#-------------------------------------------------------------------
#  DbList.py
#
#  The DbList class.
#
#  Copyright 2014 Applied Invention, LLC.
#-------------------------------------------------------------------

'''The module containing the DbList class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from .Column import Column
from .Column import manyToOne
from .Column import PrimaryKeyColumn
from .DbListProxy import DbListProxy
from collections import OrderedDict
from ai.axe.util import StringUtil
from sqlalchemy import ForeignKey
from sqlalchemy.orm.relationships import RelationshipProperty
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import List
from typing import Optional
from typing import overload
from typing import Type
from typing import TypeVar
from typing import Tuple
from typing import Union
#
# Import statements go above this line.
#-------------------------------------------------------------------


#===================================================================
# The type of the item that is stored in the list.
DbListItem = TypeVar('DbListItem')

#===================================================================
# The wrapper class that holds each list item.
DbListItemObject = TypeVar('DbListItemObject')

#===================================================================
# The object that owns this list.  (This list is a field of the object.)
DbListOwner = TypeVar('DbListOwner')

#===================================================================
# We can't use the sqlalchemy.Column type for our column objects
# in this class, because we're using mypy stub trickery to
# make Columns work as DB object class/instance members with different types.
# Until someone figures out a better way, we have to use the 'Any'
# type for those column objects.  We use this columnAny alias
# to at least make clear which variables are really Column objects.
columnAny = Any

#===================================================================
class DbList(Generic[DbListItem]):
  '''A list attribute that is managed by SQL Alchemy.

  You can create a data attribute like so:

    class Foo(SqlBase):

      tools = DbList(str)

  Now toolList is a list of strings (in Python) that gets saved to
  a table whose schema is generated and managed for you.

  By default the table has a name that is your class name plus the
  list attribute name, and a data column, and a foreign key column.
  For example, the above mapping code would result in this child table:

    table foo_tools (
      id serial,            -- Primary key.
      foo_id int,           -- Foreign key to Foo table.
      db_list_position int, -- The index in the list.
      tool varchar          -- The data.
    );

  To customize the behavior of the list, there are several optional
  arguments you can provide.  You can specify a Column
  instead of just a type:

      tools = DbList(Column("tool",
                            String,
                            values=('fertility', 'water'))

  You can specify the column to sort the list by:

      tools = DbList(str,
                     orderBy="tool")

  Ordering this list this way only affects the order the items are in
  immediately after a database read.  This list itself does not maintain
  this ordering.

  By default, if no orderBy parameter is provided, the list will be ordered
  using the order of the list itself.  When the list is read from the
  database, it will be in the same order as when it was last written.

  You can specify the tableArgs to be added to the generated child table.

      tools = DbList(str,
                     tableArgs=(UniqueConstraint('cropZoneId', 'tool'),),

  You can specify the table name you'd like the generated child table to have:

      tools = DbList(str, tableName="foobar_tools")
  '''

  #-----------------------------------------------------------------
  def __init__(self,
               itemType: Type[DbListItem],
               column: Optional[columnAny] = None,
               tableName: Optional[str] = None,
               joinColumn: str = 'id',
               tableArgs: Tuple[Any, ...] = (),
               orderBy: str = 'dbListPosition',
               foreignKeyName: Optional[str] = None) -> None:
    '''Creates a new DbList.

    @param column Either a Column object or a SQL Alchemy type.

    '''

    # Must import here to avoid circular include with parent.
    # pylint: disable=import-outside-toplevel
    import ai.axe.db.alchemy

    if column is None:

      # Create a Column from a type.
      column = ai.axe.db.alchemy.Column(itemType)

    # The column object to create on the child table for the list data.
    self.column: columnAny = column

    # The name of the child table to create.
    #
    # If not provided, a name will be auto-generated.
    self.tableName: Optional[str] = tableName

    # Name of the parent table column to join to.  Defaults to 'id'.
    self.joinColumn: str = joinColumn

    # Any __table_args__ to be applied to the child class.
    self.tableArgs: Tuple[Any, ...] = tableArgs

    # An order by constraint to use when loading.
    self.orderBy: str = orderBy

    # The name of the foreign key constraint back to the parent table.
    self.foreignKeyName: Optional[str] = foreignKeyName

    # The generated child object class that is the child table.
    self.childClass: Optional[type] = None

    # The name of the data attribute on the child table.
    self.childAttrName: Optional[str] = None

    # The relationship from the parent to the child collection.
    self.relationshipAttrName: Optional[str] = None

    # The relationship from the parent to the child collection.
    self.relationship: Optional[RelationshipProperty] = None

  #-----------------------------------------------------------------
  def contains(self, target: Any) -> Any:
    '''Used in a SQL query to test whether this list contains an item.

    @param target The target value to be tested for.
    '''

    assert self.childAttrName is not None

    # Forward the query to the child class attribute.
    return getattr(self.childClass, self.childAttrName).contains(target)

  #-----------------------------------------------------------------
  def process(self,
              parentClassName: str,
              baseClasses: List[type],
              parentAttrs: Dict[str, Any],
              attrName: str) -> None:
    '''Processes the specified class, setting up list attributes.

    @param parentClassName The class that has the DbList as a member.
    @param baseClasses Base classes of the parent class.
    @param parentAttrs The map of attributes of the parent classs.
    @param attrName The name of the attribute that this DbList is assigned to.
    '''

    # Set up the table name.

    if self.tableName is not None:
      childTableName = self.tableName
    else:
      childTableName = (StringUtil.camelCaseToUnderscores(parentClassName) +
                        '_' +
                        StringUtil.camelCaseToUnderscores(attrName))
    childClassName = parentClassName + StringUtil.capitalize(attrName)

    # Set up the column name.

    columnName = self.singular(attrName)
    if self.column.name:
      columnName = self.column.name

    childAttrName = StringUtil.underscoresToCamelCase(columnName)

    ######################################
    # Create the child table class.

    attrs: Dict[str, Any] = OrderedDict()

    # The DB table name.
    attrs['__tablename__'] = childTableName

    # The Database ID column.
    attrs['id'] = PrimaryKeyColumn()

    # The list position column.
    attrs['dbListPosition'] = Column(int, index=True)

    # The data column.
    attrs[childAttrName] = self.column

    # The foreign key column.

    fkAttrName = StringUtil.initialLower(parentClassName + "Id")
    joinColumn = parentAttrs['__tablename__'] + "." +  self.joinColumn

    foreignKeyArgs = {}
    if self.foreignKeyName:
      foreignKeyArgs['name'] = self.foreignKeyName

    foreignKey: Union[str, ForeignKey]
    foreignKey = ForeignKey(joinColumn, ondelete='cascade', **foreignKeyArgs)
    attrs[fkAttrName] = manyToOne(foreignKey)

    # Add a __repr__  method.
    attrs['__repr__'] = self.createRepr(childAttrName)

    # Table args.
    if self.tableArgs:
      attrs['__table_args__'] = self.tableArgs

    assert len(baseClasses) == 1, baseClasses

    childClass = type(childClassName, tuple(baseClasses), attrs)

    ######################################
    # Modify the parent class.

    # Create the relationship with attribute name attrName + 'Object'.
    relationshipAttrName = attrName + 'Objects'
    kwargs = {}
    if self.orderBy:
      if not hasattr(childClass, self.orderBy):
        msg = "Error creating %s.%s.  " % (parentClassName, attrName)
        msg += "You passed in orderBy='%s', "
        msg += "but there's no static data attribute called %s "
        msg += "on the list table class."
        msg = msg % (self.orderBy, self.orderBy)
        raise ValueError(msg)
      kwargs['order_by'] = getattr(childClass, self.orderBy)

    # Must import here to avoid circular include with parent.
    # pylint: disable=import-outside-toplevel
    import ai.axe.db.alchemy.orm
    relationship = ai.axe.db.alchemy.orm.relationship

    parentAttrs[relationshipAttrName] = relationship(childClassName, **kwargs)

    # Make the child class available to users.
    parentAttrs[attrName + 'Class'] = childClass

    self.childClass = childClass
    self.childAttrName = childAttrName
    self.relationshipAttrName = relationshipAttrName
    self.relationship = parentAttrs[relationshipAttrName]

  #-----------------------------------------------------------------
  def createObject(self, value: DbListItem) -> DbListItemObject:
    '''Returns a new list object containing the specified list value.

    @param value The list value to create an object for.

    @return A new object of type self.childClass.
    '''

    assert self.childClass is not None
    assert self.childAttrName is not None

    obj = self.childClass()
    setattr(obj, self.childAttrName, value)
    return obj

  #-----------------------------------------------------------------
  def createRepr(self, attrName) -> Callable[[Any], str]:
    '''Returns a __repr__ method for an object with a single attribute.
    '''

    def newReprMethod(self) -> str:
      '''Formats this object into a string.
      '''

      attrs = [attrName]
      return StringUtil.formatRepr(self, attrs)

    return newReprMethod

  # pylint: disable=function-redefined

  #-----------------------------------------------------------------
  @overload
  def __get__(self,
              obj: None,
              objType: Type[DbListOwner]) -> 'DbList':
    pass

  #-----------------------------------------------------------------
  @overload
  def __get__(self,
              obj: DbListOwner,
              objType: Type[DbListOwner]) -> DbListProxy[DbListItem]:
    pass

  #-----------------------------------------------------------------
  def __get__(self,
              obj: Union[None, DbListOwner],
              objType: Type[DbListOwner]) -> Union['DbList',
                                                   DbListProxy[DbListItem]]:
    '''Descriptor get method.

    This method is called in two places.  Assume there's a class Foo
    with a DbList called bar.

      class Foo:
        bar = DbList(String)

      foo = Foo()

      values = foo.bar   # Case 1: object
      relationship = Foo.bar  # Case 2: class

    In case 1, obj will be 'foo' and objType will be 'Foo'.
    In case 2, obj will be None, and objType will be 'Foo'.
    '''

    if not obj:

      # This is being called on the class, so just return myself.
      # This allows to user to call 'contains' on me in queries.

      return self

    else:

      # This is being called on an object, so return a list proxy.

      assert self.childAttrName is not None
      assert self.childClass is not None
      assert self.relationshipAttrName is not None

      objectList = getattr(obj, self.relationshipAttrName)
      return DbListProxy[DbListItem](self.childClass,
                                     self.childAttrName,
                                     objectList)

  # pylint: enable=function-redefined

  #-----------------------------------------------------------------
  def __set__(self, obj: DbListOwner, values: List[DbListItem]) -> None:
    '''Descriptor set method.

    This method is called when the user assigns values to the child
    list.  Assume there's a class Foo with a DbList called bar.

      class Foo:
        bar = DbList(String)

      foo = Foo()

      foo.bar = [1, 2, 3] # Called here.
    '''
    assert self.childAttrName is not None
    assert self.relationshipAttrName is not None

    objectList: List[Any] = getattr(obj, self.relationshipAttrName)

    while len(objectList) > len(values):
      objectList.pop()

    for i in range(len(objectList)):
      setattr(objectList[i], self.childAttrName, values[i])

    for i in range(len(objectList), len(values)):
      objectList.append(self.createObject(values[i]))

    # Update the dbListPosition values to be correct.
    for i, item in enumerate(objectList):
      item.dbListPosition = i

  #-----------------------------------------------------------------
  def __delete__(self, obj: DbListOwner) -> None:
    '''This method is called when the user deletes the list.

      class Foo:
        bar = DbList(String)

      foo = Foo()

      del foo.bar # Called here.
    '''
    assert self.childAttrName is not None
    assert self.childClass is not None

    name = self.childClass.__name__ + '.' + self.childAttrName
    msg = "You are not allowed to delete the list " + name + "."
    raise NotImplementedError(msg)

  #----------------------------------------------------------------
  def singular(self, word: str) -> str:
    '''Returns the singular of a plural word.'''

    if word and word[-1] == 's':
      return word[:-1]
    else:
      return word

  #----------------------------------------------------------------
  def __repr__(self) -> str:
    '''Returns a string representation of this object
    '''
    attrs = ['column', 'tableName', 'joinColumn', 'tableArgs']

    return StringUtil.formatRepr(self, attrs)
