#-------------------------------------------------------------------
#  WkbRasterReader.py
#
#  The WkbRasterReader class.
#
#  Copyright 2016 Applied Invention, LLC
#-------------------------------------------------------------------

'''The module containing the WkbRasterReader class.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from . import Raster
from . import RasterBand
import binascii
import struct
#
# Import statements go above this line.
#-------------------------------------------------------------------


#===================================================================
class WkbRasterReader:
  '''Reads Well-Known Binary data into a python object.
  '''

  #-----------------------------------------------------------------
  def __init__(self):
    '''Creates a new WkbRasterReader.
    '''

    pass

  #-----------------------------------------------------------------
  def read(self, dataStr):
    '''Reads the binary data and returns an object.

    @param dataStr A string of hex-encoded binary data.

    @return a python object.
    '''

    data = binascii.unhexlify(dataStr)
    dataIndex = 0

    if len(data) < 50:
      msg = "WTB data too short. Expected at least 50 bytes. Data: " + str(data)
      raise ValueError(msg)

    # Read the byte order.

    byteOrderData, dataIndex = self.readByte(data, dataIndex)
    if byteOrderData == 0:
      byteOrder = '>' # Big endian
    elif byteOrderData == 1:
      byteOrder = '<' # Little endian
    else:
      msg = "Invalid byte order: " + byteOrderData

    # Read the version.

    version, dataIndex = self.readValue('uint16', byteOrder, data, dataIndex)

    if version > 0:
      msg = "Unsupported WKB Raster version: " + str(version)
      raise ValueError(msg)

    numBands, dataIndex = self.readValue('uint16', byteOrder, data, dataIndex)
    scaleX, dataIndex = self.readValue('float64', byteOrder, data, dataIndex)
    scaleY, dataIndex = self.readValue('float64', byteOrder, data, dataIndex)
    upperLeftX, dataIndex = self.readValue('float64', byteOrder, data,
                                           dataIndex)
    upperLeftY, dataIndex = self.readValue('float64', byteOrder, data,
                                           dataIndex)
    skewX, dataIndex = self.readValue('float64', byteOrder, data, dataIndex)
    skewY, dataIndex = self.readValue('float64', byteOrder, data, dataIndex)

    srid, dataIndex = self.readValue('int32', byteOrder, data, dataIndex)

    width, dataIndex = self.readValue('uint16', byteOrder, data, dataIndex)
    height, dataIndex = self.readValue('uint16', byteOrder, data, dataIndex)

    bands = []
    for unused in range(numBands):

      band, dataIndex = self.readBand(byteOrder, data, dataIndex, width, height)
      bands.append(band)

    raster = Raster(width, height, upperLeftX, upperLeftY, scaleX, scaleY,
                    skewX, skewY, srid, bands)
    return raster

  #-----------------------------------------------------------------
  def readBand(self, byteOrder, data, dataIndex, width, height):
    '''Reads the binary data and returns a RasterBand object.

    @param byteOrder The byte order to use to read the data.
    @param data A string of hex-encoded binary data.
    @param dataIndex The current read index in the data.
    @param width The number of pixels wide the data is.
    @param height The number of pixels high the data is.

    @return a tuple: (RasterBand object, new dataIndex)
    '''

    # pylint: disable=C0200

    packedData, dataIndex = self.readByte(data, dataIndex)

    isOffline = (packedData & 0b10000000) > 0
    hasNoData = (packedData & 0b01000000) > 0
    isAllNoData = (packedData & 0b00100000) > 0
    pixelTypeInt = (packedData & 0b00001111)

    pixelType = RasterBand.pixelTypeCodes[pixelTypeInt]

    noData, dataIndex = self.readValue(pixelType, byteOrder, data, dataIndex)

    if not hasNoData:
      noData = None

    if isOffline:
      msg = "Offline data storage not supported by this reader."
      raise ValueError(msg)

    numValues = height * width
    flatValue, dataIndex = self.readValue(pixelType, byteOrder, data, dataIndex,
                                          numValues=numValues)

    values = []
    for rowStart in range(0, len(flatValue), width):

      valueRow = list(flatValue[rowStart:rowStart + width])

      # Convert nodata values to None.
      if hasNoData:
        for i in range(len(valueRow)):
          if valueRow[i] == noData:
            valueRow[i] = None

      values.append(valueRow)

    band = RasterBand(pixelTypeInt, noData, isAllNoData, values)

    return band, dataIndex

  #-----------------------------------------------------------------
  def readByte(self, data, dataIndex):
    '''Reads the binary data and returns an int.

    @param data A string of hex-encoded binary data.
    @param dataIndex The current read index in the data.

    @return a tuple: (int, new dataIndex)
    '''

    # Don't need a byte order since we're reading a single byte.
    byteOrder = ''

    return self.readValueCode(byteOrder, 'B', data, dataIndex, numValues=1)

  #-----------------------------------------------------------------
  def readValue(self, valueType, byteOrder, data, dataIndex, numValues=1):
    '''Reads the binary data and returns an int.

    @param valueType A type string such as 'uint16'.
    @param byteOrder The byteOrder code.
    @param data A string of hex-encoded binary data.
    @param dataIndex The current read index in the data.
    @param numValues The number of values to read.  This will be one for
                     a scalar or greater than one for a tuple.

    @return a tuple: (int, new dataIndex)
    '''

    # Map of type names to struct format codes.
    valueTypes = {
      'bool' : 'B',
      'int2' : 'B',
      'int4' : 'B',
      'int8' : 'b',
      'uint8' : 'B',
      'int16' : 'h',
      'uint16' : 'H',
      'int32' : 'i',
      'uint32' : 'I',
      'float32' : 'f',
      'float64' : 'd',
     }

    if valueType not in valueTypes:
      raise ValueError('Invalid valueType: %s' % valueType)

    code = valueTypes[valueType]

    return self.readValueCode(byteOrder, code, data, dataIndex, numValues)

  #-----------------------------------------------------------------
  def readValueCode(self, byteOrder, code, data, dataIndex, numValues):
    '''Reads the binary data and returns a single value.

    @param byteOrder The struct.unpack byte order format string.
    @param code The struct.unpack type format string.
    @param data A string of hex-encoded binary data.
    @param dataIndex The current read index in the data.
    @param numValues The number of values to read.  This will be one for
                     a scalar or greater than one for a tuple.

    @return a tuple: (int, new dataIndex)
    '''

    assert numValues >= 1

    # Number of bytes of the struct format code sizes.
    typeSizes = {
      'b' : 1,
      'B' : 1,
      'h' : 2,
      'H' : 2,
      'i' : 4,
      'I' : 4,
      'q' : 8,
      'Q' : 8,
      'f' : 4,
      'd' : 8,
    }

    # The number of bytes to read.
    size = typeSizes[code]

    startIndex = dataIndex
    endIndex = startIndex + (size * numValues)

    fmt = byteOrder + (code * numValues)

    # For Python 2, need to convert from unicode to string.
    fmt = str(fmt)

    dataSlice = data[startIndex:endIndex]
    value = struct.unpack(fmt, dataSlice)

    # If it's a single value, unpack it from the tuple.
    if numValues == 1:
      (value, ) = value # pylint: disable=self-assigning-variable

    return (value, endIndex)

  #-----------------------------------------------------------------
  def dump(self, data, dataIndex):
    return data[dataIndex:dataIndex + 20].encode('hex')
