#-------------------------------------------------------------------
#  AxeMapping.py
#
#  The AxeMapping module.
#
#  Copyright 2013 Applied Invention, LLC.
#-------------------------------------------------------------------

'''AXE-specific mapping alterations.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
import sqlalchemy
from sqlalchemy import Table
from sqlalchemy.sql.elements import ColumnElement
from typing import Any
#
# Import statements go above this line.
#-------------------------------------------------------------------


#-------------------------------------------------------------------
# Custom attributes that should be allowed in the Column() constructor.
columnAttrs = ('minValue', 'maxValue', 'values', 'positive', 'nonNegative')

#-------------------------------------------------------------------
def handleMapperConfigured(mapper: Any, clazz: Any) -> None:
  '''Called when a class has been mapped.
  '''

  # Suppress unused argument warning.
  # pylint: disable=self-assigning-variable
  mapper = mapper

  update(clazz.__table__)

#-------------------------------------------------------------------
def update(table: Table) -> None:
  '''Updates the specified SQL Alchemy Table.
  '''

  for column in table.columns:

    addMinMaxConstraint(table, column)
    addValueConstraint(table, column)
    addNonNegativeConstraint(table, column)
    addPositiveConstraint(table, column)

#-------------------------------------------------------------------
def addMinMaxConstraint(table: Table,
                        column: ColumnElement) -> None:
  '''Add a check constraint if the user supplied a min/max value.

  @param table The table to add the constraint to.
  @param column The column to add the constraint to.
  '''

  conditions = []

  if 'minValue' in column.info:
    conditions.append(column >= column.info['minValue'])

  if 'maxValue' in column.info:
    conditions.append(column <= column.info['maxValue'])

  if conditions:

    if len(conditions) == 1:
      condition = conditions[0]
    else:
      condition = sqlalchemy.and_(*conditions)

    addConditionConstraint(table, column, condition)

#-------------------------------------------------------------------
def addValueConstraint(table: Table, column: ColumnElement) -> None:
  '''Add a check constraint if the user supplied a value list.

  @param table The table to add the constraint to.
  @param column The column to add the constraint to.
  '''

  if 'values' in column.info:

    condition = column.in_(column.info['values'])

    addConditionConstraint(table, column, condition)

#-------------------------------------------------------------------
def addNonNegativeConstraint(table: Table, column: ColumnElement) -> None:
  '''Add a check constraint if the user supplied nonNegative=True.

  @param table The table to add the constraint to.
  @param column The column to add the constraint to.
  '''

  if 'nonNegative' in column.info and column.info['nonNegative']:

    condition = column >= 0

    addConditionConstraint(table, column, condition)

#-------------------------------------------------------------------
def addPositiveConstraint(table: Table, column: ColumnElement) -> None:
  '''Add a check constraint if the user supplied positive=True.

  @param table The table to add the constraint to.
  @param column The column to add the constraint to.
  '''

  if 'positive' in column.info and column.info['positive']:

    condition = column > 0

    addConditionConstraint(table, column, condition)

#-------------------------------------------------------------------
def addConditionConstraint(table: Table,
                           column: ColumnElement,
                           condition: ColumnElement) -> None:
  '''Add a check constraint for the specified condition.

  @param table The table to add the constraint to.
  @param column The column to add the constraint to.
  @param condition The query condition to enforce.
  '''

  name = "ck_" + table.name + "_" + column.name

  conditionStr = condition.compile(compile_kwargs={"literal_binds":True})
  conditionStr = str(conditionStr)

  tableColumn = table.name + "." + column.name
  conditionStr = conditionStr.replace(tableColumn, column.name)
  conditionStr = conditionStr.replace("IN", "in")

  constraint = sqlalchemy.schema.CheckConstraint(conditionStr, name=name)
  table.append_constraint(constraint)
