#-------------------------------------------------------------------
#  DbDictProxy.py
#
#  The DbDictProxy class.
#
#  Copyright 2014 Applied Invention, LLC.
#-------------------------------------------------------------------

'''The module containing the DbDictProxy class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from collections.abc import MutableMapping
from ai.axe.util import ReflectionUtil
from typing import AbstractSet
from typing import Any
from typing import Dict
from typing import Generic
from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
#
# Import statements go above this line.
#-------------------------------------------------------------------


#===================================================================
# The type of a key that is stored in the dict.
DbDictKey = TypeVar('DbDictKey')

#===================================================================
# The type of a value that is stored in the dict.
DbDictValue = TypeVar('DbDictValue')

#===================================================================
# The wrapper class that holds each dict item, and is stored in the DB.
# Since this class is created at runtime, we have to use 'Any'.
# pylint: disable=invalid-name
DbDictItemObject = Any
# pylint: enable=invalid-name

#===================================================================
# The object that owns this dict.  (This dict is a field of the object.)
DbDictOwner = TypeVar('DbDictOwner')

#===================================================================
# Hack to get around bug: https://github.com/python/mypy/issues/5264
if TYPE_CHECKING:

  # pylint: disable=invalid-name
  MutableMappingBase = MutableMapping
  # pylint: enable=invalid-name

else:

  class MutableMappingMaker:
    def __getitem__(self, *args):
      return MutableMapping

  # pylint: disable=invalid-name
  MutableMappingBase = MutableMappingMaker()
  # pylint: enable=invalid-name

#===================================================================
class DbDictProxy(MutableMappingBase[DbDictKey, DbDictValue],
                  Generic[DbDictKey, DbDictValue]):
  '''A class that acts like a dict, but is backed by a database object list.
  '''

  #-----------------------------------------------------------------
  def __init__(self,
               itemClass: Type[DbDictItemObject],
               keyAttrName: str,
               valueAttrName: str,
               srcList: List[DbDictItemObject]) -> None:
    '''Creates a new DbDictProxy.
    '''

    # The list of objects that this list is proxying.
    self.srcList: List[DbDictItemObject] = srcList

    # The class to use to create new objects.
    self.itemObjectClass: Type[DbDictItemObject] = itemClass

    # The name of the key attribute on the itemObjectClass.
    self.keyAttrName: str = keyAttrName

    # The name of the value attribute on the itemObjectClass.
    self.valueAttrName: str = valueAttrName

    # The dictionary of [key:object_list_index].
    self.keyIndexes: Dict[DbDictKey, int] = {}

    self.buildKeyIndexes()

  #-----------------------------------------------------------------
  def updateDbDictPosition(self) -> None:
    '''Updates each dbDictPostition to match its current possition in the list.
    '''

    for i, item in enumerate(self.srcList):
      item.dbDictPosition = i # type: ignore

  #-----------------------------------------------------------------
  def buildKeyIndexes(self) -> None:
    '''Builds the keyIndexes dict from the underlying collection.
    '''

    self.keyIndexes.clear()

    for i, item in enumerate(self.srcList):
      key = getattr(item, self.keyAttrName)

      # If we have a duplicate key, something is wrong.
      assert key not in self.keyIndexes, key

      self.keyIndexes[key] = i

  #-----------------------------------------------------------------
  def createNewObject(self,
                      key: DbDictKey,
                      value: DbDictValue) -> DbDictItemObject:
    '''Creates a new object that holds the specified key and value.
    '''

    obj = self.itemObjectClass()
    setattr(obj, self.keyAttrName, key)
    setattr(obj, self.valueAttrName, value)

    return obj

  #-----------------------------------------------------------------
  def __len__(self) -> int:

    return len(self.keyIndexes)

  #-----------------------------------------------------------------
  def __getitem__(self, key: DbDictKey) -> DbDictValue:

    obj = self.srcList[self.keyIndexes[key]]
    return getattr(obj, self.valueAttrName)

  #-----------------------------------------------------------------
  def __setitem__(self, key: DbDictKey, value: DbDictValue) -> None:

    if key in self.keyIndexes:
      obj = self.srcList[self.keyIndexes[key]]
      setattr(obj, self.valueAttrName, value)

    else:
      obj = self.createNewObject(key, value)
      obj.dbDictPosition = len(self.srcList) # type: ignore
      self.srcList.append(obj)
      self.keyIndexes[key] = len(self.srcList) - 1

  #-----------------------------------------------------------------
  def __delitem__(self, key: DbDictKey) -> None:

    if key not in self.keyIndexes:
      raise KeyError(key)

    else:
      del self.srcList[self.keyIndexes[key]]
      self.buildKeyIndexes()
      self.updateDbDictPosition()

  #-----------------------------------------------------------------
  def __contains__(self, key: object) -> bool:

    return key in self.keyIndexes

  #-----------------------------------------------------------------
  def __iter__(self) -> Iterator[DbDictKey]:

    return iter(self.keyIndexes.keys())

  #-----------------------------------------------------------------
  def clear(self) -> None:

    del self.srcList[:]
    self.buildKeyIndexes()

  #-----------------------------------------------------------------
  def __eq__(self, other: object) -> bool:
    return dict(self) == other

  #-----------------------------------------------------------------
  def __ne__(self, other: object) -> bool:
    return dict(self) != other

  #-----------------------------------------------------------------
  def __repr__(self) -> str:
    return repr(dict(self.items()))

  #-----------------------------------------------------------------
  def get(self,                                # type: ignore
          key: DbDictKey,
          default: Optional[DbDictValue] = None) -> Optional[DbDictValue]:
    try:
      return self[key]
    except KeyError:
      return default

  #-----------------------------------------------------------------
  def setdefault(self,                         # type: ignore
                 key: DbDictKey,
                default: Optional[DbDictValue] = None) -> Optional[DbDictValue]:
    if key not in self.keyIndexes:
      self[key] = default                      # type: ignore
      return default

    else:
      return self[key]

  #-----------------------------------------------------------------
  def keys(self) -> AbstractSet[DbDictKey]:
    return self.keyIndexes.keys()

  #-----------------------------------------------------------------
  def values(self) -> Iterator[DbDictValue]:   # type: ignore

    for obj in self.srcList:
      value = getattr(obj, self.valueAttrName)
      yield value

  #-----------------------------------------------------------------
  def items(self) -> Iterator[Tuple[DbDictKey, DbDictValue]]:  # type: ignore

    for obj in self.srcList:
      key = getattr(obj, self.keyAttrName)
      value = getattr(obj, self.valueAttrName)
      yield key, value

  #-----------------------------------------------------------------
  defaultMagicValue: Any = 'ARG_NOT_PROVIDED_MAGIC_VALUE'
  #-----------------------------------------------------------------
  def pop(self,                                      # type: ignore
          key: DbDictKey,
          default: DbDictValue = defaultMagicValue) -> DbDictValue:

    if key not in self.keyIndexes:
      if default is DbDictProxy.defaultMagicValue:
        raise KeyError(key)
      else:
        return default

    else:
      value: DbDictValue = self[key]
      del self[key]
      return value

  #-----------------------------------------------------------------
  def popitem(self) -> Tuple[DbDictKey, DbDictValue]:

    if not self.keyIndexes:
      raise KeyError('popitem(): dictionary is empty')

    else:
      obj = self.srcList[-1]
      key = getattr(obj, self.keyAttrName)
      value = getattr(obj, self.valueAttrName)

      del self.srcList[-1]
      del self.keyIndexes[key]

      return key, value

  #-----------------------------------------------------------------
  def update(self, *args, **kwargs) -> None:

    if len(args) > 1:
      raise TypeError('update expected at most 1 arguments, got %i' % len(args))

    elif len(args) == 1:

      seqOrMap = args[0]
      # discern dict from sequence - took the advice from
      # http://www.voidspace.org.uk/python/articles/duck_typing.shtml
      # still not perfect :(
      if hasattr(seqOrMap, 'keys'):
        for key in seqOrMap:
          self[key] = seqOrMap[key]

      else:
        for i, item in enumerate(seqOrMap):
          if len(item) != 2:
            msg = ("dictionary update sequence element #%s has length %s; " +
                   "2 is required")
            msg = msg % (i, len(item))
            raise ValueError(msg)
          else:
            key, value = item
            self[key] = value

    for key in kwargs:
      self[key] = kwargs[key]

  #-----------------------------------------------------------------
  def copy(self) -> Dict[DbDictKey, DbDictValue]:
    return dict(self.items())

  #-----------------------------------------------------------------
  def __hash__(self) -> int:
    raise TypeError("%s objects are unhashable" % type(self).__name__)

  #-----------------------------------------------------------------
  def dumpAll(self) -> str:

    msg = 'Keys: ' + ', '.join((str(x) for x in self.keyIndexes)) + '\n'
    msg += 'Objects: \n'
    for obj in self.srcList:
      msg += str(obj) + '\n'

    return msg

  #-----------------------------------------------------------------
  # Add the list docstrings to DbListProxy's methods.
  ReflectionUtil.copyDocstrings(list, locals())
