#-------------------------------------------------------------------
#  DbUpdater.py
#
#  The DbUpdater class.
#
#  Copyright 2017 Applied Invention, LLC
#-------------------------------------------------------------------

'''The module containing the DbUpdater class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from ai.axe.util import StringUtil
from ai.axe.web.app import AppSetup
from ai.axe.web.dbVersion import AppVersionClass
from ai.axe.web.dbVersion import AppVersionDb
from .VersionDirReader import VersionDirReader
from .VersionTreeNode import VersionTreeNode
from importlib import import_module
from sqlalchemy.orm.session import Session
from typing import Any
from typing import List
from typing import Optional
from typing import Tuple
import os
#
# Import statements go above this line.
#-------------------------------------------------------------------


#===================================================================
class DbUpdater:
  '''Updates the database schema to the latest version.
  '''

  #-----------------------------------------------------------------
  def __init__(self, session: Session, isVerbose: bool) -> None:
    '''Creates a new DbUpdater.
    '''

    # The database to update.
    self.session: Session = session

    # True if verbose output will be written.
    self.isVerbose: bool = isVerbose

  #-----------------------------------------------------------------
  def update(self,
             useRelease: bool,
             doUpdate: bool,
             dbVersionToSet: Optional[str],
             maxDbVersion: Optional[str]) -> None:
    '''Updates the database to the latest version.

    @param useRelease If True, the database will be updated to the version
                      of the Python library.
                      If False, the database will be updated to the latest
                      version file found in the SQL directory.
    @param doUpdate If False, information about files to be executed will be
                    printed, but no schema changes will be made.  If True,
                    the SQL files will be executed and schema changed.
    @param dbVersionToSet The DB version string to set in the DB before doing
                          any upgrade.  If None, the version will be left as-is.
    @param maxDbVersion The version to update the database to.
                        If None, will be updated to the latest version.
    '''

    self.initTable(self.session)
    appVersionObj = AppVersionDb.read(self.session)

    if dbVersionToSet is not None:
      print("Recording current schema as version: %s" % dbVersionToSet)
      print()
      appVersionObj.dbVersion = dbVersionToSet

    currentNode, targetNode = self.findVersions(appVersionObj,
                                                useRelease,
                                                maxDbVersion)

    if doUpdate:
      self.updateDb(appVersionObj, currentNode, targetNode)
    else:
      print("No SQL files executed because '-run' flag not provided.")

  #-----------------------------------------------------------------
  def initTable(self, session: Session) -> None:
    '''Creats and initializes the version table if neccessary.

    @param session The database to write to.
    '''

    # If the table doesn't exist, assume we have a completely empty database.
    # Create the table and initialize to version zero.
    if not AppVersionDb.tableExists(session):

      AppVersionDb.createTable(session)
      print("Creating app_version table...")
      print()

      # Initialize to version zero.
      appVersion = AppVersionClass.get()
      AppVersionDb.write(session, appVersion("", "000"))

    # If the table does exists, but is empty, assume that the user has
    # recently created the table with 'createTables', so the version
    # should be set to the latest version.
    elif AppVersionDb.count(session) == 0:

      # Initialize to latest version.
      appVersion = AppVersionClass.get()
      AppVersionDb.write(session, appVersion("", "000"))

  #-----------------------------------------------------------------
  def findVersions(self,
                   appVersionObj: Any,
                   useRelease: bool,
                   maxDbVersion: Optional[str]) -> Tuple[VersionTreeNode,
                                                         VersionTreeNode]:
    '''Find the from- and to- version nodes.
    '''

    versionTree = VersionDirReader(self.sqlDir()).read()

    # Look up the app's current version from the Version.py file in the
    # root package.
    appSetup = AppSetup.get()
    rootPackage = import_module(appSetup.rootPackageName())
    currentPyVersion = rootPackage.version           # type: ignore

    currentDbVersion = appVersionObj.dbVersion

    currentNode = versionTree.findDbVersion(currentDbVersion)

    targetNode = currentNode.leafNode()
    if useRelease:
      targetNode = currentNode.findAppVersion(currentPyVersion)
    if maxDbVersion is not None:
      targetNode = currentNode.findDbVersion(maxDbVersion)

    print("Current schema version: ", currentNode.dbVersion)
    print("Upgrading to schema version: ", targetNode.dbVersion)
    print("Current app version: ", currentNode.appVersion)

    if useRelease:
      print("Upgrading to app version: ", targetNode.appVersion)

    sqlFileNames = currentNode.sqlFileNames(targetNode)
    print()
    print("SQL files to be executed:")
    if sqlFileNames:
      print("  " + "\n  ".join(sqlFileNames))
    else:
      print("  <None>")
    print()

    return currentNode, targetNode

  #-----------------------------------------------------------------
  def updateDb(self,
               appVersionObj: Any,
               currentNode: VersionTreeNode,
               targetNode: VersionTreeNode) -> None:
    '''Updates the DB to the specified version node.
    '''

    node = currentNode
    while node != targetNode:

      node = node.children[0]

      assert node.sqlFileName is not None

      print("Executing", node.sqlFileName, "...")
      self.executeSqlFile(node.sqlFileName)
      print("  ...OK")

    # Update the version table.

    appVersionObj.dbVersion = targetNode.dbVersion
    if targetNode.appVersion is not None:
      appVersionObj.appVersion = targetNode.appVersion

  #----------------------------------------------------------------
  def executeSqlFile(self, sqlFileName: str) -> None:
    '''Executes the specified sql file.
    '''

    fileName = os.path.join(self.sqlDir(), sqlFileName)

    sqls = open(fileName).read()
    self.session.execute(sqls)

  #----------------------------------------------------------------
  def sqlDir(self) -> str:
    '''Returns the name of the directory that SQL files are kept in.
    '''

    sqlDirKey = AppSetup.get().appNameAllCaps() + "_SQL"
    return os.environ[sqlDirKey]

  #----------------------------------------------------------------
  def __repr__(self) -> str:
    '''Returns a string representation of this object
    '''
    attrs: List[str] = []

    return StringUtil.formatRepr(self, attrs)
