# Copyright 2019-2020,2023 VMware, Inc.
# All rights reserved. -- VMware Confidential

"""Utilities for ImageManager
"""

from datetime import datetime
import logging
import os
import re

from . import Constants
from .. import Vib
log = logging.getLogger("ImageManagerUtil")

# Python datetime format that is the closest to, but not exactly matching vAPI's
# definition. See vapi-core/vapi/DateTime.cpp for the full format.
# str2Time() and time2Str() need to be used to convert from/to vAPI's format.
BASE_TASK_TIME_FORMAT = '%Y-%m-%dT%H:%M:%S.%f'

# Get a list of object.toDict() result for each object in the non-empty list.
# Otherwise simply return None.
getOptionalDictList = lambda x: [i.toDict() for i in x] if x else None

# Get a sorted comma separated output
getCommaSepArg = lambda x: ', '.join(sorted(x))

class Notification(object):
   """A class that represents one VAPI notification.
      See com.vmware.esx.settings_daemon.Notifications.
   """
   def __init__(self, notificationId, msgId, msg, resMsgId, resMsg,
                msgArgs=None, resArgs=None):
      self.notificationId = notificationId
      self.msgId = msgId
      self.msg = msg
      self.msgArgs = msgArgs or []
      self.resMsgId = resMsgId
      self.resMsg = resMsg
      self.resArgs = resArgs or []
      self.time = datetime.utcnow()

   def toDict(self):
      msg = dict(id=self.msgId,
                 default_message=self.msg,
                 args=self.msgArgs)
      if self.resMsgId or self.resMsg or self.resArgs:
         # Resolution is optional.
         resolution = dict(id=self.resMsgId,
                           default_message=self.resMsg,
                           args=self.resArgs)
      else:
         resolution = None
      return dict(id=self.notificationId,
                  message=msg,
                  resolution=resolution,
                  time=time2Str(self.time))

class Notifications(object):
   """A collection of notifications divided to info, warning and error
      categories.
      See com.vmware.esx.settings_daemon.Notifications.
   """
   def __init__(self, infoMsgs=None, warnMsgs=None, errMsgs=None):
      self.info = infoMsgs or []
      self.warnings = warnMsgs or []
      self.errors = errMsgs or []

   def toDict(self):
      return dict(info=getOptionalDictList(self.info),
                  warnings=getOptionalDictList(self.warnings),
                  errors=getOptionalDictList(self.errors))

def time2Str(timeObj):
   """Convert datetime object to a VAPI time string.
   """
   # Truncate microsec to millisec and add Z.
   return timeObj.strftime(BASE_TASK_TIME_FORMAT)[:-3] + 'Z'

def str2Time(s):
   """Converts time string to a datetime object. The string should be
      from time2Str().
   """
   return datetime.strptime(s[:-1] + '000', BASE_TASK_TIME_FORMAT)

def getFormattedMessage(msg, args):
   """Format a message for VAPI.
   """
   if args:
      # Messages in Constants are copied from VLCM on VC, they have positional
      # arguments that start from {1} rather than {0}. An extra argument will
      # save the effort of having two slightly different messages.
      return msg.format(*([''] + args))
   return msg

def getExceptionNotification(ex):
   """Get a notification from an exception.
   """
   UNKNOWN_ERR = 'UnknownError'

   def getMappedErrorName(ex):
      """Get mapped name of the error.
      """
      # Error name is figured using the alias map and the conversion map.
      # Aliases unify similar errors, and during conversion, UnknownError is
      # assigned for an error that is not explicitly handled.
      exType = type(ex).__name__
      errorAlias = Constants.ESXIMAGE_ERROR_ALIAS.get(exType, exType)
      errorName = (errorAlias if errorAlias in Constants.ESXIMAGE_ERROR_MSG_ARG
                   else UNKNOWN_ERR)
      return errorName

   errorName = getMappedErrorName(ex)

   if hasattr(ex, 'cause') and ex.cause is not None:
      # Nested exception, get notification from the actual error if it is not
      # mapped to unknown error.
      causeErrorName = getMappedErrorName(ex.cause)
      if causeErrorName != UNKNOWN_ERR:
         errorName = causeErrorName
         ex = ex.cause

   notifId = Constants.ESXIMAGE_PREFIX + errorName
   msg, argNames = Constants.ESXIMAGE_ERROR_MSG_ARG[errorName]

   # Get arguments for the notification by attributes in the exception
   # object.
   msgArgs = []
   for arg in argNames:
      attr = getattr(ex, arg)
      if isinstance(attr, list):
         msgArgs.append(','.join(attr))
      else:
         msgArgs.append(str(attr))
   msg = getFormattedMessage(msg, msgArgs)

   # For error reporting, resolution is not used.
   return Notification(notifId, notifId, msg, "", "", msgArgs=msgArgs)

def getNotification(notificationId, msgId, msgArgs=None, resArgs=None,
                    type_=Constants.INFO):
   """Forms a Notification instance for VAPI use.
      **This does not use the Notification(s) classes above, but the one in
        settingsd binding.
   """
   # Not available on 6.7.
   from com.vmware.esx.settings_daemon_client \
      import Notification as VapiNotification
   from com.vmware.vapi.std_client \
      import LocalizableMessage as VapiLocalizableMessage

   defMsg = getFormattedMessage(Constants.NOTIFICATION_MSG[msgId], msgArgs)
   msg = VapiLocalizableMessage(id=msgId, default_message=defMsg,
                                args=msgArgs or [])
   resMsg = getFormattedMessage(
      Constants.RESOLUTION_MSG.get(notificationId, ''), resArgs)
   if resMsg:
      # Populate the optional resolution when there is actually a message.
      resId = msgId + Constants.RESOLUTION_SUFFIX if resMsg else ''
      resolution = VapiLocalizableMessage(id=resId,
                                          default_message=resMsg,
                                          args=resArgs or [])
   else:
      resolution = None

   if Constants.NOTIFICATION_HAS_TYPE:
      return VapiNotification(id=notificationId,
                              time=datetime.utcnow(),
                              message=msg,
                              resolution=resolution,
                              type=type_)
   return VapiNotification(id=notificationId,
                           time=datetime.utcnow(),
                           message=msg,
                           resolution=resolution)

def parseQpScriptInfo(filePath, vibFilePath):
   """Parse a single quick patch script and returns script info.
   """
   SCRIPT_INFO_BEGIN = "BEGIN SCRIPT INFO"
   SCRIPT_INFO_END = "END SCRIPT INFO"
   SCRIPT_INFO_LINENUM_MAX = 20

   scriptInfo = {}
   lineNum = 0
   with open(filePath, 'r') as f:
      withinInfo = False
      for line in f:
         lineNum += 1
         if not withinInfo:
            if SCRIPT_INFO_BEGIN in line:
               withinInfo = True
               continue
            if lineNum >= SCRIPT_INFO_LINENUM_MAX:
               # Script info should be just below the shebang line, don't
               # go too deep into a library file.
               break
            continue

         if "type:" in line:
            typeStr = line[line.find(':') + 1:].strip()
            if len(typeStr) == 0 or \
                  not typeStr in Vib.QuickPatchScript.SCRIPT_TYPES:
               raise ValueError("Invalid script type value '%s' in %s"
                                % (typeStr, vibFilePath))
            scriptInfo["type"] = typeStr
         elif "timeout:" in line:
            timeout = line[line.find(':') + 1:].strip()
            try:
               scriptInfo["timeout"] = int(timeout)
            except ValueError:
               raise ValueError(
                     "Invalid script timeout value '%s' in %s"
                     % (timeout, vibFilePath))
            if scriptInfo["timeout"] <= 0:
               raise ValueError(
                     "Invalid script timeout value '%s' in %s"
                     % (timeout, vibFilePath))
         elif SCRIPT_INFO_END in line:
            if "type" not in scriptInfo or "timeout" not in scriptInfo:
               raise ValueError("Script %s does not have type of timeout info"
                                % vibFilePath)
            break
   return scriptInfo

def getQpScriptInfoAndTags(stagePath, vibName):
   """Get quick patch script info and other quick patch tags: resource pool
      definition file path and security policy directory, from a quick patch
      script payload stage dir.
   """
   DOM_REGEX = re.compile(r'^tmp-[\w-]+Dom$')

   qpDir = os.path.join(stagePath, Vib.QUICKPATCH_SCRIPT_DIR, vibName)
   scriptsInfo = {}
   rpFile = secPolDir = None
   # Sort potential scripts by name.
   dirEntries = sorted(os.scandir(qpDir), key=lambda x: x.name)
   for entry in dirEntries:
      # Looking for secpolicy dir, scripts and resource pool yaml file in the
      # quick patch script dir.
      pathInVib = os.path.join(Vib.QUICKPATCH_SCRIPT_DIR, vibName, entry.name)

      if entry.is_dir() and entry.name == 'secpolicy':
         # Security policy dir. Verify it only contains the tmp-*Dom domain
         # files in the 'domains' sub-directory.
         domDir = os.path.join(entry.path, 'domains')

         if not os.path.isdir(domDir):
            raise ValueError("The security policy dir {!r} does not have the "
                             "'domains' sub-dir".format(entry.path))

         for entry in os.scandir(domDir):
            if not entry.is_file():
               raise ValueError('Not a security domain file: {!r}'.format(
                                entry.path))
            if not DOM_REGEX.match(entry.name):
               raise ValueError('Security domain name {!r} does not match '
                  'regex {!r}'.format(entry.name, DOM_REGEX.pattern))
         if secPolDir is not None:
            raise ValueError('Only one security policy dir is allowed, found '
               '{} and {}'.format(secPolDir, pathInVib))
         secPolDir = pathInVib

      if not entry.is_file():
         continue

      if entry.path.endswith('.yml'):
         if rpFile is not None:
            raise ValueError("Only one rp .yml file is allowed, found %s and %s"
                             % (rpFile, pathInVib))
         rpFile = pathInVib
      scriptInfo = parseQpScriptInfo(entry.path, pathInVib)
      if scriptInfo:
         scriptsInfo[pathInVib] = scriptInfo

   if not scriptsInfo:
      raise ValueError("No Live Patch script exists in %s" % stagePath)
   if not rpFile:
      raise ValueError("No resource pool definition file is found in %s"
                       % stagePath)
   if not secPolDir:
      raise ValueError("No security policy directory is found in %s"
                       % stagePath)

   return scriptsInfo, rpFile, secPolDir
