2018-08-09 16:37:52 -04:00
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
2024-07-03 03:52:57 -04:00
# NOTE:
# Any feature updates to this script must also be reflected in
# `dsql-main-py3` so that intended changes work for users using
# Python 2 or 3.
2018-08-09 16:37:52 -04:00
from __future__ import print_function
import argparse
import base64
import collections
import csv
import errno
import json
import numbers
2019-01-28 20:02:43 -05:00
import os
2018-08-09 16:37:52 -04:00
import re
2019-01-28 20:02:43 -05:00
import readline
2018-08-09 16:37:52 -04:00
import ssl
import sys
import time
import unicodedata
import urllib2
class DruidSqlException(Exception):
2019-01-28 20:02:43 -05:00
def friendly_message(self):
return self.message if self.message else "Query failed"
2018-08-09 16:37:52 -04:00
def write_to(self, f):
f.write('\x1b[31m')
2019-01-28 20:02:43 -05:00
f.write(self.friendly_message())
2018-08-09 16:37:52 -04:00
f.write('\x1b[0m')
f.write('\n')
f.flush()
2019-01-28 20:02:43 -05:00
def do_query_with_args(url, sql, context, args):
2021-02-11 23:16:47 -05:00
return do_query(url, sql, context, args.timeout, args.user, args.ignore_ssl_verification, args.cafile, args.capath, args.certchain, args.keyfile, args.keypass)
2019-01-28 20:02:43 -05:00
2021-02-11 23:16:47 -05:00
def do_query(url, sql, context, timeout, user, ignore_ssl_verification, ca_file, ca_path, cert_chain, key_file, key_pass):
2018-08-09 16:37:52 -04:00
json_decoder = json.JSONDecoder(object_pairs_hook=collections.OrderedDict)
try:
2019-01-28 20:02:43 -05:00
if timeout <= 0:
timeout = None
query_context = context
2019-09-06 17:45:52 -04:00
elif int(context.get('timeout', 0)) / 1000. < timeout:
2019-01-28 20:02:43 -05:00
query_context = context.copy()
2019-09-06 17:45:52 -04:00
query_context['timeout'] = timeout * 1000
2019-01-28 20:02:43 -05:00
sql_json = json.dumps({'query' : sql, 'context' : query_context})
2018-08-09 16:37:52 -04:00
# SSL stuff
2019-09-06 17:45:52 -04:00
ssl_context = None
2021-02-11 23:16:47 -05:00
if ignore_ssl_verification or ca_file is not None or ca_path is not None or cert_chain is not None:
2018-08-09 16:37:52 -04:00
ssl_context = ssl.create_default_context()
2019-09-06 17:45:52 -04:00
if ignore_ssl_verification:
2018-08-09 16:37:52 -04:00
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
2021-02-11 23:16:47 -05:00
elif ca_path is not None:
2018-08-09 16:37:52 -04:00
ssl_context.load_verify_locations(cafile=ca_file, capath=ca_path)
2021-02-11 23:16:47 -05:00
else:
ssl_context.load_cert_chain(certfile=cert_chain, keyfile=key_file, password=key_pass)
2018-08-09 16:37:52 -04:00
req = urllib2.Request(url, sql_json, {'Content-Type' : 'application/json'})
2019-01-28 20:02:43 -05:00
if user:
req.add_header("Authorization", "Basic %s" % base64.b64encode(user))
2018-08-09 16:37:52 -04:00
response = urllib2.urlopen(req, None, timeout, context=ssl_context)
first_chunk = True
eof = False
buf = ''
while not eof or len(buf) > 0:
while True:
try:
# Remove starting ','
buf = buf.lstrip(',')
obj, sz = json_decoder.raw_decode(buf)
yield obj
buf = buf[sz:]
except ValueError as e:
# Maybe invalid JSON, maybe partial object; it's hard to tell with this library.
if eof and buf.rstrip() == ']':
# Stream done and all objects read.
buf = ''
break
elif eof or len(buf) > 256 * 1024:
# If we read more than 256KB or if it's eof then report the parse error.
raise
else:
# Stop reading objects, get more from the stream instead.
break
# Read more from the http stream
if not eof:
chunk = response.read(8192)
if chunk:
buf = buf + chunk
if first_chunk:
# Remove starting '['
buf = buf.lstrip('[')
else:
# Stream done. Keep reading objects out of buf though.
eof = True
except urllib2.URLError as e:
raise_friendly_error(e)
def raise_friendly_error(e):
if isinstance(e, urllib2.HTTPError):
text = e.read().strip()
error_obj = {}
try:
error_obj = dict(json.loads(text))
except:
pass
if e.code == 500 and 'errorMessage' in error_obj:
error_text = ''
if error_obj['error'] != 'Unknown exception':
error_text = error_text + error_obj['error'] + ': '
if error_obj['errorClass']:
error_text = error_text + str(error_obj['errorClass']) + ': '
error_text = error_text + str(error_obj['errorMessage'])
if error_obj['host']:
error_text = error_text + ' (' + str(error_obj['host']) + ')'
raise DruidSqlException(error_text)
2019-02-27 23:40:30 -05:00
elif e.code == 405:
error_text = 'HTTP Error {0}: {1}\n{2}'.format(e.code, e.reason + " - Are you using the correct broker URL and " +\
"is druid.sql.enabled set to true on your broker?", text)
raise DruidSqlException(error_text)
2018-08-09 16:37:52 -04:00
else:
raise DruidSqlException("HTTP Error {0}: {1}\n{2}".format(e.code, e.reason, text))
else:
raise DruidSqlException(str(e))
def to_utf8(value):
if value is None:
return ""
elif isinstance(value, unicode):
return value.encode("utf-8")
else:
return str(value)
def to_tsv(values, delimiter):
return delimiter.join(to_utf8(v).replace(delimiter, '') for v in values)
def print_csv(rows, header):
csv_writer = csv.writer(sys.stdout)
first = True
for row in rows:
if first and header:
csv_writer.writerow(list(to_utf8(k) for k in row.keys()))
first = False
values = []
for key, value in row.iteritems():
values.append(to_utf8(value))
csv_writer.writerow(values)
def print_tsv(rows, header, tsv_delimiter):
first = True
for row in rows:
if first and header:
print(to_tsv(row.keys(), tsv_delimiter))
first = False
values = []
for key, value in row.iteritems():
values.append(value)
print(to_tsv(values, tsv_delimiter))
def print_json(rows):
for row in rows:
print(json.dumps(row))
def table_to_printable_value(value):
# Unicode string, trimmed with control characters removed
if value is None:
return u"NULL"
else:
return to_utf8(value).strip().decode('utf-8').translate(dict.fromkeys(range(32)))
def table_compute_string_width(v):
normalized = unicodedata.normalize('NFC', v)
width = 0
for c in normalized:
ccategory = unicodedata.category(c)
cwidth = unicodedata.east_asian_width(c)
if ccategory == 'Cf':
# Formatting control, zero width
pass
elif cwidth == 'F' or cwidth == 'W':
# Double-wide character, prints in two columns
width = width + 2
else:
# All other characters
width = width + 1
return width
def table_compute_column_widths(row_buffer):
widths = None
for values in row_buffer:
values_widths = [table_compute_string_width(v) for v in values]
if not widths:
widths = values_widths
else:
i = 0
for v in values:
widths[i] = max(widths[i], values_widths[i])
i = i + 1
return widths
def table_print_row(values, column_widths, column_types):
vertical_line = u'\u2502'.encode('utf-8')
for i in xrange(0, len(values)):
padding = ' ' * max(0, column_widths[i] - table_compute_string_width(values[i]))
if column_types and column_types[i] == 'n':
print(vertical_line + ' ' + padding + values[i].encode('utf-8') + ' ', end="")
else:
print(vertical_line + ' ' + values[i].encode('utf-8') + padding + ' ', end="")
print(vertical_line)
def table_print_header(values, column_widths):
# Line 1
left_corner = u'\u250C'.encode('utf-8')
horizontal_line = u'\u2500'.encode('utf-8')
top_tee = u'\u252C'.encode('utf-8')
right_corner = u'\u2510'.encode('utf-8')
print(left_corner, end="")
for i in xrange(0, len(column_widths)):
print(horizontal_line * max(0, column_widths[i] + 2), end="")
if i + 1 < len(column_widths):
print(top_tee, end="")
print(right_corner)
# Line 2
table_print_row(values, column_widths, None)
# Line 3
left_tee = u'\u251C'.encode('utf-8')
cross = u'\u253C'.encode('utf-8')
right_tee = u'\u2524'.encode('utf-8')
print(left_tee, end="")
for i in xrange(0, len(column_widths)):
print(horizontal_line * max(0, column_widths[i] + 2), end="")
if i + 1 < len(column_widths):
print(cross, end="")
print(right_tee)
def table_print_bottom(column_widths):
left_corner = u'\u2514'.encode('utf-8')
right_corner = u'\u2518'.encode('utf-8')
bottom_tee = u'\u2534'.encode('utf-8')
horizontal_line = u'\u2500'.encode('utf-8')
print(left_corner, end="")
for i in xrange(0, len(column_widths)):
print(horizontal_line * max(0, column_widths[i] + 2), end="")
if i + 1 < len(column_widths):
print(bottom_tee, end="")
print(right_corner)
def table_print_row_buffer(row_buffer, column_widths, column_types):
first = True
for values in row_buffer:
if first:
table_print_header(values, column_widths)
first = False
else:
table_print_row(values, column_widths, column_types)
def print_table(rows):
start = time.time()
nrows = 0
first = True
# Buffer some rows before printing.
rows_to_buffer = 500
row_buffer = []
column_types = []
column_widths = None
for row in rows:
nrows = nrows + 1
if first:
row_buffer.append([table_to_printable_value(k) for k in row.keys()])
for k in row.keys():
if isinstance(row[k], numbers.Number):
column_types.append('n')
else:
column_types.append('s')
first = False
values = [table_to_printable_value(v) for k, v in row.iteritems()]
if rows_to_buffer > 0:
row_buffer.append(values)
rows_to_buffer = rows_to_buffer - 1
else:
if row_buffer:
column_widths = table_compute_column_widths(row_buffer)
table_print_row_buffer(row_buffer, column_widths, column_types)
del row_buffer[:]
table_print_row(values, column_widths, column_types)
if row_buffer:
column_widths = table_compute_column_widths(row_buffer)
table_print_row_buffer(row_buffer, column_widths, column_types)
if column_widths:
table_print_bottom(column_widths)
print("Retrieved {0:,d} row{1:s} in {2:.2f}s.".format(nrows, 's' if nrows != 1 else '', time.time() - start))
print("")
def display_query(url, sql, context, args):
2019-01-28 20:02:43 -05:00
rows = do_query_with_args(url, sql, context, args)
2018-08-09 16:37:52 -04:00
if args.format == 'csv':
print_csv(rows, args.header)
elif args.format == 'tsv':
print_tsv(rows, args.header, args.tsv_delimiter)
elif args.format == 'json':
print_json(rows)
elif args.format == 'table':
print_table(rows)
2019-01-28 20:02:43 -05:00
def sql_literal_escape(s):
2018-08-09 16:37:52 -04:00
if s is None:
return "''"
elif isinstance(s, unicode):
ustr = s
else:
ustr = str(s).decode('utf-8')
escaped = [u"U&'"]
for c in ustr:
ccategory = unicodedata.category(c)
if ccategory.startswith('L') or ccategory.startswith('N') or c == ' ':
escaped.append(c)
else:
escaped.append(u'\\')
escaped.append('%04x' % ord(c))
escaped.append("'")
return ''.join(escaped)
2019-01-28 20:02:43 -05:00
def make_readline_completer(url, context, args):
starters = [
'EXPLAIN PLAN FOR',
'SELECT'
]
middlers = [
'FROM',
'WHERE',
'GROUP BY',
'ORDER BY',
'LIMIT'
]
def readline_completer(text, state):
if readline.get_begidx() == 0:
results = [x for x in starters if x.startswith(text.upper())] + [None]
else:
results = ([x for x in middlers if x.startswith(text.upper())] + [None])
return results[state] + " "
print("Connected to [" + args.host + "].")
print("")
return readline_completer
2018-08-09 16:37:52 -04:00
def main():
parser = argparse.ArgumentParser(description='Druid SQL command-line client.')
2019-01-28 20:02:43 -05:00
parser_cnn = parser.add_argument_group('Connection options')
parser_fmt = parser.add_argument_group('Formatting options')
parser_oth = parser.add_argument_group('Other options')
parser_cnn.add_argument('--host', '-H', type=str, default='http://localhost:8082/', help='Druid query host or url, like https://localhost:8282/')
parser_cnn.add_argument('--user', '-u', type=str, help='HTTP basic authentication credentials, like user:password')
parser_cnn.add_argument('--timeout', type=int, default=0, help='Timeout in seconds')
parser_cnn.add_argument('--cafile', type=str, help='Path to SSL CA file for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
parser_cnn.add_argument('--capath', type=str, help='SSL CA path for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
parser_cnn.add_argument('--ignore-ssl-verification', '-k', action='store_true', default=False, help='Skip verification of SSL certificates.')
parser_fmt.add_argument('--format', type=str, default='table', choices=('csv', 'tsv', 'json', 'table'), help='Result format')
parser_fmt.add_argument('--header', action='store_true', help='Include header row for formats "csv" and "tsv"')
parser_fmt.add_argument('--tsv-delimiter', type=str, default='\t', help='Delimiter for format "tsv"')
2020-04-04 02:29:23 -04:00
parser_oth.add_argument('--context-option', '-c', type=str, action='append', help='Set context option for this connection, see https://druid.apache.org/docs/latest/querying/sql.html#connection-context for options')
2019-01-28 20:02:43 -05:00
parser_oth.add_argument('--execute', '-e', type=str, help='Execute single SQL query')
2021-02-11 23:16:47 -05:00
parser_cnn.add_argument('--certchain', type=str, help='Path to SSL certificate used to connect to server. See load_cert_chain() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
parser_cnn.add_argument('--keyfile', type=str, help='Path to private SSL key used to connect to server. See load_cert_chain() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
parser_cnn.add_argument('--keypass', type=str, help='Password to private SSL key file used to connect to server. See load_cert_chain() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
2018-08-09 16:37:52 -04:00
args = parser.parse_args()
# Build broker URL
url = args.host.rstrip('/') + '/druid/v2/sql/'
if not url.startswith('http:') and not url.startswith('https:'):
url = 'http://' + url
# Build context
context = {}
if args.context_option:
for opt in args.context_option:
kv = opt.split("=", 1)
if len(kv) != 2:
raise ValueError('Invalid context option, should be key=value: ' + opt)
if re.match(r"^\d+$", kv[1]):
context[kv[0]] = long(kv[1])
else:
context[kv[0]] = kv[1]
if args.execute:
display_query(url, args.execute, context, args)
else:
# interactive mode
print("Welcome to dsql, the command-line client for Druid SQL.")
2019-01-28 20:02:43 -05:00
readline_history_file = os.path.expanduser("~/.dsql_history")
readline.parse_and_bind('tab: complete')
readline.set_history_length(500)
readline.set_completer(make_readline_completer(url, context, args))
try:
readline.read_history_file(readline_history_file)
except IOError:
# IOError can happen if the file doesn't exist.
pass
print("Type \"\\h\" for help.")
2018-08-09 16:37:52 -04:00
while True:
sql = ''
while not sql.endswith(';'):
prompt = "dsql> " if sql == '' else 'more> '
try:
more_sql = raw_input(prompt)
except EOFError:
sys.stdout.write('\n')
sys.exit(1)
if sql == '' and more_sql.startswith('\\'):
# backslash command
dmatch = re.match(r'^\\d(S?)(\+?)(\s+.*?|)\s*$', more_sql)
if dmatch:
include_system = dmatch.group(1)
extra_info = dmatch.group(2)
arg = dmatch.group(3).strip()
if arg:
2019-01-28 20:02:43 -05:00
sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = " + sql_literal_escape(arg)
2018-08-09 16:37:52 -04:00
if not include_system:
sql = sql + " AND TABLE_SCHEMA = 'druid'"
# break to execute sql
break
else:
2019-09-06 17:45:52 -04:00
sql = "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES"
2018-08-09 16:37:52 -04:00
if not include_system:
sql = sql + " WHERE TABLE_SCHEMA = 'druid'"
# break to execute sql
break
hmatch = re.match(r'^\\h\s*$', more_sql)
if hmatch:
print("Commands:")
2019-01-28 20:02:43 -05:00
print(" \\d show tables")
print(" \\dS show tables, including system tables")
print(" \\d table_name describe table")
print(" \\h show this help")
print(" \\q exit this program")
2018-08-09 16:37:52 -04:00
print("Or enter a SQL query ending with a semicolon (;).")
continue
qmatch = re.match(r'^\\q\s*$', more_sql)
if qmatch:
sys.exit(0)
print("No such command: " + more_sql)
else:
sql = (sql + ' ' + more_sql).strip()
try:
2019-01-28 20:02:43 -05:00
readline.write_history_file(readline_history_file)
2018-08-09 16:37:52 -04:00
display_query(url, sql.rstrip(';'), context, args)
except DruidSqlException as e:
e.write_to(sys.stdout)
except KeyboardInterrupt:
sys.stdout.write("Query interrupted\n")
sys.stdout.flush()
try:
main()
except DruidSqlException as e:
e.write_to(sys.stderr)
sys.exit(1)
except KeyboardInterrupt:
sys.exit(1)
except IOError as e:
if e.errno == errno.EPIPE:
sys.exit(1)
else:
raise