#-------------------------------------------------------------------
#  DdlCollector.py
#
#  The DdlCollector class.
#
#  Copyright 2016 Applied Invention, LLC
#-------------------------------------------------------------------

'''The module containing the DdlCollector class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from sqlalchemy.schema import CreateTable
from sqlalchemy.engine.interfaces import Dialect
from typing import Any
from typing import Dict
from typing import List
#
# Import statements go above this line.
#-------------------------------------------------------------------


#===================================================================
class DdlCollector:
  '''Collects DDL statements.

  Should be used as the 'executor' of a SQL Alchemy create_engine()
  call to collect all the DDL SQL for a database.
  '''

  #-----------------------------------------------------------------
  def __init__(self) -> None:
    '''Creates a new DdlCollector.
    '''

    # The dialect to use.
    self.dialect: Dialect = None

    # List of CreateTable objects.
    self.createTables: List[CreateTable] = []

    # Other SQL that doesn't appear to be a 'create table'.
    self.statements: List[Any] = []

  #-----------------------------------------------------------------
  def execute(self, sql: str, *dummyMultiparams, **dummyParams) -> None:
    '''Saves the specified SQL.

    Called by SqlAlchemy during a 'create_all' DDL string build.
    '''

    if isinstance(sql, CreateTable):
      self.createTables.append(sql)
    else:
      self.statements.append(sql)

  #-----------------------------------------------------------------
  def createString(self) -> str:
    '''Creates a SQL string for the passed in clauses.
    '''

    assert self.dialect

    tableNames: List[str] = []

    # Dictionary of { tableName : CreateTable }.
    tableDict: Dict[str, CreateTable] = {}

    for createTable in self.createTables:
      table = createTable.element
      tableDict[table.name] = createTable
      tableNames.append(table.name)

    tableNames.sort()

    sql = ''

    for tableName in tableNames:
      createTable = tableDict[tableName]
      compiler = createTable.compile(dialect=self.dialect)
      tableSql = str(compiler)

      # Add a semicolon at the end of each CREATE.
      head, unusedSep, tail = tableSql.rpartition('\n)\n')
      tableSql = head + '\n);\n' + tail

      sql += tableSql

    statementSqls: List[str] = []
    for statement in self.statements:
      if isinstance(statement, str):
        statementStr = statement
      else:
        compiler = statement.compile(dialect=self.dialect)
        statementStr = str(compiler)
      statementSqls.append(statementStr)

    statementSqls.sort()

    sql += ';\n\n'.join(statementSqls)

    return sql
