########################################################################
# Copyright (C) 2010,2017,2018,2021,2022 VMWare, Inc.                       #
# All Rights Reserved                                                  #
########################################################################

"This module provides classes and functions for working with XML data."

import datetime
import os
import sys
import re
import glob

try:
   import buildNumber
   vibtoolsDir = "vibtools-%s" % buildNumber.BUILDNUMBER
except ImportError:
   vibtoolsDir = "vibtools"

etree = None
def FindElementTree():
   """Find an implementation of the ElementTree API. There are many possible
      implementations or locations of the ElementTree API, and this function
      avoids a potentially long and deeply nested series of try/except
      statements in order to find one.
         Returns: A module instance.
         Raises: ImportError, if no ElementTree implementation can be found.
         Example usage:
            >>> import util
            >>> etree = util.FindElementTree()
            >>> tree = etree.parse("my.xml")
   """
   if etree is not None:
      return etree

   for name in ("lxml.etree", "xml.etree.cElementTree",
                "xml.etree.ElementTree", "cElementTree",
                "elementtree.ElementTree"):
      try:
         module = __import__(name)
      except ImportError:
         continue
      # if module = __import__("a.b.c") is successful, it means that the module
      # "a.b.c" has been imported, but the module variable will be a reference
      # to the top-level module "a", and we must get the specific sub-module
      # we want.
      namespaces = name.split(".")
      for namespace in namespaces[1:]:
         module = getattr(module, namespace)
      return module

   raise ImportError("Failed to find ElementTree implementation.")

etree = FindElementTree()

def StripNamespace(tagname):
   """Strip an ElementTree-style namespace (encapsulated in curly braces at the
      beginning of the tag name), from the tag name. It is also safe to call
      this function on tag names which do not have a namespace.
         Parameters:
            * tagname - A string giving the tag name.
         Returns: A string without the tag name.
   """
   if tagname[0] == "{":
      return tagname[tagname.find("}") + 1:]
   return tagname

def IndentElementTree(elem, indent=0):
   """Recursively indents Element elem and all children.
         Parameters:
            * elem - An etree.Element object. Note that it will be modified
                     in-place.
   """
   elem.tail = "\n" + "  " * indent
   for child in elem:
      IndentElementTree(child, indent + 1)

   if len(elem):
      # Fix tail string of last child element so it returns us to our indent
      # level
      elem.text = "\n" + "  " * (indent + 1)
      elem[-1].tail = "\n" + "  " * indent

def ParseXsdBoolean(text):
   """Parse an xsd:boolean string to a Python bool object.
         Paramters:
            * text - The xsd:boolean string.
         Returns: True or False
         Raises:
            * ValueError - If the input string is not in the expected format.
   """
   boolstr = text.strip().lower()
   if boolstr in ("0", "false"):
      return False
   elif boolstr in ("1", "true"):
      return True
   else:
      raise ValueError("'%s' is not a valid boolean." % text)



# This class provides a way to standardize all datetime objects to a single
# time zone. In ParseXsdDateTime, we create a temporary tzinfo object based on
# the offset that we parse from the time stamp, but then we convert it to UTC.
# The idea is to have all datetime objects reference a single UtcInfo object,
# rather than have each one reference a single instance of tzinfo.

class UtcInfo(datetime.tzinfo):
   ZERO = datetime.timedelta(0)
   utcoffset = lambda self, dt: self.ZERO
   dst       = lambda self, dt: self.ZERO
   tzname    = lambda self, dt: 'UTC'

# Note that we do this outside of the function so that we don't have to re-
# compile the regex every time. Also note that we only use this to parse the
# string--the datetime.datetime constructor may yet throw an exception because
# the parsed values are crazy. E.g., minute = 99, month = 40, tz = +25:00.
_xsd_dt_re = re.compile(r"(?P<year>\d{4})-(?P<month>\d\d)-(?P<day>\d\d)."
                        r"(?P<hour>\d\d):(?P<minute>\d\d):(?P<second>\d\d)"
                        r"(\.(?P<ms>\d{1,6}))?"
                        r"(Z|(?P<tzdir>[+-])(?P<tzhr>\d\d):(?P<tzmin>\d\d))?$")
_utctzinfo = UtcInfo()
def ParseXsdDateTime(text):
   """Parse an xsd:dateTime string to a Python datetime object.
         Parameters:
            * text - The xsd:dateTime string.
         Returns: A datetime.datetime object as UTC.
         Raises:
            * ValueError - If the input string is not in the expected format.
   """
   m = _xsd_dt_re.match(text.strip())
   if m is None:
      raise ValueError("'%s' is not a valid date/time string." % text)

   dtargs = dict()
   for key in("year", "month", "day", "hour", "minute", "second"):
      value = m.group(key)
      try:
         dtargs[key] = int(value)
      except Exception:
         raise ValueError("'%s' is an invalid value for %s." % (value, key))

   ms = m.group("ms")
   if ms is not None:
      try:
         dtargs["microsecond"] = int(ms + "0" * (6 - len(ms)))
      except Exception:
         raise ValueError("'%s' is an invalid value for microseconds." % ms)

   tzdir = m.group("tzdir")
   if tzdir is not None:
      tzhr = tzdir + m.group("tzhr")
      tzmin = m.group("tzmin")
      try:
         tzhr = int(tzhr)
         tzmin = int(tzmin)
      except Exception:
         raise ValueError("'%s:%s' is an invalid offset." % (tzhr, tzmin))

      tzoffset = datetime.timedelta(minutes = tzhr * 60 + tzmin)

      class tzinfo(datetime.tzinfo):
         utcoffset = lambda self, dt: tzoffset
         dst       = lambda self, dt: tzoffset
         tzname    = lambda self, dt: ""

      dtargs["tzinfo"] = tzinfo()

      try:
         dt = datetime.datetime(**dtargs)
         return dt.astimezone(_utctzinfo)
      except Exception as e:
         msg = "'%s' is not a valid date/time: %s." % (text, e)
         raise ValueError(msg)
   else:
      dtargs["tzinfo"] = _utctzinfo
      try:
         return datetime.datetime(**dtargs)
      except Exception as e:
         msg = "'%s' is not a valid date/time: %s." % (text, e)
         raise ValueError(msg)

class ValidationError(Exception):
   pass

def GetSchemaDir():
   """Return the default schema directory.
      For ESXi host, the directory is '/usr/share/esximage/schemas'
      For vib-suite, the directory is '/opt/vmware/vibtools/schemas'
      For vCenter, the directory is '/etc/vmware-imagebuilder/schemas'
      For other platforms, the directory is '../schemas' relative to this module
   """
   paths = ["/usr/share/esximage/schemas",
            "/opt/vmware/%s/schemas" % vibtoolsDir,
            "/etc/vmware-imagebuilder/schemas"]
   for path in paths:
      if os.path.exists(path):
         return path
   moduledir = os.path.dirname(__file__)
   schemadir = os.path.join(moduledir, os.pardir, os.pardir, os.pardir)
   return os.path.abspath(schemadir)


# We can cache schemas in a dictionary. This will prevent us
# from having to re-load a schema each time we want to validate a new XML.
# The keys of this dictionary are schema file names; the values are schema
# instances (which support a 'validate' method).
_schema_cache = dict()

def GetSchemaObj(schemafile):
   """Returns an XML schema validator object, with a 'validate' method, which
      can be passed XML files to validate.  The lxml module must be available.
      Uses a cache to store instances of schema objects so that schemas do not
      have to be re-parsed each time.
      Parameters:
         * schemafile - A path name to an .rng, .xsd, or .dtd format XML schema
      Returns:
         An instance of etree.RelaxNG, etree.XMLSchema, or etree.DTD
      Raises:
         ValidationError - if one of the classes above cannot be found, or the
                           XML schema file cannot be parsed
   """
   if schemafile in _schema_cache:
      return _schema_cache[schemafile]

   try:
      schema_tree = etree.parse(schemafile)
   except Exception as e:
      msg = "Error loading schema XML data: %s." % e
      raise ValidationError(msg)

   schema_tag = StripNamespace(schema_tree.getroot().tag)
   if schema_tag == "grammar":
      if not hasattr(etree, "RelaxNG"):
         msg = "No validation class for RelaxNG schema."
         raise ValidationError(msg)
      schema_class = etree.RelaxNG
   elif schema_tag == "schema":
      if not hasattr(etree, "XMLSchema"):
         msg = "No validation class for XMLSchema schema."
         raise ValidationError(msg)
      schema_class = etree.XMLSchema
   else: # assume DTD?
      if not hasattr(etree, "DTD"):
         msg = "No validation class for DTD schema."
         raise ValidationError(msg)
      schema_class = etree.DTD

   try:
      schema_obj = schema_class(schema_tree)
   except Exception as e:
      msg = "Error parsing schema information from XML data: %s." % e
      raise ValidationError(msg)

   _schema_cache[schemafile] = schema_obj

   return schema_obj

def ValidateXml(xml, schemaobj):
   """Wrapper for XML validation that produces nicer errors.
      Parameters:
         * xml       - ElementTree instance to validate;
         * schemaobj - schema object, maybe from GetSchemaObj()
      Returns:
         An XmlUtils.ValidationResult instance
   """
   # Cache lines from the serialized XML for friendly errors
   # Spitting out the offending line is MUCH better than just the line
   # number and "Error datatype validation" but takes some work
   # Also, note that lxml will not report line numbers unless we
   # do the tostring/pretty_print/fromstring combo.
   if sys.version_info[0] >= 3:
      encoding = 'unicode'
   else:
      encoding = 'us-ascii'
   xmlout = etree.tostring(xml, encoding=encoding, pretty_print=True)
   xmllines = xmlout.splitlines()
   schemaobj.validate(ParseXMLFromString(xmlout))
   result = ValidationResult(schemaobj, xmllines)
   return result

def ParseXMLFromString(text):
   """Wrapper to parse XML from string. As fromstring of 'lxml' is vulnerable
      to XXE and billion laugh attack.
      This wrapper is introduced to address XXE and billion laugh vulnerabilities.
      Parameters:
         * text        - The text that needs to be parsed.
      Returns:
         The root node or result return by a parser target.
   """
   # Only lxml library supports resolve_entities parser. It won't be
   # used in case of other libs.
   if etree.__name__ == "lxml.etree":
      parser = etree.XMLParser(resolve_entities=False)
      return etree.fromstring(text, parser)
   else:
      return etree.fromstring(text)


class ValidationResult(object):
   """Represents results of XML Schema Validation.
      An instance of this class evaluates to True if there are no errors.

      Attributes:
         * errors      - A list of error log objects, each of which supports
                         str() -- but the string is not very friendly or readable.
         * errorstrings - A list of friendly, formatted error strings
   """
   # Some errors are not worth reporting on
   SKIP_THESE_ERRORS = ['RELAXNG_ERR_INTEREXTRA']

   def __init__(self, schema_obj, xmllines=[]):
      self.errors = []
      self.errorstrings = []
      for err in schema_obj.error_log:
         if err.type_name in self.SKIP_THESE_ERRORS:
            continue
         self.errors.append(err)

         # Report the offending data itself if at all possible
         if xmllines and err.type == etree.RelaxNGErrorTypes.RELAXNG_ERR_DATATYPE:
            msg = "Invalid: " + xmllines[err.line - 1][:80]
         else:
            msg = '(line %d: col %d) %s' % (err.line, err.column, err.message)
         self.errorstrings.append(msg)

   def __nonzero__(self):
      return len(self.errors) == 0
   __bool__ = __nonzero__
