#-------------------------------------------------------------------
#  DatabaseMgr.py
#
#  The DatabaseMgr class.
#
#  Copyright 2017 Applied Invention, LLC
#-------------------------------------------------------------------

'''The module containing the DatabaseMgr class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from .DbConfig import DbConfig
from .DdlCollector import DdlCollector
from ai.axe.db.alchemy import AxeMapping
from ai.axe.db.alchemy.declarative.AxeDbClass import AxeDbClass
from ai.axe.util import StringUtil
from datetime import date
from datetime import datetime
from datetime import timedelta
from sqlalchemy.engine import Engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm.session import sessionmaker
from sqlalchemy.orm.session import Session
from sqlalchemy.schema import MetaData
from sqlalchemy.schema import Table
from typing import Callable
from typing import List
from typing import TypeVar
import ai.axe.db.alchemy
import ai.axe.math
import ai.axe.wkt
import sqlalchemy
import sqlalchemy.orm.mapper
import traceback
#
# Import statements go above this line.
#-------------------------------------------------------------------

#-------------------------------------------------------------------
# Generic value that is the return type of 'executeInSession()'.
ReturnType = TypeVar('ReturnType')

#===================================================================
class DatabaseMgr:
  '''Manager that sets up and manages a SQL Alchemy database.

  To use this class in your project, you should subclass it
  and override the importAllClasses() method.

  During app startup, you should first import this class and
  call createSqlBase() to generate a SqlBase class.  All your
  DB classes should extend this base class.

  When you are ready to connect to the database, you should create
  a single DatabaseMgr object, passing in your SqlBase class,
  and use the DatabaseMgr to create sessions.
  '''

  #-----------------------------------------------------------------
  @staticmethod
  def createSqlBase() -> type:
    '''Returns a base class that all your DB classes should extend.
    '''

    sqlBase = declarative_base(metaclass=AxeDbClass)
    return sqlBase

  #-----------------------------------------------------------------
  def __init__(self,
               dbConfig: DbConfig,
               sqlBase: type,
               logFunction: Callable[[str, str], None],
               dumpSql: bool = False) -> None:
    '''Creates a new DatabaseMgr.

    This will call importAllClasses() to import all your classes.

    @param dbConfig The configuration to use to contact the database.
    @param sqlBase The base class for all your DB classes in this database.
                   It must have been created by calling the
                   DatabaseMgr.createSqlBase() static method.
    @param logFunction A function that will be called to log any exception
                       thrown during executeInSession().  The function
                       must take 2 arguments:  a string error message,
                       and a string exception stack trace.
    @param dumpSql If True, every SQL statement will be written to stdout.
    '''

    # The configuration to use to contact the database.
    self.dbConfig: DbConfig = dbConfig

    # The base class for all DB classes in this database.
    self.sqlBase: type = sqlBase

    # If True, every SQL statement will be written to stdout.
    self.dumpSql: bool = dumpSql

    # A list of all database classes.
    self.dbClasses: List[type] = []

    # The DB session factory.
    self.sessionMaker: sessionmaker = None

    # The DB engine used by sessions created by this object.
    self.engine: Engine = None

    # A log function to call to log exceptions in executeInSession().
    self.logException: Callable[[str, str], None] = logFunction

    # Make sure all classes have been imported.
    classes = self.importAllClasses()
    self.dbClasses.extend(classes)

    self.recreateSessionMaker()

  #-----------------------------------------------------------------
  def recreateSessionMaker(self) -> None:
    '''Recreates the session maker.

    After fork-ing a new process, you must re-create the engine to
    prevent the old and new process's connection pools both trying
    to talk to the database on the same TCP/IP ports.

    This function accomplishes that.
    '''

    dbConfig = self.dbConfig

    self.engine = sqlalchemy.create_engine(dbConfig.url,
                                           pool_size=dbConfig.poolSize,
                                           max_overflow=dbConfig.poolOverflow,
                                           pool_recycle=dbConfig.poolRecycle,
                                           pool_timeout=dbConfig.poolTimeout,
                                           echo=self.dumpSql)
    self.sessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine,
                                                    autocommit=False)

  #-----------------------------------------------------------------
  def importAllClasses(self) -> List[type]:
    '''Imports all DB classes and returns the list of classes.

    This method should import all classes that will be read/written
    to the database.  All these classes must extend 'self.sqlBase'
    as their base class.

    @return A list of class objects that is all the classes to be
            used in the datbase.
    '''

    # Your implementation should look something like this.

    #   # All DB classes must be imported so the session knows about them.
    #   # pylint: disable=W0621
    #   import ram.account.Account
    #   import ram.account.UserSession
    #   import ram.account.UserOperation
    #   import ram.field
    #   import ram.clusterView
    #   # pylint: enable=W0621
    #
    #   classes = [
    #     ram.account.Account,
    #     ram.account.UserSession,
    #     ram.account.UserOperation,
    #     ram.clusterView.ClusterSet,
    #     ram.clusterView.FirstItem,
    #     ram.clusterView.SecondItem,
    #     ram.clusterView.FirstSecondValue,
    #     ram.field.CropZone,
    #     ram.field.Eru,
    #     ram.field.Farm,
    #     ram.field.Field,
    #     ram.field.Operation,
    #     ram.field.TimeZone,
    #   ]
    #
    # return classes

    raise NotImplementedError()

  #-----------------------------------------------------------------
  @staticmethod
  def setUpTypeAliases() -> None:
    '''Sets up basic type aliases.

    This allows users to create columns like Column(bool), Column(Point)
    instead of Column(Boolean), Column(PointType).
    '''

    addTypeMapping = ai.axe.db.alchemy.addTypeMapping
    types = ai.axe.db.alchemy.types                     # type: ignore

    addTypeMapping(bool, sqlalchemy.Boolean)
    addTypeMapping(date, sqlalchemy.Date)
    addTypeMapping(datetime, sqlalchemy.DateTime)
    addTypeMapping(timedelta, sqlalchemy.Interval(second_precision=3))
    addTypeMapping(float, sqlalchemy.Float)
    addTypeMapping(int, sqlalchemy.Integer)
    addTypeMapping(str, sqlalchemy.String)
    addTypeMapping(ai.axe.math.Angle, types.AngleType)
    addTypeMapping(ai.axe.math.Rectangle, types.RectangleType)
    addTypeMapping(ai.axe.wkt.Point, types.PointType)
    addTypeMapping(ai.axe.wkt.Point3d, types.Point3dType)
    addTypeMapping(ai.axe.wkt.MultiPolygon, types.MultiPolygonType)
    addTypeMapping(ai.axe.wkt.Polygon, types.PolygonType)

  #-------------------------------------------------------------------
  def eraseAll(self, session: Session) -> None:
    '''Deletes all data from a database.
    '''

    for clazz in self.dbClasses:
      session.query(clazz).delete()

  #-------------------------------------------------------------------
  def create(self) -> Session:
    '''Creates a Sqlalchemy session for the current config.

    @return A Sqlalchemy session.
    '''

    session = self.sessionMaker()
    return session

  #-------------------------------------------------------------------
  def generateDdl(self, tables: List[Table] = None) -> str:
    '''Returns a string with the DDL for the database.

    @param tables Optional list of Table objects.  If None, all tables will
                  be included.
    '''

    ddl = DdlCollector()
    engine = sqlalchemy.create_engine(self.dbConfig.url,
                                      strategy='mock',
                                      executor=ddl.execute)
    ddl.dialect = engine.dialect

    metadata: MetaData = self.sqlBase.metadata # type: ignore
    metadata.create_all(engine, tables=tables, checkfirst=False)

    return ddl.createString()

  #-------------------------------------------------------------------
  def executeInSession(self,
                       function: Callable[..., ReturnType],
                       *args) -> ReturnType:
    '''Executes the specified function inside a database session.

    @param function The function to execute.  It's first argument should be
                    a database session.
    @param args The arguments to pass into the function.

    @return The value that was returned by the function.
    '''

    return self.executeInCreatedSession(self.create, function, *args)

  #-------------------------------------------------------------------
  def executeInCreatedSession(self,
                              createSessionFunc: Callable[[], Session],
                              function: Callable[..., ReturnType],
                              *args) -> ReturnType:
    '''Executes the specified function inside the specified database session.

    @param createSessionFunc Function to create the DB Session to run
                             the function inside.
    @param function The function to execute.  It's first argument should be
                    a database session.
    @param args The arguments to pass into the function.

    @return The value that was returned by the function.
    '''

    session: Session = createSessionFunc()

    try:

      allArgs = [session]
      allArgs.extend(args)

      ret = function(*allArgs)

      session.commit()

      return ret

    # pylint: disable = W0703
    except Exception as ex:

      msg = "Exception while running function in DB session."
      self.logException(msg, traceback.format_exc())
      DatabaseMgr.handleException(session, ex)

    finally:

      session.close()

    return None # type: ignore

  #-------------------------------------------------------------------
  @staticmethod
  def handleException(session: Session, ex: Exception) -> None:
    '''Handles an exception thrown during a session.

     This can only be called inside an 'except' clause.
    '''

    # pylint: disable=W0703

    if session:
      try:
        session.rollback()
      except Exception:
        # Since there's already an exception in progress, ignore
        # any further exception.
        pass

    raise ex

  #----------------------------------------------------------------
  def __repr__(self) -> str:
    '''Returns a string representation of this object
    '''
    attrs = ['dbClasses', 'sessionMaker']

    return StringUtil.formatRepr(self, attrs)

#-------------------------------------------------------------------
# Static setup done at class import time.

# Configure the class events before the classes are imported.
sqlalchemy.event.listen(sqlalchemy.orm.mapper,
                        'instrument_class',
                        AxeMapping.handleMapperConfigured)

# Configure SQL Alchemy before the classes are imported.
DatabaseMgr.setUpTypeAliases()
