#-------------------------------------------------------------------
#  ConnectionUrl.py
#
#  The ConnectionUrl module.
#
#  Copyright 2017 Applied Invention, LLC
#-------------------------------------------------------------------

'''Functions to parse and manipulate a DB connection string.
'''

#-------------------------------------------------------------------
# Import statements go here.
#
from collections import OrderedDict
from typing import Dict
from typing import List
import urllib.parse
#
# Import statements go above this line.
#-------------------------------------------------------------------


#-------------------------------------------------------------------
def unittestDbName(url: str, remoteUnitTests: bool) -> str:
  '''Returns a copy of the passed-in URL with the DB name set to test.

  The DB name is expected to be foowebapp.

  The returned DB name will be footest.

  @param url A connection URL string.
  @param remoteUnitTests Whether running unit tests on remote machines is
                         allowed.

  @return A connection URL string.
  '''

  # Most SQL Alchemy connection URLs are:
  #
  # dialect+driver://username:password@host:port/database
  #
  # However, there is also the MS-SQL 'passthrough' format:
  #
  # mssql+pyodbc:///?odbc_connect=CONNECT_STRING

  if url.startswith('mssql') and 'odbc_connect=' in url:

    return unittestDbNameMssql(url)

  else:

    return unittestDbNameAlchemy(url, remoteUnitTests)

#-------------------------------------------------------------------
def unittestDbNameAlchemy(url: str, remoteUnitTests: bool) -> str:
  '''Returns a copy of the passed-in URL with the DB name set to test.

  The DB name is expected to be foowebapp.

  The returned DB name will be footest.

  @param url A connection URL string.

  @return A connection URL string.
  '''

  # Note the URL should be a PostgreSQL URL along the lines of:
  #
  # postgresql+psycopg2://ramwebapp:ramwebapp@localhost:5432/ramtest

  # Break up the URL into components.

  scheme, rest = url.split('://')
  netloc, dbName = rest.split('/', 1)
  userName, host = netloc.split('@')

  if not remoteUnitTests:

    # Unit tests are always run on localhost.
    host = 'localhost:5432'

  # The name of the test database is the name of the webapp database
  # with 'webapp' replaced with 'test'.

  if dbName.endswith('webapp'):
    dbName = dbName[:-6] + 'test'
  else:
    dbName += 'test'

  # Put the URL back together.

  url = scheme + '://' + userName + '@' + host + '/' + dbName

  # Build the new config.

  return url

#-------------------------------------------------------------------
def unittestDbNameMssql(url: str) -> str:
  '''Returns a copy of the passed-in URL with the DB name set to test.

  The DB name is expected to be foowebapp.

  The returned DB name will be footest.

  @param url A connection URL string.

  @return A connection URL string.
  '''

  # Note the URL should be a MS SQL Server URL along the lines of:
  #
  # mssql+pyodbc:///?odbc_connect=CONNECT_STRING
  #
  # where CONNECT_STRING is a list of KEY=VALUE pairs joined with semi-colons
  # such as:
  #
  # Database=seamonkeywebapp;Uid=mbrady@seamonkeydev;

  prefix, encodedConnectStr = url.split('odbc_connect=', 1)

  connectItems = parseMssqlConnect(encodedConnectStr)

  if 'Database' in connectItems:

    dbName = connectItems['Database']

    # The name of the test database is the name of the webapp database
    # with 'webapp' replaced with 'test'.

    if dbName.endswith('webapp'):
      dbName = dbName[:-6] + 'test'
    else:
      dbName += 'test'

    connectItems['Database'] = dbName

  return prefix + 'odbc_connect=' + formatMssqlConnect(connectItems)

#-------------------------------------------------------------------
def parseMssqlConnect(connectStr: str) -> Dict[str, str]:
  '''Parse the string into a dictionary.
  '''

  # The connect string is a URL-encoded list of KEY=VALUE pairs joined
  # with semi-colons such as:
  #
  # Database=seamonkeywebapp;Uid=mbrady@seamonkeydev;

  encodedConnectStr = connectStr

  connectStr = urllib.parse.unquote(encodedConnectStr)

  if connectStr.endswith(';'):
    connectStr = connectStr[:-1]

  connectItemStrs: List[str] = connectStr.split(';')

  connectItemDict: Dict[str, str] = OrderedDict()

  for itemStr in connectItemStrs:
    key, value = itemStr.split('=', 1)
    connectItemDict[key] = value

  return connectItemDict

#-------------------------------------------------------------------
def formatMssqlConnect(connectDict: Dict[str, str]) -> str:
  '''Formats the specified dictionary into a connection string.
  '''

  # The connect string is a URL-encoded list of KEY=VALUE pairs joined
  # with semi-colons such as:
  #
  # Database=seamonkeywebapp;Uid=mbrady@seamonkeydev;

  connectItemStrs = []

  for key, value in connectDict.items():
    connectItemStr = '%s=%s' % (key, value)
    connectItemStrs.append(connectItemStr)

  connectStr = ';'.join(connectItemStrs) + ';'

  encodedConnectStr = urllib.parse.quote_plus(connectStr)

  return encodedConnectStr
