#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import print_function

import argparse
import base64
import collections
import csv
import errno
import json
import numbers
import os
import re
import readline
import ssl
import sys
import time
import unicodedata
import urllib2

class DruidSqlException(Exception):
  def friendly_message(self):
    return self.message if self.message else "Query failed"

  def write_to(self, f):
    f.write('\x1b[31m')
    f.write(self.friendly_message())
    f.write('\x1b[0m')
    f.write('\n')
    f.flush()

def do_query_with_args(url, sql, context, args):
  return do_query(url, sql, context, args.timeout, args.user, args.ignore_ssl_verification, args.cafile, args.capath)

def do_query(url, sql, context, timeout, user, ignore_ssl_verification, ca_file, ca_path):
  json_decoder = json.JSONDecoder(object_pairs_hook=collections.OrderedDict)
  try:
    if timeout <= 0:
      timeout = None
      query_context = context
    elif int(context.get('timeout', 0)) / 1000 < timeout:
      query_context = context.copy()
      query_context['timeout'] = timeout * 1000;

    sql_json = json.dumps({'query' : sql, 'context' : query_context})

    # SSL stuff
    ssl_context =  None;
    if (ignore_ssl_verification or ca_file != None or ca_path != None):
      ssl_context = ssl.create_default_context()
      if (ignore_ssl_verification):
        ssl_context.check_hostname = False
        ssl_context.verify_mode = ssl.CERT_NONE
      else:
        ssl_context.load_verify_locations(cafile=ca_file, capath=ca_path)

    req = urllib2.Request(url, sql_json, {'Content-Type' : 'application/json'})

    if user:
      req.add_header("Authorization", "Basic %s" % base64.b64encode(user))

    response = urllib2.urlopen(req, None, timeout, context=ssl_context)

    first_chunk = True
    eof = False
    buf = ''

    while not eof or len(buf) > 0:
      while True:
        try:
          # Remove starting ','
          buf = buf.lstrip(',')
          obj, sz = json_decoder.raw_decode(buf)
          yield obj
          buf = buf[sz:]
        except ValueError as e:
          # Maybe invalid JSON, maybe partial object; it's hard to tell with this library.
          if eof and buf.rstrip() == ']':
            # Stream done and all objects read.
            buf = ''
            break
          elif eof or len(buf) > 256 * 1024:
            # If we read more than 256KB or if it's eof then report the parse error.
            raise
          else:
            # Stop reading objects, get more from the stream instead.
            break

      # Read more from the http stream
      if not eof:
        chunk = response.read(8192)
        if chunk:
          buf = buf + chunk
          if first_chunk:
            # Remove starting '['
            buf = buf.lstrip('[')
        else:
          # Stream done. Keep reading objects out of buf though.
          eof = True

  except urllib2.URLError as e:
    raise_friendly_error(e)

def raise_friendly_error(e):
  if isinstance(e, urllib2.HTTPError):
    text = e.read().strip()
    error_obj = {}
    try:
      error_obj = dict(json.loads(text))
    except:
      pass
    if e.code == 500 and 'errorMessage' in error_obj:
      error_text = ''
      if error_obj['error'] != 'Unknown exception':
        error_text = error_text + error_obj['error'] + ': '
      if error_obj['errorClass']:
        error_text = error_text + str(error_obj['errorClass']) + ': '
      error_text = error_text + str(error_obj['errorMessage'])
      if error_obj['host']:
        error_text = error_text + ' (' + str(error_obj['host']) + ')'
      raise DruidSqlException(error_text)
    elif e.code == 405:
      error_text = 'HTTP Error {0}: {1}\n{2}'.format(e.code, e.reason + " - Are you using the correct broker URL and " +\
      "is druid.sql.enabled set to true on your broker?", text)
      raise DruidSqlException(error_text)
    else:
      raise DruidSqlException("HTTP Error {0}: {1}\n{2}".format(e.code, e.reason, text))
  else:
    raise DruidSqlException(str(e))

def to_utf8(value):
  if value is None:
    return ""
  elif isinstance(value, unicode):
    return value.encode("utf-8")
  else:
    return str(value)

def to_tsv(values, delimiter):
  return delimiter.join(to_utf8(v).replace(delimiter, '') for v in values)

def print_csv(rows, header):
  csv_writer = csv.writer(sys.stdout)
  first = True
  for row in rows:
    if first and header:
      csv_writer.writerow(list(to_utf8(k) for k in row.keys()))
      first = False

    values = []
    for key, value in row.iteritems():
      values.append(to_utf8(value))

    csv_writer.writerow(values)

def print_tsv(rows, header, tsv_delimiter):
  first = True
  for row in rows:
    if first and header:
      print(to_tsv(row.keys(), tsv_delimiter))
      first = False

    values = []
    for key, value in row.iteritems():
      values.append(value)

    print(to_tsv(values, tsv_delimiter))

def print_json(rows):
  for row in rows:
    print(json.dumps(row))

def table_to_printable_value(value):
  # Unicode string, trimmed with control characters removed
  if value is None:
    return u"NULL"
  else:
    return to_utf8(value).strip().decode('utf-8').translate(dict.fromkeys(range(32)))

def table_compute_string_width(v):
  normalized = unicodedata.normalize('NFC', v)
  width = 0
  for c in normalized:
    ccategory = unicodedata.category(c)
    cwidth = unicodedata.east_asian_width(c)
    if ccategory == 'Cf':
      # Formatting control, zero width
      pass
    elif cwidth == 'F' or cwidth == 'W':
      # Double-wide character, prints in two columns
      width = width + 2
    else:
      # All other characters
      width = width + 1
  return width

def table_compute_column_widths(row_buffer):
  widths = None
  for values in row_buffer:
    values_widths = [table_compute_string_width(v) for v in values]
    if not widths:
      widths = values_widths
    else:
      i = 0
      for v in values:
        widths[i] = max(widths[i], values_widths[i])
        i = i + 1
  return widths

def table_print_row(values, column_widths, column_types):
  vertical_line = u'\u2502'.encode('utf-8')
  for i in xrange(0, len(values)):
    padding = ' ' * max(0, column_widths[i] - table_compute_string_width(values[i]))
    if column_types and column_types[i] == 'n':
      print(vertical_line + ' ' + padding + values[i].encode('utf-8') + ' ', end="")
    else:
      print(vertical_line + ' ' + values[i].encode('utf-8') + padding + ' ', end="")
  print(vertical_line)

def table_print_header(values, column_widths):
  # Line 1
  left_corner = u'\u250C'.encode('utf-8')
  horizontal_line = u'\u2500'.encode('utf-8')
  top_tee = u'\u252C'.encode('utf-8')
  right_corner = u'\u2510'.encode('utf-8')
  print(left_corner, end="")
  for i in xrange(0, len(column_widths)):
    print(horizontal_line * max(0, column_widths[i] + 2), end="")
    if i + 1 < len(column_widths):
      print(top_tee, end="")
  print(right_corner)

  # Line 2
  table_print_row(values, column_widths, None)

  # Line 3
  left_tee = u'\u251C'.encode('utf-8')
  cross = u'\u253C'.encode('utf-8')
  right_tee = u'\u2524'.encode('utf-8')
  print(left_tee, end="")
  for i in xrange(0, len(column_widths)):
    print(horizontal_line * max(0, column_widths[i] + 2), end="")
    if i + 1 < len(column_widths):
      print(cross, end="")
  print(right_tee)

def table_print_bottom(column_widths):
  left_corner = u'\u2514'.encode('utf-8')
  right_corner = u'\u2518'.encode('utf-8')
  bottom_tee = u'\u2534'.encode('utf-8')
  horizontal_line = u'\u2500'.encode('utf-8')
  print(left_corner, end="")
  for i in xrange(0, len(column_widths)):
    print(horizontal_line * max(0, column_widths[i] + 2), end="")
    if i + 1 < len(column_widths):
      print(bottom_tee, end="")
  print(right_corner)

def table_print_row_buffer(row_buffer, column_widths, column_types):
  first = True
  for values in row_buffer:
    if first:
      table_print_header(values, column_widths)
      first = False
    else:
      table_print_row(values, column_widths, column_types)

def print_table(rows):
  start = time.time()
  nrows = 0
  first = True

  # Buffer some rows before printing.
  rows_to_buffer = 500
  row_buffer = []
  column_types = []
  column_widths = None

  for row in rows:
    nrows = nrows + 1

    if first:
      row_buffer.append([table_to_printable_value(k) for k in row.keys()])
      for k in row.keys():
        if isinstance(row[k], numbers.Number):
          column_types.append('n')
        else:
          column_types.append('s')
      first = False

    values = [table_to_printable_value(v) for k, v in row.iteritems()]
    if rows_to_buffer > 0:
      row_buffer.append(values)
      rows_to_buffer = rows_to_buffer - 1
    else:
      if row_buffer:
        column_widths = table_compute_column_widths(row_buffer)
        table_print_row_buffer(row_buffer, column_widths, column_types)
        del row_buffer[:]
      table_print_row(values, column_widths, column_types)

  if row_buffer:
    column_widths = table_compute_column_widths(row_buffer)
    table_print_row_buffer(row_buffer, column_widths, column_types)

  if column_widths:
    table_print_bottom(column_widths)

  print("Retrieved {0:,d} row{1:s} in {2:.2f}s.".format(nrows, 's' if nrows != 1 else '', time.time() - start))
  print("")

def display_query(url, sql, context, args):
  rows = do_query_with_args(url, sql, context, args)

  if args.format == 'csv':
    print_csv(rows, args.header)
  elif args.format == 'tsv':
    print_tsv(rows, args.header, args.tsv_delimiter)
  elif args.format == 'json':
    print_json(rows)
  elif args.format == 'table':
    print_table(rows)

def sql_literal_escape(s):
  if s is None:
    return "''"
  elif isinstance(s, unicode):
    ustr = s
  else:
    ustr = str(s).decode('utf-8')

  escaped = [u"U&'"]

  for c in ustr:
    ccategory = unicodedata.category(c)
    if ccategory.startswith('L') or ccategory.startswith('N') or c == ' ':
      escaped.append(c)
    else:
      escaped.append(u'\\')
      escaped.append('%04x' % ord(c))

  escaped.append("'")
  return ''.join(escaped)

def make_readline_completer(url, context, args):
  starters = [
    'EXPLAIN PLAN FOR',
    'SELECT'
  ]

  middlers = [
    'FROM',
    'WHERE',
    'GROUP BY',
    'ORDER BY',
    'LIMIT'
  ]

  def readline_completer(text, state):
    if readline.get_begidx() == 0:
      results = [x for x in starters if x.startswith(text.upper())] + [None]
    else:
      results = ([x for x in middlers if x.startswith(text.upper())] + [None])

    return results[state] + " "

  print("Connected to [" + args.host + "].")
  print("")

  return readline_completer

def main():
  parser = argparse.ArgumentParser(description='Druid SQL command-line client.')
  parser_cnn = parser.add_argument_group('Connection options')
  parser_fmt = parser.add_argument_group('Formatting options')
  parser_oth = parser.add_argument_group('Other options')
  parser_cnn.add_argument('--host', '-H', type=str, default='http://localhost:8082/', help='Druid query host or url, like https://localhost:8282/')
  parser_cnn.add_argument('--user', '-u', type=str, help='HTTP basic authentication credentials, like user:password')
  parser_cnn.add_argument('--timeout', type=int, default=0, help='Timeout in seconds')
  parser_cnn.add_argument('--cafile', type=str, help='Path to SSL CA file for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
  parser_cnn.add_argument('--capath', type=str, help='SSL CA path for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
  parser_cnn.add_argument('--ignore-ssl-verification', '-k', action='store_true', default=False, help='Skip verification of SSL certificates.')
  parser_fmt.add_argument('--format', type=str, default='table', choices=('csv', 'tsv', 'json', 'table'), help='Result format')
  parser_fmt.add_argument('--header', action='store_true', help='Include header row for formats "csv" and "tsv"')
  parser_fmt.add_argument('--tsv-delimiter', type=str, default='\t', help='Delimiter for format "tsv"')
  parser_oth.add_argument('--context-option', '-c', type=str, action='append', help='Set context option for this connection, see https://docs.imply.io/on-prem/query-data/sql for options')
  parser_oth.add_argument('--execute', '-e', type=str, help='Execute single SQL query')
  args = parser.parse_args()

  # Build broker URL
  url = args.host.rstrip('/') + '/druid/v2/sql/'
  if not url.startswith('http:') and not url.startswith('https:'):
    url = 'http://' + url

  # Build context
  context = {}
  if args.context_option:
    for opt in args.context_option:
      kv = opt.split("=", 1)
      if len(kv) != 2:
        raise ValueError('Invalid context option, should be key=value: ' + opt)
      if re.match(r"^\d+$", kv[1]):
        context[kv[0]] = long(kv[1])
      else:
        context[kv[0]] = kv[1]

  if args.execute:
    display_query(url, args.execute, context, args)
  else:
    # interactive mode
    print("Welcome to dsql, the command-line client for Druid SQL.")

    readline_history_file = os.path.expanduser("~/.dsql_history")
    readline.parse_and_bind('tab: complete')
    readline.set_history_length(500)
    readline.set_completer(make_readline_completer(url, context, args))

    try:
      readline.read_history_file(readline_history_file)
    except IOError:
      # IOError can happen if the file doesn't exist.
      pass

    print("Type \"\\h\" for help.")

    while True:
      sql = ''
      while not sql.endswith(';'):
        prompt = "dsql> " if sql == '' else 'more> '
        try:
          more_sql = raw_input(prompt)
        except EOFError:
          sys.stdout.write('\n')
          sys.exit(1)
        if sql == '' and more_sql.startswith('\\'):
          # backslash command
          dmatch = re.match(r'^\\d(S?)(\+?)(\s+.*?|)\s*$', more_sql)
          if dmatch:
            include_system = dmatch.group(1)
            extra_info = dmatch.group(2)
            arg = dmatch.group(3).strip()
            if arg:
              sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = " + sql_literal_escape(arg)
              if not include_system:
                sql = sql + " AND TABLE_SCHEMA = 'druid'"
              # break to execute sql
              break
            else:
              sql = "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES";
              if not include_system:
                sql = sql + " WHERE TABLE_SCHEMA = 'druid'"
              # break to execute sql
              break

          hmatch = re.match(r'^\\h\s*$', more_sql)
          if hmatch:
            print("Commands:")
            print("  \\d             show tables")
            print("  \\dS            show tables, including system tables")
            print("  \\d table_name  describe table")
            print("  \\h             show this help")
            print("  \\q             exit this program")
            print("Or enter a SQL query ending with a semicolon (;).")
            continue

          qmatch = re.match(r'^\\q\s*$', more_sql)
          if qmatch:
            sys.exit(0)

          print("No such command: " + more_sql)
        else:
          sql = (sql + ' ' + more_sql).strip()

      try:
        readline.write_history_file(readline_history_file)
        display_query(url, sql.rstrip(';'), context, args)
      except DruidSqlException as e:
        e.write_to(sys.stdout)
      except KeyboardInterrupt:
        sys.stdout.write("Query interrupted\n")
        sys.stdout.flush()

try:
  main()
except DruidSqlException as e:
  e.write_to(sys.stderr)
  sys.exit(1)
except KeyboardInterrupt:
  sys.exit(1)
except IOError as e:
  if e.errno == errno.EPIPE:
    sys.exit(1)
  else:
    raise