#-------------------------------------------------------------------
#  AxeSimpleTestCase.py
#
#  The AxeSimpleTestCase class.
#
#  Copyright 2016 Applied Invention, LLC
#-------------------------------------------------------------------

'''The module containing the AxeSimpleTestCase class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from typing import cast
from typing import List
from typing import Optional
import unittest
#
# Import statements go above this line.
#-------------------------------------------------------------------


#===================================================================
class AxeSimpleTestCase(unittest.TestCase):
  '''Unittest TestCase with AXE-specific functionality added.

  This is a 'simple' test case because it does not include any
  help for testing database-related functionality.  You should use
  ai.axe.db.unittest.AxeUnitTest for that.
  '''

  # pylint: disable=W0221

  #-----------------------------------------------------------------
  # Make assertEqual() messages be added to the default message,
  # rather than replacing it.
  longMessage = True

  #-----------------------------------------------------------------
  def __init__(self, *args, **kwargs) -> None:
    '''Creates a new AxeSimpleTestCase.
    '''

    unittest.TestCase.__init__(self, *args, **kwargs)

    # In assertEquals() error messages, show string diffs no matter
    # how long the strings are.
    #
    # Inherited from TestCase.
    self.maxDiff = None

  #-----------------------------------------------------------------
  # pylint: disable=signature-differs
  def assertEqual(self,
                  first: object,
                  second: object,
                  *args,
                  **kwargs) -> None:
    '''Override the built-in assertEqual() for Axe-specific behavior.
    '''

    # assertEqual() is supposed to automatically go to assertMultiLineEqual()
    # when the arguments are strings, but for some reason it doesn't.
    if isinstance(first, str) and isinstance(second, str):
      self.assertMultiLineEqual(first, second, *args, **kwargs)

    # If the user is checking with a tolerance, redirect to assertAlmostEqual().
    elif 'places' in kwargs or 'delta' in kwargs:

      if (isinstance(first, (list, tuple)) and first and
          isinstance(first[0], (list, tuple)) and
          isinstance(second, (list, tuple)) and second and
          isinstance(second[0], (list, tuple))):
        self.assertAlmostEqualList2d(cast(List[List[float]], first),
                                     cast(List[List[float]], second),
                                     *args,
                                     **kwargs)

      elif isinstance(first, list) and isinstance(second, list):
        self.assertListAlmostEqual(first, second, *args, **kwargs)

      else:
        first = cast(float, first)
        second = cast(float, second)
        unittest.TestCase.assertAlmostEqual(self, first, second,
                                            *args, **kwargs)

    else:
      unittest.TestCase.assertEqual(self, first, second, *args, **kwargs)

  #-----------------------------------------------------------------
  def assertEquals(self,
                   first: object,
                   second: object,
                   *args,
                   **kwargs) -> None:
    '''Override the built-in assertEquals() for GMC-specific behavior.
    '''

    # Suppress unused variable warning
    # pylint: disable=self-assigning-variable
    first, second, args, kwargs = first, second, args, kwargs

    errMsg = "Don't use 'assertEquals()'.  Use 'assertEqual()' (no final 's'!) "
    errMsg += "instead."

    msg = ''
    msg = self._formatMessage(msg, errMsg) # type: ignore
    raise self.failureException(msg)

  #-----------------------------------------------------------------
  def assertRectangularList2d(self,
                              list2d: object,
                              msg: Optional[str] = None) -> None:
    '''Fail if the 2-D list is ragged.
    '''

    if not list2d:
      return

    assert isinstance(list2d, (list, tuple))
    assert isinstance(list2d[0], (list, tuple))

    numColumns = len(list2d[0])

    for i, row in enumerate(list2d):
      if len(row) != numColumns:
        errMsg = "Row %s has %s columns, but expected %s"
        errMsg = errMsg % (i, len(row), numColumns)

        msg = self._formatMessage(msg, errMsg) # type: ignore
        raise self.failureException(msg)

  #-----------------------------------------------------------------
  def assertAlmostEqual(self,
                        first: object,
                        second: object,
                        *args,
                        **kwargs) -> None:
    '''Like the built-in unittest.assertAlmostEqual(), but first arg is a msg.

    The built-in assertAlmostEqual requires you to say:

      self.assertAlmostEqual(1.0, 1.1, msg="values")

    This one allows you to drop the 'msg' keyword:

      self.assertAlmostEqual(1.0, 1.1, "values")
    '''

    # Assume the first arg is a 'msg'.

    if args:

      msg = args[0]
      args = args[1:]
      kwargs['msg'] = msg

    # Pass through to the built-in assertAlmostEqual().

    first = cast(float, first)
    second = cast(float, second)
    unittest.TestCase.assertAlmostEqual(self, first, second, *args, **kwargs)

  #-----------------------------------------------------------------
  def assertAlmostEqualList2d(self,
                              first: List[List[float]],
                              second: List[List[float]],
                              places: Optional[int] = None,
                              msg: Optional[str] = None,
                              delta: Optional[float] = None) -> None:
    """Fail if the two lists of values are unequal as determined by their
    difference rounded to the given number of decimal places
    (default 7) and comparing to zero, or by comparing that the
    between the two objects is more than the given delta.
    """

    if first == second:
      # shortcut
      return

    self.assertRectangularList2d(first, msg=msg)
    self.assertRectangularList2d(second, msg=msg)

    if len(first) != len(second):
      errMsg = 'Wrong num rows: expected %s, actual %s.'
      errMsg = errMsg % (len(first), len(second))
      msg = self._formatMessage(msg, errMsg) # type: ignore
      raise self.failureException(msg)
    elif first and second and len(first[0]) != len(second[0]):
      errMsg = 'Wrong num columns: expected %s, actual %s.'
      errMsg = errMsg % (len(first[0]), len(second[0]))
      msg = self._formatMessage(msg, errMsg) # type: ignore
      raise self.failureException(msg)

    for i in range(len(first)):
      for j in range(len(first[0])):
        if msg is None:
          msg = ''
        else:
          msg += ' '
        msg += '(array element [%s][%s])' % (i, j)

        self.assertAlmostEqual(first[i][j], second[i][j],
                               places=places, msg=msg, delta=delta)

  # -----------------------------------------------------------------
  def assertListAlmostEqual(self,
                            list1: List[float],
                            list2: List[float],
                            places: int = 7,
                            msg: Optional[str] = None,
                            delta: Optional[float] = None):
    ''' Helper function to check if two lists are almost equal.
    '''

    msgSuffix = ''
    if msg:
      msgSuffix = ': ' + msg

    self.assertEqual(len(list1), len(list2), 'bad list length' + msgSuffix)

    for index, (a, b) in enumerate(zip(list1, list2)):

      indexMsg = "list index " + str(index) + msgSuffix

      self.assertAlmostEqual(a, b, places, msg=indexMsg, delta=delta)
