#-------------------------------------------------------------------
#  ActionProcessor.py
#
#  The ActionProcessor class.
#
#  Copyright 2016 Applied Invention, LLC
#-------------------------------------------------------------------

'''The module containing the ActionProcessor class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from .AuthMgr import AuthMgr
from .ParamValueException import ParamValueException
from .ResponseEncoder import ResponseEncoder
from .WebWorkerContext import WebWorkerContext
from .WebWorkerRunner import WebWorkerRunner
from .desc.WebDesc import WebDesc
from .desc.WebParam import WebParam
from .desc.WebVerbDesc import WebVerbDesc
from ai.axe.db.DatabaseMgr import DatabaseMgr
from ai.axe.web.config import Config
from ai.axe.web.core import AxeException
from ai.axe.web.core import Log
from ai.axe.util import StringUtil
from collections import OrderedDict
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.datastructures import FileStorage
from werkzeug.wrappers import Request
from werkzeug.wrappers import Response
from datetime import datetime
from sqlalchemy.orm.session import Session
import os
import traceback
import urllib
#
# Import statements go above this line.
#-------------------------------------------------------------------


#===================================================================
class ActionProcessor(WebWorkerContext):
  '''Runs a single action.
  '''

  #-----------------------------------------------------------------
  def __init__(self,
               isApiRequest: bool,
               request: Request,
               authMgr: AuthMgr,
               actionFunc: Callable[..., Any],
               databaseMgr: DatabaseMgr) -> None:
    '''Creates a new ActionProcessor.
    '''

    WebWorkerContext.__init__(self)

    # The Authentication/Authorization manager.
    self.authMgr: AuthMgr = authMgr

    # True if this is a web API action call, False if it is an app action call.
    self.isApiRequest: bool = isApiRequest

    # The HTTP request.
    self.request: Request = request

    # The DB manager.
    self.databaseMgr: DatabaseMgr = databaseMgr

    # The DB session.
    self.session: Optional[Session] = None

    # Multi-dict of request params.
    self.requestParams: Dict[str, List[str]] = self.extractParams(request)

    # The action to be executed.
    self.actionFunc: Callable[..., Any] = actionFunc

    # The user who is executing this action.
    self.userInfo: object = self.authMgr.userInfo()

    # The userName at the start of the action.
    self.initialUserName: str = self.authMgr.userName()

    # The name of the web worker that is currently running.
    #
    # This will only be set in a web worker sub-proces.
    # During normal action processing, this will be None.
    self.webWorkerName: Optional[str] = None

    # When the action began.
    self.beginTime: Optional[datetime] = None

  #-----------------------------------------------------------------
  def process(self) -> Response:
    '''Process an HTTP request from Flask.
    '''

    # pylint:  disable = R0912
    # pylint:  disable = R0915

    self.beginTime = datetime.utcnow()

    appConfig: Config = Config.readConfig()

    # The DB session for this request.
    self.session = self.databaseMgr.create()

    try:

      # Prepare the parameters.

      webVerb: WebVerbDesc = WebDesc.getForFunction(self.actionFunc)
      paramDescs: List[WebParam] = webVerb.requestParams
      contextParamDescs: List[Tuple[str, str]] = webVerb.contextParams
      isDynamicParams = webVerb.dynamicParams

      if not isDynamicParams:

        extra, missing = self.checkParams(paramDescs,
                                          self.requestParams,
                                          self.request.files)
        if extra or missing:
          msg, response = self.extraMissingResponse(extra, missing)
          self.logAction(msg, True)
          return response

      # Set up the request params.

      try:
        funcArgs: Dict[str, object] = self.createRequestParams(isDynamicParams,
                                                               paramDescs)

      except ValueError as ex:
        self.logAction(str(ex), True)
        return Response('Server Error', 500)

      # Set up the context params.

      for paramName, contextParamType in contextParamDescs:
        value = self.createContextParam(self.session,
                                        appConfig,
                                        contextParamType)
        funcArgs[paramName] = value

      responseClass = webVerb.responseClass

      responseObj = self.actionFunc(**funcArgs)

      # Have to encode inside the DB transaction, so lazy objects can load.
      response = ResponseEncoder.encode(responseClass, responseObj)

      self.session.commit()
      self.doPostSuccess()

      return response

    # pylint: disable = W0703
    except Exception as ex:

      self.doPostFailure(traceback.format_exc())
      self.session.rollback()

      if self.isApiRequest and isinstance(ex, ParamValueException):
        msg = "Param value error: %s" % ex
        return Response(msg, 400)
      else:
        return Response('Server error.', 500)

    finally:

      self.session.close()

  #-----------------------------------------------------------------
  def doPostSuccess(self) -> None:
    '''Called after a successful action.
    '''

    self.logAction('', False)

  #-----------------------------------------------------------------
  def doPostFailure(self, exceptionStr: str) -> None:
    '''Called after a failed action.

    @param ex The exception that caused the failure.
    '''

    # Log the action failure.

    msg = "  Action failed because of exception:\n"
    msg += "-------------------------------------------------\n"
    msg += str(exceptionStr)

    self.logAction(msg, True)

  #-----------------------------------------------------------------
  def logAction(self, msg: str, wasError: bool) -> None:

    assert self.beginTime is not None

    userName = self.authMgr.userName()

    if userName != self.initialUserName:
      userName = self.initialUserName + '/' + userName

    now = datetime.utcnow()
    duration = now - self.beginTime

    actionName = self.actionFunc.__name__

    if self.webWorkerName is not None:
      actionName += '/' + self.webWorkerName + '/' + str(os.getpid())

    logMsg = ''
    logMsg += actionName + '\n'
    logMsg += "  " + self.getRequestInfo() + '\n'
    logMsg += "  " + userName + '\n'
    logMsg += "  Duration: " + str(duration) + "\n"
    logMsg += msg

    remoteAddr = self.request.remote_addr
    remoteAddr = remoteAddr if remoteAddr else ''

    Log.logAction(actionName,
                  logMsg,
                  wasError,
                  userName,
                  remoteAddr)

  #-----------------------------------------------------------------
  def setWebWorkerName(self, webWorkerName: str) -> None:
    '''Sets the web worker name to be included in any log messages.
    '''

    self.webWorkerName = webWorkerName

  #-----------------------------------------------------------------
  def getDatabaseMgr(self) -> DatabaseMgr:
    '''Returns the database under which this is running.
    '''

    return self.databaseMgr

  #-----------------------------------------------------------------
  def getRequestInfo(self) -> str:
    '''Returns the URL that called this action.
    '''
    return self.getRequestUrl(self.request, self.requestParams)

  #-----------------------------------------------------------------
  @staticmethod
  def getRequestUrl(request: Request,
                    paramMultiDict: Dict[str, List[str]]) -> str:
    '''Returns the URL that called a request.
    '''

    params = StringUtil.formatMultiDict(paramMultiDict, ".*[p|P]assword")

    ret = ("Method: " + request.method + " " +
      "URL: " + request.host_url + ' ' + request.path  + " Params: " + params)

    return ret

  #-----------------------------------------------------------------
  def createContextParam(self,
                         session: Session,
                         config: Config,
                         paramType: str) -> object:
    '''Returns a context parameter of the specified type.
    '''
    # pylint:  disable = R0911

    if paramType == 'Session':
      return session

    elif paramType == 'AuthMgr':
      return self.authMgr

    elif paramType == 'UserInfo':
      return self.authMgr.userInfo()

    elif paramType == 'Config':
      return config

    elif paramType.endswith('Config'):
      sectionName = StringUtil.initialLower(paramType[:-len('Config')])
      return config.getConfigByName(sectionName)

    elif paramType == 'Request':
      return self.request

    elif paramType == 'WebWorkerRunner':
      return WebWorkerRunner(self)

    else:
      msg = "Unknown context parameter type: " + paramType
      raise AxeException(msg)

  #-----------------------------------------------------------------
  def checkParams(self,
                  paramDescList: List[WebParam],
                  paramDict: Dict[str, List[str]],
                  fileDict: ImmutableMultiDict[str,
                                               FileStorage]) -> Tuple[List[str],
                                                                     List[str]]:
    '''Returns an ([extra], [missing]) tuple of param name lists.
    '''

    missing: List[str] = []
    for paramDesc in paramDescList:
      paramName = paramDesc.effectiveUrlName()
      if (not paramDesc.optional and
          not paramDesc.hasDefaultValue() and
          not paramDesc.paramType.isList() and
          ((not paramDesc.paramType.isFile() and
            paramName not in paramDict) or
           (paramDesc.paramType.isFile() and paramName not in fileDict))):
        missing.append(paramName)

    extra: List[str] = []
    allNames = [paramDesc.effectiveUrlName() for paramDesc in paramDescList]
    for name in paramDict:
      if name not in allNames:
        extra.append(name)

    for name in fileDict:
      if name not in allNames:
        extra.append(name)

    return extra, missing

  #-----------------------------------------------------------------
  def extraMissingResponse(self,
                           extra: List[str],
                           missing: List[str]) -> Tuple[str, Response]:
    '''Returns a Response for extra or missing params from the request.
    '''

    # Only call this function if there's at least one extra or missing.
    assert extra or missing

    msg = 'Error processing request.  '
    if extra:
      msg += 'The following extra param(s) were found: ' + ', '.join(extra)
    if extra and missing:
      msg += '\n'
    if missing:
      msg += 'The following param(s) were missing: ' + ', '.join(missing)

    # An API call will return a 400 Bad Request with a message
    # explaining what the user's mistake was.
    # An app call with bad arguments is simply a bug: return 500.
    if self.isApiRequest:
      response = Response(msg, 400)
    else:
      response = Response('Server error.', 500)

    return msg, response

  #-----------------------------------------------------------------
  def createRequestParams(self,
                          isDynamicParams: bool,
                          paramDescs: List[WebParam]) -> Dict[str, object]:

    requestParams: Dict[str, object] = OrderedDict()

    if isDynamicParams:

      # Dynamic params are all just returned as a big list called 'params'.

      queryStr: str = self.request.query_string.decode()
      paramsList: List[Tuple[str, str]] = urllib.parse.parse_qsl(queryStr)
      requestParams['params'] = paramsList

    else:

      for paramDesc in paramDescs:

        paramValue = self.createRequestParam(paramDesc)
        requestParams[paramDesc.name] = paramValue

    return requestParams

  #-----------------------------------------------------------------
  def createRequestParam(self, paramDesc: WebParam) -> object:

    try:

      paramType = paramDesc.paramType
      paramObjClass = paramDesc.objClass
      paramValue: object

      if paramType.isFile():
        paramValueFileList = self.request.files.getlist(paramDesc.name)
        paramValue = paramType.extractUploadedFiles(paramValueFileList)

      elif paramDesc.effectiveUrlName() in self.requestParams:
        paramValueStr = self.requestParams.get(paramDesc.effectiveUrlName())
        assert paramValueStr is not None
        paramValue = paramType.stringToObject(paramValueStr, paramObjClass)

        if paramValue is None and not paramDesc.optional:
          if paramDesc.hasDefaultValue():
            paramValue = paramDesc.defaultValue
          else:
            raise ParamValueException('value missing.')

      elif paramDesc.optional or paramDesc.hasDefaultValue():
        paramValue = None
        if paramDesc.hasDefaultValue():
          paramValue = paramDesc.defaultValue

      elif paramType.isList():
        # If a list key isn't present, it's an empty list.
        paramValue = []

      else:
        raise ValueError('Invalid param desc: ' + str(paramDesc))

      return paramValue

    except ParamValueException as ex:
      msg = 'Param "%s":  ' % paramDesc.effectiveUrlName()
      msg += str(ex)
      raise ParamValueException(msg) from ex

  #-----------------------------------------------------------------
  def extractParams(self, request: Request) -> Dict[str, List[str]]:
    '''Extract the parameters from an HTTP request.

    @param request A Werkzeug Request object.

    @return A nested dictionary of parameters.
    '''

    assert request.method in ('GET', 'POST')
    if request.method == 'GET':
      paramDict = request.args
    else:
      paramDict = request.form

    return paramDict.to_dict(flat=False)
