druid/examples/bin/dsql-main

508 lines
16 KiB
Plaintext
Raw Normal View History

#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import argparse
import base64
import collections
import csv
import errno
import json
import numbers
import os
import re
import readline
import ssl
import sys
import time
import unicodedata
import urllib2
class DruidSqlException(Exception):
def friendly_message(self):
return self.message if self.message else "Query failed"
def write_to(self, f):
f.write('\x1b[31m')
f.write(self.friendly_message())
f.write('\x1b[0m')
f.write('\n')
f.flush()
def do_query_with_args(url, sql, context, args):
return do_query(url, sql, context, args.timeout, args.user, args.ignore_ssl_verification, args.cafile, args.capath)
def do_query(url, sql, context, timeout, user, ignore_ssl_verification, ca_file, ca_path):
json_decoder = json.JSONDecoder(object_pairs_hook=collections.OrderedDict)
try:
if timeout <= 0:
timeout = None
query_context = context
elif int(context.get('timeout', 0)) / 1000 < timeout:
query_context = context.copy()
query_context['timeout'] = timeout * 1000;
sql_json = json.dumps({'query' : sql, 'context' : query_context})
# SSL stuff
ssl_context = None;
if (ignore_ssl_verification or ca_file != None or ca_path != None):
ssl_context = ssl.create_default_context()
if (ignore_ssl_verification):
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
else:
ssl_context.load_verify_locations(cafile=ca_file, capath=ca_path)
req = urllib2.Request(url, sql_json, {'Content-Type' : 'application/json'})
if user:
req.add_header("Authorization", "Basic %s" % base64.b64encode(user))
response = urllib2.urlopen(req, None, timeout, context=ssl_context)
first_chunk = True
eof = False
buf = ''
while not eof or len(buf) > 0:
while True:
try:
# Remove starting ','
buf = buf.lstrip(',')
obj, sz = json_decoder.raw_decode(buf)
yield obj
buf = buf[sz:]
except ValueError as e:
# Maybe invalid JSON, maybe partial object; it's hard to tell with this library.
if eof and buf.rstrip() == ']':
# Stream done and all objects read.
buf = ''
break
elif eof or len(buf) > 256 * 1024:
# If we read more than 256KB or if it's eof then report the parse error.
raise
else:
# Stop reading objects, get more from the stream instead.
break
# Read more from the http stream
if not eof:
chunk = response.read(8192)
if chunk:
buf = buf + chunk
if first_chunk:
# Remove starting '['
buf = buf.lstrip('[')
else:
# Stream done. Keep reading objects out of buf though.
eof = True
except urllib2.URLError as e:
raise_friendly_error(e)
def raise_friendly_error(e):
if isinstance(e, urllib2.HTTPError):
text = e.read().strip()
error_obj = {}
try:
error_obj = dict(json.loads(text))
except:
pass
if e.code == 500 and 'errorMessage' in error_obj:
error_text = ''
if error_obj['error'] != 'Unknown exception':
error_text = error_text + error_obj['error'] + ': '
if error_obj['errorClass']:
error_text = error_text + str(error_obj['errorClass']) + ': '
error_text = error_text + str(error_obj['errorMessage'])
if error_obj['host']:
error_text = error_text + ' (' + str(error_obj['host']) + ')'
raise DruidSqlException(error_text)
else:
raise DruidSqlException("HTTP Error {0}: {1}\n{2}".format(e.code, e.reason, text))
else:
raise DruidSqlException(str(e))
def to_utf8(value):
if value is None:
return ""
elif isinstance(value, unicode):
return value.encode("utf-8")
else:
return str(value)
def to_tsv(values, delimiter):
return delimiter.join(to_utf8(v).replace(delimiter, '') for v in values)
def print_csv(rows, header):
csv_writer = csv.writer(sys.stdout)
first = True
for row in rows:
if first and header:
csv_writer.writerow(list(to_utf8(k) for k in row.keys()))
first = False
values = []
for key, value in row.iteritems():
values.append(to_utf8(value))
csv_writer.writerow(values)
def print_tsv(rows, header, tsv_delimiter):
first = True
for row in rows:
if first and header:
print(to_tsv(row.keys(), tsv_delimiter))
first = False
values = []
for key, value in row.iteritems():
values.append(value)
print(to_tsv(values, tsv_delimiter))
def print_json(rows):
for row in rows:
print(json.dumps(row))
def table_to_printable_value(value):
# Unicode string, trimmed with control characters removed
if value is None:
return u"NULL"
else:
return to_utf8(value).strip().decode('utf-8').translate(dict.fromkeys(range(32)))
def table_compute_string_width(v):
normalized = unicodedata.normalize('NFC', v)
width = 0
for c in normalized:
ccategory = unicodedata.category(c)
cwidth = unicodedata.east_asian_width(c)
if ccategory == 'Cf':
# Formatting control, zero width
pass
elif cwidth == 'F' or cwidth == 'W':
# Double-wide character, prints in two columns
width = width + 2
else:
# All other characters
width = width + 1
return width
def table_compute_column_widths(row_buffer):
widths = None
for values in row_buffer:
values_widths = [table_compute_string_width(v) for v in values]
if not widths:
widths = values_widths
else:
i = 0
for v in values:
widths[i] = max(widths[i], values_widths[i])
i = i + 1
return widths
def table_print_row(values, column_widths, column_types):
vertical_line = u'\u2502'.encode('utf-8')
for i in xrange(0, len(values)):
padding = ' ' * max(0, column_widths[i] - table_compute_string_width(values[i]))
if column_types and column_types[i] == 'n':
print(vertical_line + ' ' + padding + values[i].encode('utf-8') + ' ', end="")
else:
print(vertical_line + ' ' + values[i].encode('utf-8') + padding + ' ', end="")
print(vertical_line)
def table_print_header(values, column_widths):
# Line 1
left_corner = u'\u250C'.encode('utf-8')
horizontal_line = u'\u2500'.encode('utf-8')
top_tee = u'\u252C'.encode('utf-8')
right_corner = u'\u2510'.encode('utf-8')
print(left_corner, end="")
for i in xrange(0, len(column_widths)):
print(horizontal_line * max(0, column_widths[i] + 2), end="")
if i + 1 < len(column_widths):
print(top_tee, end="")
print(right_corner)
# Line 2
table_print_row(values, column_widths, None)
# Line 3
left_tee = u'\u251C'.encode('utf-8')
cross = u'\u253C'.encode('utf-8')
right_tee = u'\u2524'.encode('utf-8')
print(left_tee, end="")
for i in xrange(0, len(column_widths)):
print(horizontal_line * max(0, column_widths[i] + 2), end="")
if i + 1 < len(column_widths):
print(cross, end="")
print(right_tee)
def table_print_bottom(column_widths):
left_corner = u'\u2514'.encode('utf-8')
right_corner = u'\u2518'.encode('utf-8')
bottom_tee = u'\u2534'.encode('utf-8')
horizontal_line = u'\u2500'.encode('utf-8')
print(left_corner, end="")
for i in xrange(0, len(column_widths)):
print(horizontal_line * max(0, column_widths[i] + 2), end="")
if i + 1 < len(column_widths):
print(bottom_tee, end="")
print(right_corner)
def table_print_row_buffer(row_buffer, column_widths, column_types):
first = True
for values in row_buffer:
if first:
table_print_header(values, column_widths)
first = False
else:
table_print_row(values, column_widths, column_types)
def print_table(rows):
start = time.time()
nrows = 0
first = True
# Buffer some rows before printing.
rows_to_buffer = 500
row_buffer = []
column_types = []
column_widths = None
for row in rows:
nrows = nrows + 1
if first:
row_buffer.append([table_to_printable_value(k) for k in row.keys()])
for k in row.keys():
if isinstance(row[k], numbers.Number):
column_types.append('n')
else:
column_types.append('s')
first = False
values = [table_to_printable_value(v) for k, v in row.iteritems()]
if rows_to_buffer > 0:
row_buffer.append(values)
rows_to_buffer = rows_to_buffer - 1
else:
if row_buffer:
column_widths = table_compute_column_widths(row_buffer)
table_print_row_buffer(row_buffer, column_widths, column_types)
del row_buffer[:]
table_print_row(values, column_widths, column_types)
if row_buffer:
column_widths = table_compute_column_widths(row_buffer)
table_print_row_buffer(row_buffer, column_widths, column_types)
if column_widths:
table_print_bottom(column_widths)
print("Retrieved {0:,d} row{1:s} in {2:.2f}s.".format(nrows, 's' if nrows != 1 else '', time.time() - start))
print("")
def display_query(url, sql, context, args):
rows = do_query_with_args(url, sql, context, args)
if args.format == 'csv':
print_csv(rows, args.header)
elif args.format == 'tsv':
print_tsv(rows, args.header, args.tsv_delimiter)
elif args.format == 'json':
print_json(rows)
elif args.format == 'table':
print_table(rows)
def sql_literal_escape(s):
if s is None:
return "''"
elif isinstance(s, unicode):
ustr = s
else:
ustr = str(s).decode('utf-8')
escaped = [u"U&'"]
for c in ustr:
ccategory = unicodedata.category(c)
if ccategory.startswith('L') or ccategory.startswith('N') or c == ' ':
escaped.append(c)
else:
escaped.append(u'\\')
escaped.append('%04x' % ord(c))
escaped.append("'")
return ''.join(escaped)
def make_readline_completer(url, context, args):
starters = [
'EXPLAIN PLAN FOR',
'SELECT'
]
middlers = [
'FROM',
'WHERE',
'GROUP BY',
'ORDER BY',
'LIMIT'
]
def readline_completer(text, state):
if readline.get_begidx() == 0:
results = [x for x in starters if x.startswith(text.upper())] + [None]
else:
results = ([x for x in middlers if x.startswith(text.upper())] + [None])
return results[state] + " "
print("Connected to [" + args.host + "].")
print("")
return readline_completer
def main():
parser = argparse.ArgumentParser(description='Druid SQL command-line client.')
parser_cnn = parser.add_argument_group('Connection options')
parser_fmt = parser.add_argument_group('Formatting options')
parser_oth = parser.add_argument_group('Other options')
parser_cnn.add_argument('--host', '-H', type=str, default='http://localhost:8082/', help='Druid query host or url, like https://localhost:8282/')
parser_cnn.add_argument('--user', '-u', type=str, help='HTTP basic authentication credentials, like user:password')
parser_cnn.add_argument('--timeout', type=int, default=0, help='Timeout in seconds')
parser_cnn.add_argument('--cafile', type=str, help='Path to SSL CA file for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
parser_cnn.add_argument('--capath', type=str, help='SSL CA path for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
parser_cnn.add_argument('--ignore-ssl-verification', '-k', action='store_true', default=False, help='Skip verification of SSL certificates.')
parser_fmt.add_argument('--format', type=str, default='table', choices=('csv', 'tsv', 'json', 'table'), help='Result format')
parser_fmt.add_argument('--header', action='store_true', help='Include header row for formats "csv" and "tsv"')
parser_fmt.add_argument('--tsv-delimiter', type=str, default='\t', help='Delimiter for format "tsv"')
parser_oth.add_argument('--context-option', '-c', type=str, action='append', help='Set context option for this connection, see https://docs.imply.io/on-prem/query-data/sql for options')
parser_oth.add_argument('--execute', '-e', type=str, help='Execute single SQL query')
args = parser.parse_args()
# Build broker URL
url = args.host.rstrip('/') + '/druid/v2/sql/'
if not url.startswith('http:') and not url.startswith('https:'):
url = 'http://' + url
# Build context
context = {}
if args.context_option:
for opt in args.context_option:
kv = opt.split("=", 1)
if len(kv) != 2:
raise ValueError('Invalid context option, should be key=value: ' + opt)
if re.match(r"^\d+$", kv[1]):
context[kv[0]] = long(kv[1])
else:
context[kv[0]] = kv[1]
if args.execute:
display_query(url, args.execute, context, args)
else:
# interactive mode
print("Welcome to dsql, the command-line client for Druid SQL.")
readline_history_file = os.path.expanduser("~/.dsql_history")
readline.parse_and_bind('tab: complete')
readline.set_history_length(500)
readline.set_completer(make_readline_completer(url, context, args))
try:
readline.read_history_file(readline_history_file)
except IOError:
# IOError can happen if the file doesn't exist.
pass
print("Type \"\\h\" for help.")
while True:
sql = ''
while not sql.endswith(';'):
prompt = "dsql> " if sql == '' else 'more> '
try:
more_sql = raw_input(prompt)
except EOFError:
sys.stdout.write('\n')
sys.exit(1)
if sql == '' and more_sql.startswith('\\'):
# backslash command
dmatch = re.match(r'^\\d(S?)(\+?)(\s+.*?|)\s*$', more_sql)
if dmatch:
include_system = dmatch.group(1)
extra_info = dmatch.group(2)
arg = dmatch.group(3).strip()
if arg:
sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = " + sql_literal_escape(arg)
if not include_system:
sql = sql + " AND TABLE_SCHEMA = 'druid'"
# break to execute sql
break
else:
sql = "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES";
if not include_system:
sql = sql + " WHERE TABLE_SCHEMA = 'druid'"
# break to execute sql
break
hmatch = re.match(r'^\\h\s*$', more_sql)
if hmatch:
print("Commands:")
print(" \\d show tables")
print(" \\dS show tables, including system tables")
print(" \\d table_name describe table")
print(" \\h show this help")
print(" \\q exit this program")
print("Or enter a SQL query ending with a semicolon (;).")
continue
qmatch = re.match(r'^\\q\s*$', more_sql)
if qmatch:
sys.exit(0)
print("No such command: " + more_sql)
else:
sql = (sql + ' ' + more_sql).strip()
try:
readline.write_history_file(readline_history_file)
display_query(url, sql.rstrip(';'), context, args)
except DruidSqlException as e:
e.write_to(sys.stdout)
except KeyboardInterrupt:
sys.stdout.write("Query interrupted\n")
sys.stdout.flush()
try:
main()
except DruidSqlException as e:
e.write_to(sys.stderr)
sys.exit(1)
except KeyboardInterrupt:
sys.exit(1)
except IOError as e:
if e.errno == errno.EPIPE:
sys.exit(1)
else:
raise