#-------------------------------------------------------------------
#  CommandRunner.py
#
#  The CommandRunner class.
#
#  Copyright 2016 Applied Invention, LLC
#-------------------------------------------------------------------

'''The module containing the CommandRunner class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from .BatchCommand import BatchCommand
from .AxeCommand import AxeCommand
from .AxeCommandBase import AxeCommandBase
from collections import OrderedDict
from datetime import datetime
from ai.axe.web.config import Config
from ai.axe.web.config import DbConfig
from ai.axe.web.core import AxeException
from ai.axe.web.core import Log
from ai.axe.web.app import AppSetup
from sqlalchemy.orm.session import Session
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Type
import inspect
import shlex
import sys
import traceback
#
# Import statements go above this line.
#-------------------------------------------------------------------


#===================================================================
class CommandRunner:
  '''Runs a command-line action.
  '''

  #-----------------------------------------------------------------
  def __init__(self,
               commandClass: Type[AxeCommandBase],
               args: List[str],
               wantsTabCompletion: bool) -> None:
    '''Creates a new CommandRunner.

    @param commandClass The class of the command object to be run.
    @param args The arguments to pass into the command.
    @param wantsTabCompletion If True, the command should print a list of
                              command-line tab-complete suggestions
                              rather than actually running the command.

    '''

    # The class of the command object to be run.
    self.commandClass: Type[AxeCommandBase] = commandClass

    # The command to be run.
    self.command: AxeCommandBase = commandClass()

    # The command-line arguments.
    self.args: List[str] = args

    # Whether the user has asked for tab auto-complete suggestions.
    self.wantsTabCompletion: bool = wantsTabCompletion

    # The time this command started running.
    self.beginTime: Optional[datetime] = None

    # The time the current batch started running.
    self.batchBeginTime: Optional[datetime] = None

    # Which batch we're currently running.  Zero means no batch.
    self.batchNum: int = 0

    # The labels to use in logging each batch.
    self.batchLabels: List[str] = []

    # Batches that failed.
    self.failedBatchLabels: List[str] = []

  #-----------------------------------------------------------------
  @staticmethod
  def run() -> None:
    '''Runs a command.
    '''

    appSetup = AppSetup.get()
    appName = appSetup.appName()


    # Parse command-line arguments.

    commandName: Optional[str] = None

    argDbUrl: Optional[str] = None
    argTestDb: bool = False
    wantsHelp: bool = False
    wantsTabCompletion: bool = False
    completionLine: Optional[str] = None

    args = sys.argv[1:]

    while args:

      if args[0] in ('-h', '--help', 'help'):
        wantsHelp = True
        args.pop(0)

      elif args[0] == '-testDb':

        argTestDb = True
        args.pop(0)

      elif args[0] == '-connect':

        args.pop(0)

        # It's an error if the connect string is missing.
        # If the user is asking for complete suggestions, we just
        # return blank.
        if not args:
          if not wantsTabCompletion:
            print('Error: -connect flag not followed by URL.')
          sys.exit(1)

        argDbUrl = args.pop(0)

      elif args[0] == '-complete':

        args.pop(0)
        wantsTabCompletion = True

        # The user has asked for auto complete suggestions, so replace
        # the args with what's been typed so far and re-parse.
        completionLine = args.pop(0)
        args = shlex.split(completionLine)[1:]

        # If the user has typed a blank space and hit tab, he's asking
        # about the next argument.  Add that extra implied 'blank'
        # argument.
        if completionLine and completionLine[-1].isspace():
          args.append('')

      else:
        commandName = args.pop(0)
        break

    # wantsHelp will be True whether the user has asked for global help
    # or help for a single command.
    wantsHelp = wantsHelp or '-h' in args or '--help' in args


    # Perform app initialization.  All commands will be registered.

    # A function that will modify a DbConfig to replace its URL
    # with the URL from the command line if appropriate.
    #
    # Note that this is a closure on argDbUrl and argTestDb.

    def updateDbConfig(dbConfig: DbConfig) -> DbConfig:
      return CommandRunner.chooseDbUrl(dbConfig, argDbUrl, argTestDb)

    appSetup.initCommandLineFull(updateDbConfig)

    commands: Dict[str,
                   Type[AxeCommandBase]] = CommandRunner.createCommandDict()

    # Find the command that the user has requested.

    commandClass: Optional[type] = None

    if commandName and commandName in commands:
      commandClass = commands[commandName]

    # Write out command tab-completion suggestions if it was requested,
    # and the user hasn't finished typing a command yet.

    if wantsTabCompletion and (not commandName or not commandClass):

      options = "-h --help help -testDb -connect"
      options += " " + " ".join(commands.keys())
      print(options)
      sys.exit(0)

    # Handle an invalid command name.

    if commandName and not commandClass:
      print('Error: "%s" is an invalid command name.' % (commandName))
      sys.exit(1)

    # Write out help if it was requested.

    if not wantsHelp and not commandClass:
      msg = 'Error: no command given.  Run "%sAdmin help" for a list.'
      msg = msg % appName
      print(msg)
      sys.exit(1)

    if wantsHelp and not commandClass:
      usage = """Usage: %sAdmin [-testDb] [-connect url] command arg1 arg2 ...

Runs a %s admin command.

Arguments:

  -testDb Runs the command against the test database.
  -connect Runs the command against the specified database URL.

Commands:
"""
      usage = usage % (appName, appName)
      print(usage)
      CommandRunner.printCommands(commands)
      print()

      sys.exit(1)

    elif wantsHelp:

      assert commandClass is not None

      print(commandClass.usage) # type: ignore
      sys.exit(1)


    # Run the chosen command.

    assert commandClass is not None

    runner = CommandRunner(commandClass, args, wantsTabCompletion)
    runner.runCommand()

  #-----------------------------------------------------------------
  @staticmethod
  def createCommandDict() -> Dict[str, Type[AxeCommandBase]]:
    '''Returns a dict of all available command objects, keyed by command name.
    '''

    commands: Dict[str, Type[AxeCommandBase]] = OrderedDict()

    for cmd in BatchCommand.registeredCommands:
      if cmd.name in commands:
        raise AxeException("Duplicate command name: " + str(cmd.name))
      commands[cmd.name] = cmd

    for cmd in AxeCommand.registeredCommands:
      if cmd.name in commands:
        raise AxeException("Duplicate command name: " + str(cmd.name))
      commands[cmd.name] = cmd

    # Alphabetize by command name.
    unorderedCommands = commands
    commands = OrderedDict()
    for name in sorted(unorderedCommands.keys()):
      commands[name] = unorderedCommands[name]

    return commands

  #-----------------------------------------------------------------
  @staticmethod
  def printCommands(commands: Dict[str, Type[AxeCommandBase]]) -> None:
    '''Prints all the commands and descriptions.

    @param commands A {commandName: commandClass} dictionary.
    '''

    names = commands.keys()
    descriptions = [clazz.description for clazz in commands.values()]

    maxName = max([len(x) for x in names])
    maxDesc = max([len(x) for x in descriptions])

    formatStr = ' %-' + str(maxName) + 's %-' + str(maxDesc) + 's'

    for name, description in zip(names, descriptions):
      print(formatStr % (name, description))

  #-----------------------------------------------------------------
  def runCommand(self) -> None:
    '''Runs this runner's command.
    '''

    self.beginTime = datetime.utcnow()

    if self.command.isBatchCommand:
      self.executeInit()

      while self.batchLabels:
        self.batchNum += 1
        self.executeRun()

      self.batchNum = 0
      self.logAction('', False)

    else:
      self.executeRun()

    if self.command.isBatchCommand:
      self.executeFini()

  #-----------------------------------------------------------------
  def executeInit(self) -> None:
    '''Executes a batch command's init() method.
    '''

    databaseMgr = AppSetup.get().databaseMgr()

    session = databaseMgr.create()
    try:

      initFunc = getattr(self.command, 'init')
      funcArgs = CommandRunner.createArgs(initFunc,
                                          session,
                                          self.args,
                                          self.wantsTabCompletion)
      batchLabels = initFunc(**funcArgs)
      session.commit()

      # The self.batchLabels list controls how many batches will be run.
      # If commit succeeded, it's safe to assign to batchLabels.
      self.batchLabels = batchLabels

    except Exception:

      self.doPostFailure(traceback.format_exc())
      session.rollback()
      raise

    finally:

      session.close()

  #-----------------------------------------------------------------
  def executeFini(self) -> None:
    '''Executes a batch command's fini() method.
    '''

    databaseMgr = AppSetup.get().databaseMgr()

    session = databaseMgr.create()
    try:

      finiFunc = getattr(self.command, 'fini')
      funcArgs = CommandRunner.createArgs(finiFunc,
                                          session,
                                          self.args,
                                          self.wantsTabCompletion)
      finiFunc(**funcArgs)
      session.commit()

    except Exception:

      self.doPostFailure(traceback.format_exc())
      session.rollback()
      raise

    finally:

      session.close()

  #-----------------------------------------------------------------
  def executeRun(self) -> None:
    '''Executes this command's run() method.
    '''

    self.batchBeginTime = datetime.utcnow()

    databaseMgr = AppSetup.get().databaseMgr()

    session = databaseMgr.create()
    try:

      runFunc = getattr(self.command, 'run')
      funcArgs = CommandRunner.createArgs(runFunc, session, self.args,
                                          self.wantsTabCompletion)

      runFunc(**funcArgs)

      session.commit()
      self.doPostSuccess()

    # pylint: disable = W0703
    except Exception:

      self.doPostFailure(traceback.format_exc())
      session.rollback()

      # Batch actions don't re-raise, so further batches can continue.
      if not self.command.isBatchCommand:
        raise

    finally:

      session.close()

  #-----------------------------------------------------------------
  def doPostSuccess(self) -> None:
    '''Called after a successful action.
    '''

    self.logAction('', False)

  #-----------------------------------------------------------------
  def doPostFailure(self, exceptionStr: str) -> None:
    '''Called after a failed action.

    @param ex The exception that caused the failure.
    '''

    # Save the failed batch label.
    if self.batchLabels:
      self.failedBatchLabels.append(self.batchLabels[0])

    # Log the action failure.

    msg = "  Action failed because of exception:\n"
    msg += "-------------------------------------------------\n"
    msg += str(exceptionStr)

    self.logAction(msg, True)

  #-----------------------------------------------------------------
  def logAction(self, msg: str, wasError: bool) -> None:

    assert self.beginTime is not None

    userName = 'ROOT'

    beginTime = self.beginTime
    if self.batchNum > 0:
      assert self.batchBeginTime is not None
      beginTime = self.batchBeginTime

    now = datetime.utcnow()
    duration = now - beginTime

    actionName = self.commandClass.name.split('.')[-1]

    if self.batchNum > 0:
      actionName += " (batch " + str(self.batchNum) + ")"

    logMsg = ''
    logMsg += actionName + '\n'
    logMsg += "  " + " ".join(self.args) + '\n'
    logMsg += "  " + userName + '\n'
    logMsg += "  Duration: " + str(duration) + "\n"

    if not self.batchLabels and self.failedBatchLabels:
      batchMsg = "  Failed batches:\n"
      for batchLabel in self.failedBatchLabels:
        batchMsg += "    %s\n" % batchLabel
      logMsg += batchMsg
      print(batchMsg)

    if self.batchLabels:
      logMsg += "  Batch: " + str(self.batchLabels.pop(0)) + "\n"

    logMsg += msg

    Log.logAction(actionName,
                  logMsg,
                  wasError,
                  userName,
                  'localhost')

  #-----------------------------------------------------------------
  @staticmethod
  def createArgs(actionFunction: Callable[..., Any],
                 session: Session,
                 cmdArgs: List[str],
                 wantsTabCompletion: bool) -> Dict[str, object]:
    '''Returns a dictionary of arguments to be passed to the function.
    '''

    signature: inspect.Signature = inspect.signature(actionFunction, follow_wrapped=True)

    config = Config.readConfig()

    # Make a copy so any arg parsing in the command doesn't change the original.
    cmdArgs = cmdArgs[:]

    funcArgs: Dict[str, object] = OrderedDict()

    for arg in signature.parameters:

      if arg == 'session':
        funcArgs[arg] = session
      elif arg == 'args':
        funcArgs[arg] = cmdArgs
      elif arg == 'wantsTabCompletion':
        funcArgs[arg] = wantsTabCompletion
      elif arg == 'config':
        funcArgs[arg] = config
      elif arg == 'self':
        pass
      else:
        msg = "Unknown command function argument '%s' in function: %s"
        msg = msg % (arg, str(actionFunction))
        raise AxeException(msg)

    return funcArgs

  #-----------------------------------------------------------------
  @staticmethod
  def chooseDbUrl(dbConfig: DbConfig,
                  argDbUrl: Optional[str],
                  argTestDb: bool) -> DbConfig:
    '''Returns the DB config information.

    @param dbConfig The dbConfig from the config file.
    @param argDbUrl The URL the user passed in on the command line, or None.
    @param argTestDb True if the test DB should be used.
    '''

    if argTestDb:
      dbConfig = dbConfig.unitTestConfig()

    if argDbUrl:
      dbConfig.url = argDbUrl

    return dbConfig
