#!/usr/bin/python
########################################################################
# Copyright (c) 2010-2023 VMware, Inc. All rights reserved.
# VMware Confidential
########################################################################

"""Wrapper class for calculating a checksum while reading/writing.
"""
import hashlib

from .Misc import seekable

class HashError(Exception):
   pass

class HashedStream(object):
   def __init__(self, stream, expected=None, method="sha256"):
      self.stream = stream
      self.method = method
      if self.method == "md5":
         try:
            # In Python version 3.9, new keyword 'usedforsecurity' is
            # introduced. If this argument is False, then it allows the
            # use of insecure and blocked hashing algorithms.
            self.hashobj = hashlib.new(self.method, usedforsecurity=False)
         except TypeError:
            self.hashobj = hashlib.new(self.method)
      else:
         self.hashobj = hashlib.new(self.method)
      self.expected = None
      if expected is not None:
         self.expected = expected.lower()

   def reset(self):
      if not seekable(self.stream):
         raise IOError("failed to reset HashStream (file is not seekable)")
      self.stream.seek(0)
      if self.method == "md5":
         try:
            # In Python version 3.9, new keyword 'usedforsecurity' is
            # introduced. If this argument is False, then it allows the
            # use of insecure and blocked hashing algorithms.
            self.hashobj = hashlib.new(self.method, usedforsecurity=False)
         except TypeError:
            self.hashobj = hashlib.new(self.method)
      else:
         self.hashobj = hashlib.new(self.method)

   @property
   def digest(self):
      return self.hashobj.digest()

   @property
   def hexdigest(self):
      return self.hashobj.hexdigest()

   def read(self, size):
      data = self.stream.read(size)
      self.hashobj.update(data)
      if not data and self.expected is not None:
         result = self.hashobj.hexdigest().lower()
         if result != self.expected:
            msg = ("Calculated digest at EOF does not match expected result: "
                   "%s calculated, %s expected." % (result, self.expected))
            raise HashError(msg)
      return data

   def write(self, data):
      self.hashobj.update(data)
      self.stream.write(data)

   def close(self):
      self.stream.close()
