#-------------------------------------------------------------------
#  AxeTestCase.py
#
#  The AxeTestCase class.
#
#  Copyright 2016 Applied Invention, LLC
#-------------------------------------------------------------------

'''The module containing the AxeTestCase class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
import unittest
from .AxeSimpleTestCase import AxeSimpleTestCase
from ai.axe.db.DatabaseMgr import DatabaseMgr
from ai.axe.web.app import AppSetup
from sqlalchemy.orm.session import Session
from typing import Optional
#
# Import statements go above this line.
#-------------------------------------------------------------------


#===================================================================
class AxeTestCase(AxeSimpleTestCase):
  '''Unittest TestCase with AXE-specific functionality added.
  '''

  #-----------------------------------------------------------------
  # Make assertEqual() messages be added to the default message,
  # rather than replacing it.
  longMessage: bool = True

  # The DB manager to be used to create sessions.
  databaseMgr: Optional[DatabaseMgr] = None

  #-----------------------------------------------------------------
  def __init__(self, *args, **kwargs) -> None:
    '''Creates a new AxeTestCase.
    '''

    AxeSimpleTestCase.__init__(self, *args, **kwargs)

    # The current DB session to be used by the test.
    self.sessionObj: Optional[Session] = None

    # Make sure the unit test DB system has been set up.
    AppSetup.get().initUnitTestFull()

  #-----------------------------------------------------------------
  def session(self) -> Session:
    '''Returns a SqlAlchemy DB session for use in testing.

    If no session currently exists, creates a new one.

    This should be called within a unit test to get the session.
    '''

    if not self.sessionObj:
      self.startNewSession()

    return self.sessionObj

  #-----------------------------------------------------------------
  def startNewSession(self) -> Session:
    '''Commits the current transaction and starts a new session.
    '''

    assert AxeTestCase.databaseMgr is not None

    databaseMgr = AxeTestCase.databaseMgr

    # If this is the first time this test has created a session,
    # we'll clear out the database.
    firstTime = not self.sessionObj

    self.endTransaction(True)
    self.closeSession()

    self.sessionObj = databaseMgr.create()
    #self.transaction = self.sessionObj.begin()

    if firstTime:
      databaseMgr.eraseAll(self.sessionObj)
      self.sessionObj.commit()

    return self.sessionObj

  #-----------------------------------------------------------------
  def endTransaction(self, commit: bool) -> None:
    '''Ends the current transaction, either by commit or rollback.

    This is safe to call when the transaction is NULL.

    @param commit  If true, the tranaction will be committed.  If false,
                   it will be rolled back.
    '''

    if self.sessionObj:
      if commit:
        self.sessionObj.commit()
      else:
        self.sessionObj.rollback()

  #-----------------------------------------------------------------
  def closeSession(self) -> None:
    ''''Closes the current session.

    This is safe to call when the session is NULL.
    '''

    sessionToClose = self.sessionObj

    # Set None before closing in case an exception is thrown.
    self.sessionObj = None

    if sessionToClose:
      #sessionToClose.close()
      pass

  #-----------------------------------------------------------------
  def run(self,     # type: ignore # Typeshed stub has wrong return type.
          result: Optional[unittest.TestResult] = None) -> None:
    '''Called by JUnit to run a test.
    '''

    # Call the 'TestCase.run' method inside a session.
    self.runInSession(result)

  #-----------------------------------------------------------------
  def runInSession(self, result: Optional[unittest.TestResult]) -> None:
    '''Private method to run the test method inside a Session.
    '''

    assert AxeTestCase.databaseMgr is not None
    assert not self.sessionObj

    try:

      # Run the wrapped method.
      unittest.TestCase.run(self, result)

      self.endTransaction(True)

    # pylint: disable = W0703
    except Exception as ex:
      AxeTestCase.databaseMgr.handleException(self.sessionObj, ex)

    finally:
      self.closeSession()

  #-----------------------------------------------------------------
  @staticmethod
  def setDatabaseMgr(databaseMgr: DatabaseMgr) -> None:
    '''Sets the DbManager object to be used to create DB sessions.

    @param databaseMgr A DbManager object.
    '''

    AxeTestCase.databaseMgr = databaseMgr
