#-------------------------------------------------------------------
#  NumberBinMap.py
#
#  The NumberBinMap class.
#
#  Copyright 2016 Applied Invention, LLC
#-------------------------------------------------------------------

'''The module containing the NumberBinMap class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from ai.axe.util import StringUtil
from typing import List
from typing import Optional
from typing import Tuple
#
# Import statements go above this line.
#-------------------------------------------------------------------


#===================================================================
class NumberBinMap:
  '''Handles mapping bins of floating-point values to numbers.

  You specify a range of values to be mapped to each number.

  You start by passing the lowest number value to the __init__()
  then, call addBin() with higher thresholds and numbers.
  '''

  #-----------------------------------------------------------------
  def __init__(self, noDataNumber: float, lowestNumber: float) -> None:
    '''Creates a new NumberBinMap.

    @param noDataNumber The number used for NODATA points in the raster.
    @param lowestNumber The number used for values below the first bin border.
    '''

    # The number used for NODATA points in the raster.
    self.noDataNumber: float = noDataNumber

    # The numbers used for binned values.
    self.numbers: List[float] = [lowestNumber]

    # The bin border values.  A list of floats.
    self.borders: List[float] = []

    # Which bin a value that matches a border value should go into.
    #
    # A list of booleans that corresponds to the list of self.border numbers.
    #
    # If true, the value will go into the higher bin.
    # If false, the value will go into the lower bin.
    #
    self.borderGoesHighs: List[bool] = []

  #-----------------------------------------------------------------
  def addBin(self, border: float, goesHigh: bool, number: float):
    '''Adds a new bin to this map.

    Each call to addBin() should havea higher border value than the last
    call.

    @param border The minimum value for this bin.  A float.
    @param goesHigh A boolean.  If true, values that match the 'border' value
                    exactly will go into the higher bin.  If false,
                    they will go into the lower.
    @param number The number for the new bin.'''

    if self.borders and border <= self.borders[-1]:
      msg = "Error.  Each call to addBin() must add a higher border.  "
      msg += "New border:  " + str(border) + "  "
      msg += "Current borders: " + str(self.borders)
      raise ValueError(msg)

    self.borders.append(border)
    self.borderGoesHighs.append(goesHigh)
    self.numbers.append(number)

  #-----------------------------------------------------------------
  def numberForValue(self, value: Optional[float]) -> float:
    '''Returns the number for the specified value.

    @param value The value to convert to a number.

    @return A number that the value has been converted to.
    '''

    assert len(self.borders) == len(self.borderGoesHighs)
    assert len(self.borders) + 1 == len(self.numbers)

    if value is None:
      return self.noDataNumber

    # Default is the last number.
    numberIndex = len(self.numbers) - 1

    for i in range(len(self.borders)):

      # If the value is lower than the threshold, use that number.
      if value < self.borders[i] or (value == self.borders[i] and
                                     not self.borderGoesHighs[i]):
        numberIndex = i
        break

    return self.numbers[numberIndex]

  #-----------------------------------------------------------------
  def getBinSql(self) -> str:
    '''Returns SQL to convert data values to bin numbers.

    This SQL can be used in the ST_MapAlgebra() function to convert
    data values to bin numbers.

    @return A SQL string.'''

    assert self.noDataNumber
    assert self.borders
    assert len(self.borders) == len(self.borderGoesHighs)
    assert len(self.borders) + 1 == len(self.numbers)

    sql = 'case '

    # Add the low one-sided bin.

    border, goesHigh = self.borders[0], self.borderGoesHighs[0]
    sign = self.upperComparison(goesHigh)

    sql += 'when [rast.val] %s %s then %s ' % (sign, border, self.numbers[0])

    # Add the middle bins.

    for i in range(1, len(self.borders)):

      border, goesHigh = self.borders[i - 1], self.borderGoesHighs[i - 1]
      sign = self.lowerComparison(goesHigh)

      sql += 'when [rast.val] %s %s and ' % (sign, border)

      border, goesHigh = self.borders[i], self.borderGoesHighs[i]
      sign = self.upperComparison(goesHigh)

      sql += '[rast.val] %s %s then %s ' % (sign, border, self.numbers[i])

    # Add the high one-sided bin.

    border, goesHigh = self.borders[-1], self.borderGoesHighs[-1]
    sign = self.lowerComparison(goesHigh)

    number = self.numbers[len(self.borders)]

    sql += 'when [rast.val] %s %s then %s end' % (sign, border, number)

    return sql

  #-----------------------------------------------------------------
  def upperComparison(self, goesHigh: bool) -> str:
    '''Returns the comparison operator for the high side of a bin.

    @param goesHigh Whether values that match the border value exactly
                    go into the next higher bin or lower bin.

    @return A string comparison operator.
    '''

    return '<' if goesHigh else '<='

  #-----------------------------------------------------------------
  def lowerComparison(self, goesHigh: bool) -> str:
    '''Returns the comparison operator for the low side of a bin.

    @param goesHigh Whether values that match the border value exactly
                    go into the next higher bin or lower bin.

    @return A string comparison operator.
    '''

    return '>=' if goesHigh else '>'

  #-----------------------------------------------------------------
  def __repr__(self) -> str:
    '''Format this object into a string.

    @return The formatted string for this object.
    '''

    binStr = '\n'
    binStr += str(self.numbers[0]) + '\n'

    num = min(len(self.borders), len(self.borderGoesHighs), len(self.numbers))

    for i in range(num):

      binStr += '               '
      binStr += str(self.borders[i]) + ' '
      binStr += 'goes high' if self.borderGoesHighs[i] else 'goes low'
      binStr += '\n'
      binStr += str(self.numbers[i + 1]) + '\n'

    attrs = ['noDataNumber',
             ]
    extraMembers: List[Tuple[str, object]] = [('bins', binStr)]
    return StringUtil.formatRepr(self, attrs, extraMembers)
