#-------------------------------------------------------------------
#  ColorBinMap.py
#
#  The ColorBinMap class.
#
#  Copyright 2016 Applied Invention, LLC
#-------------------------------------------------------------------

'''The module containing the ColorBinMap class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from ai.axe.util import StringUtil
from typing import List
from typing import Optional
from typing import Tuple
#
# Import statements go above this line.
#-------------------------------------------------------------------

RgbaTuple = Tuple[int, int, int, int]

#===================================================================
class ColorBinMap:
  '''Handles mapping bins of floating-point values to colors.

  You specify a range of values to be mapped to each color.

  A color is specified as a tuple of 4 RGBA 1-byte integers.

  You start by passing the lowest color value to the __init__()
  then, call addBin() with higher thresholds and colors.
  '''

  #-----------------------------------------------------------------
  def __init__(self, noDataColor: RgbaTuple, lowestColor: RgbaTuple) -> None:
    '''Creates a new ColorBinMap.

    @param noDataColor The color used for NODATA points in the raster.
    @param lowestColor The color used for values below the first bin border.
    '''

    assert len(noDataColor) == 4
    assert len(lowestColor) == 4

    # The color used for NODATA points in the raster.
    self.noDataColor: RgbaTuple = noDataColor

    # The colors used for binned values.
    self.colors: List[RgbaTuple] = [lowestColor]

    # The bin border values.  A list of floats.
    self.borders: List[float] = []

    # Which bin a value that matches a border value should go into.
    #
    # A list of booleans that corresponds to the list of self.border colors.
    #
    # If true, the value will go into the higher bin.
    # If false, the value will go into the lower bin.
    #
    self.borderGoesHighs: List[bool] = []

  #-----------------------------------------------------------------
  def addBin(self, border: float, goesHigh: bool, color: RgbaTuple):
    '''Adds a new bin to this map.

    Each call to addBin() should havea higher border value than the last
    call.

    @param border The minimum value for this bin.  A float.
    @param goesHigh A boolean.  If true, values that match the 'border' value
                    exactly will go into the higher bin.  If false,
                    they will go into the lower.
    @param color The RGBA color for the new bin.'''

    assert len(color) == 4

    if self.borders and border <= self.borders[-1]:
      msg = "Error.  Each call to addBin() must add a higher border.  "
      msg += "New border:  " + str(border) + "  "
      msg += "Current borders: " + str(self.borders)
      raise ValueError(msg)

    self.borders.append(border)
    self.borderGoesHighs.append(goesHigh)
    self.colors.append(color)

  #-----------------------------------------------------------------
  def colorForValue(self, value: Optional[float]) -> RgbaTuple:
    '''Returns the color for the specified value.

    @param value The value to convert to a color.

    @return An RGBA tuple of 4 integers.
    '''

    assert len(self.borders) == len(self.borderGoesHighs)
    assert len(self.borders) + 1 == len(self.colors)

    if value is None:
      return self.noDataColor

    # Default is the last color.
    colorIndex = len(self.colors) - 1

    for i in range(len(self.borders)):

      # If the value is lower than the threshold, use that color.
      if value < self.borders[i] or (value == self.borders[i] and
                                     not self.borderGoesHighs[i]):
        colorIndex = i
        break

    return self.colors[colorIndex]

  #-----------------------------------------------------------------
  def getBinSql(self) -> str:
    '''Returns SQL to convert data values to bin integers.

    This SQL can be used in the ST_MapAlgebra() function to convert
    data values to bin integers.

    @return A SQL string.'''

    assert self.noDataColor
    assert self.borders
    assert len(self.borders) == len(self.borderGoesHighs)
    assert len(self.borders) + 1 == len(self.colors)

    sql = 'case '

    # Add the low one-sided bin.

    border, goesHigh = self.borders[0], self.borderGoesHighs[0]
    sign = self.upperComparison(goesHigh)

    sql += 'when [rast.val] %s %s then 1 ' % (sign, border)

    # Add the middle bins.

    for i in range(1, len(self.borders)):

      border, goesHigh = self.borders[i - 1], self.borderGoesHighs[i - 1]
      sign = self.lowerComparison(goesHigh)

      sql += 'when [rast.val] %s %s and ' % (sign, border)

      border, goesHigh = self.borders[i], self.borderGoesHighs[i]
      sign = self.upperComparison(goesHigh)

      sql += '[rast.val] %s %s then %s ' % (sign, border, i + 1)

    # Add the high one-sided bin.

    border, goesHigh = self.borders[-1], self.borderGoesHighs[-1]
    sign = self.lowerComparison(goesHigh)

    index = len(self.borders) + 1

    sql += 'when [rast.val] %s %s then %s end' % (sign, border, index)

    return sql

  #-----------------------------------------------------------------
  def getGdalColorMap(self) -> str:
    '''Returns a GDAL color map string.

    After data values have been converted to integers via the getBinSql() SQL,
    this color mapping can be used with ST_ColorMap() to convert those
    integers to the correct color.

    @return A GDAL color map string.'''

    colorStrs = []
    for i, color in reversed(list(enumerate(self.colors))):
      colorStr = '%s %s %s %s %s' % ((i + 1,) + color) # type: ignore
      colorStrs.append(colorStr)

    colorStrs.append('0 0 0 0 0')

    return '\n'.join(colorStrs)

  #-----------------------------------------------------------------
  def upperComparison(self, goesHigh: bool) -> str:
    '''Returns the comparison operator for the high side of a bin.

    @param goesHigh Whether values that match the border value exactly
                    go into the next higher bin or lower bin.

    @return A string comparison operator.
    '''

    return '<' if goesHigh else '<='

  #-----------------------------------------------------------------
  def lowerComparison(self, goesHigh: bool) -> str:
    '''Returns the comparison operator for the low side of a bin.

    @param goesHigh Whether values that match the border value exactly
                    go into the next higher bin or lower bin.

    @return A string comparison operator.
    '''

    return '>=' if goesHigh else '>'

  #-----------------------------------------------------------------
  def __repr__(self) -> str:
    '''Format this object into a string.

    @return The formatted string for this object.
    '''

    binStr = '\n'
    binStr += str(self.colors[0]) + '\n'

    num = min(len(self.borders), len(self.borderGoesHighs), len(self.colors))

    for i in range(num):

      binStr += '               '
      binStr += str(self.borders[i]) + ' '
      binStr += 'goes high' if self.borderGoesHighs[i] else 'goes low'
      binStr += '\n'
      binStr += str(self.colors[i + 1]) + '\n'

    attrs = ['noDataColor',
             ]
    extraMembers: List[Tuple[str, object]] = [('bins', binStr)]
    return StringUtil.formatRepr(self, attrs, extraMembers)
