#-------------------------------------------------------------------
#  AxeDbClass.py
#
#  The AxeDbClass class.
#
#  Copyright 2014 Applied Invention, LLC.
#-------------------------------------------------------------------

'''The module containing the AxeDbClass class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from ai.axe.db.alchemy.orm import DbDict
from ai.axe.db.alchemy.orm import DbList
from ai.axe.util import StringUtil
from sqlalchemy import Column
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import ColumnProperty
from sqlalchemy.schema import UniqueConstraint
from typing import Any
from typing import Dict
from typing import List
#
# Import statements go above this line.
#-------------------------------------------------------------------


#===================================================================
class AxeDbClass(DeclarativeMeta):
  '''Meta-class to use to create all user DB classes.

  This class overrides default SQL Alchemy behavior when constructing a
  new class (such as the 'CropZone' class).

  Because the __init__ method of this class is run when creating a new
  class (such as the 'CropZone' class), it can tweak the SQL Alchemy
  class attributes that represent columns and relationships, before
  the SQL Alchemy mapper sees them and builds metadata based on them.
  '''

  #-----------------------------------------------------------------
  def __new__(self,
              className: str,
              bases: List[type],
              attrs: Dict[str, Any]) -> 'AxeDbClass':
    '''Creates a new AxeDbClass.
    '''

    setUpDbClass(self, className, bases, attrs)

    return super().__new__(self, className, bases, attrs)

  #-----------------------------------------------------------------
  def __init__(self,
               className: str,
               bases: List[type],
               attrs: Dict[str, Any]) -> None:
    '''Creates a new AxeDbClass.
    '''

    super().__init__(className, bases, attrs)

#===================================================================
# Private Functions
#
# This must be module-level functions.
# These cannot be functions on AxeDbClass, as they would be added
# to any classes that AxeDbClass created.

#-------------------------------------------------------------------
def setAttr(clazz: type, attrs: Dict[str, Any], attrName: str, attrValue: Any):
  '''Sets an attribute on a class.

  @param clazz The class being modified.
  @param attrs The dictionary of the class being modified.
  @param attrName Name of the attribute to set.
  @param attrvalue Value of the attribute to set.
  '''

  attrs[attrName] = attrValue
  setattr(clazz, attrName, attrValue)

#-------------------------------------------------------------------
def setUpDbClass(clazz: type,
                 className: str,
                 bases: List[type],
                 attrs: Dict[str, Any]) -> None:
  '''Modifies a DB class as the class is being created.
  '''

  if '_decl_class_registry' in clazz.__dict__:

    # Class was already instrumented.  Do nothing.
    return

  setUpTableName(clazz, className, attrs)
  setUpColumns(className, bases, attrs)
  setUpUniqueConstraints(attrs)

#-------------------------------------------------------------------
def setUpTableName(clazz: type, className: str, attrs: Dict[str, Any]) -> None:
  '''Modifies a DB class as the class is being created.
  '''

  if not '__tablename__' in attrs:

    # Default table name is underscores and plural.

    tableName = StringUtil.camelCaseToUnderscores(className)
    tableName += 's'

    setAttr(clazz, attrs, '__tablename__', tableName)

#-------------------------------------------------------------------
def setUpColumns(className: str,
                 bases: List[type],
                 attrs: Dict[str, Any]) -> None:
  '''Modifies a DB class as the class is being created.
  '''

  tableName = attrs['__tablename__']

  for attrName, attrValue in sorted(attrs.items()):

    if isinstance(attrValue, (DbList, DbDict)):
      attrValue.process(className, bases, attrs, attrName)

    column = None
    if isinstance(attrValue, ColumnProperty):
      columnProperty = attrValue
      column = columnProperty.columns[0]

    if isinstance(attrValue, Column):
      column = attrValue

    if column is not None:

      # Auto-generate column names based on the attribute name.

      if not column.name:
        column.name = StringUtil.camelCaseToUnderscores(attrName)

      if column.foreign_keys:
        for foreignKey in column.foreign_keys:

          # Auto-generate constraint names based.

          if not foreignKey.name:
            name = ('fk_' + tableName + '_' + column.name + '_' +
                    foreignKey._colspec)
            name = name.replace('.', '_')
            foreignKey.name = name

          # Put foreign key constraint DDL at the end so we can sort tables
          # in alphabetical order.

          foreignKey.use_alter = True

#-------------------------------------------------------------------
def setUpUniqueConstraints(attrs: Dict[str, Any]) -> None:
  '''Modifies a DB class as the class is being created.
  '''

  tableName = attrs['__tablename__']

  if not '__table_args__' in attrs:
    return

  for tableArg in attrs['__table_args__']:
    if isinstance(tableArg, UniqueConstraint):
      constraint = tableArg

      if not constraint.name:

        # The constraint.columns member isn't initialized yet, so have to use
        # _pending_colargs.
        attrNames = constraint._pending_colargs

        columnNames = [StringUtil.camelCaseToUnderscores(x) for x in attrNames]

        name = 'un_' + tableName + '_' + '_'.join(columnNames)
        constraint.name = name
