mirror of https://github.com/apache/druid.git
522 lines
17 KiB
Python
522 lines
17 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
# NOTE:
|
|
# Any feature updates to this script must also be reflected in
|
|
# `dsql-main-py3` so that intended changes work for users using
|
|
# Python 2 or 3.
|
|
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import base64
|
|
import collections
|
|
import csv
|
|
import errno
|
|
import json
|
|
import numbers
|
|
import os
|
|
import re
|
|
import readline
|
|
import ssl
|
|
import sys
|
|
import time
|
|
import unicodedata
|
|
import urllib2
|
|
|
|
class DruidSqlException(Exception):
|
|
def friendly_message(self):
|
|
return self.message if self.message else "Query failed"
|
|
|
|
def write_to(self, f):
|
|
f.write('\x1b[31m')
|
|
f.write(self.friendly_message())
|
|
f.write('\x1b[0m')
|
|
f.write('\n')
|
|
f.flush()
|
|
|
|
def do_query_with_args(url, sql, context, args):
|
|
return do_query(url, sql, context, args.timeout, args.user, args.ignore_ssl_verification, args.cafile, args.capath, args.certchain, args.keyfile, args.keypass)
|
|
|
|
def do_query(url, sql, context, timeout, user, ignore_ssl_verification, ca_file, ca_path, cert_chain, key_file, key_pass):
|
|
json_decoder = json.JSONDecoder(object_pairs_hook=collections.OrderedDict)
|
|
try:
|
|
if timeout <= 0:
|
|
timeout = None
|
|
query_context = context
|
|
elif int(context.get('timeout', 0)) / 1000. < timeout:
|
|
query_context = context.copy()
|
|
query_context['timeout'] = timeout * 1000
|
|
|
|
sql_json = json.dumps({'query' : sql, 'context' : query_context})
|
|
|
|
# SSL stuff
|
|
ssl_context = None
|
|
if ignore_ssl_verification or ca_file is not None or ca_path is not None or cert_chain is not None:
|
|
ssl_context = ssl.create_default_context()
|
|
if ignore_ssl_verification:
|
|
ssl_context.check_hostname = False
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
elif ca_path is not None:
|
|
ssl_context.load_verify_locations(cafile=ca_file, capath=ca_path)
|
|
else:
|
|
ssl_context.load_cert_chain(certfile=cert_chain, keyfile=key_file, password=key_pass)
|
|
|
|
req = urllib2.Request(url, sql_json, {'Content-Type' : 'application/json'})
|
|
|
|
if user:
|
|
req.add_header("Authorization", "Basic %s" % base64.b64encode(user))
|
|
|
|
response = urllib2.urlopen(req, None, timeout, context=ssl_context)
|
|
|
|
first_chunk = True
|
|
eof = False
|
|
buf = ''
|
|
|
|
while not eof or len(buf) > 0:
|
|
while True:
|
|
try:
|
|
# Remove starting ','
|
|
buf = buf.lstrip(',')
|
|
obj, sz = json_decoder.raw_decode(buf)
|
|
yield obj
|
|
buf = buf[sz:]
|
|
except ValueError as e:
|
|
# Maybe invalid JSON, maybe partial object; it's hard to tell with this library.
|
|
if eof and buf.rstrip() == ']':
|
|
# Stream done and all objects read.
|
|
buf = ''
|
|
break
|
|
elif eof or len(buf) > 256 * 1024:
|
|
# If we read more than 256KB or if it's eof then report the parse error.
|
|
raise
|
|
else:
|
|
# Stop reading objects, get more from the stream instead.
|
|
break
|
|
|
|
# Read more from the http stream
|
|
if not eof:
|
|
chunk = response.read(8192)
|
|
if chunk:
|
|
buf = buf + chunk
|
|
if first_chunk:
|
|
# Remove starting '['
|
|
buf = buf.lstrip('[')
|
|
else:
|
|
# Stream done. Keep reading objects out of buf though.
|
|
eof = True
|
|
|
|
except urllib2.URLError as e:
|
|
raise_friendly_error(e)
|
|
|
|
def raise_friendly_error(e):
|
|
if isinstance(e, urllib2.HTTPError):
|
|
text = e.read().strip()
|
|
error_obj = {}
|
|
try:
|
|
error_obj = dict(json.loads(text))
|
|
except:
|
|
pass
|
|
if e.code == 500 and 'errorMessage' in error_obj:
|
|
error_text = ''
|
|
if error_obj['error'] != 'Unknown exception':
|
|
error_text = error_text + error_obj['error'] + ': '
|
|
if error_obj['errorClass']:
|
|
error_text = error_text + str(error_obj['errorClass']) + ': '
|
|
error_text = error_text + str(error_obj['errorMessage'])
|
|
if error_obj['host']:
|
|
error_text = error_text + ' (' + str(error_obj['host']) + ')'
|
|
raise DruidSqlException(error_text)
|
|
elif e.code == 405:
|
|
error_text = 'HTTP Error {0}: {1}\n{2}'.format(e.code, e.reason + " - Are you using the correct broker URL and " +\
|
|
"is druid.sql.enabled set to true on your broker?", text)
|
|
raise DruidSqlException(error_text)
|
|
else:
|
|
raise DruidSqlException("HTTP Error {0}: {1}\n{2}".format(e.code, e.reason, text))
|
|
else:
|
|
raise DruidSqlException(str(e))
|
|
|
|
def to_utf8(value):
|
|
if value is None:
|
|
return ""
|
|
elif isinstance(value, unicode):
|
|
return value.encode("utf-8")
|
|
else:
|
|
return str(value)
|
|
|
|
def to_tsv(values, delimiter):
|
|
return delimiter.join(to_utf8(v).replace(delimiter, '') for v in values)
|
|
|
|
def print_csv(rows, header):
|
|
csv_writer = csv.writer(sys.stdout)
|
|
first = True
|
|
for row in rows:
|
|
if first and header:
|
|
csv_writer.writerow(list(to_utf8(k) for k in row.keys()))
|
|
first = False
|
|
|
|
values = []
|
|
for key, value in row.iteritems():
|
|
values.append(to_utf8(value))
|
|
|
|
csv_writer.writerow(values)
|
|
|
|
def print_tsv(rows, header, tsv_delimiter):
|
|
first = True
|
|
for row in rows:
|
|
if first and header:
|
|
print(to_tsv(row.keys(), tsv_delimiter))
|
|
first = False
|
|
|
|
values = []
|
|
for key, value in row.iteritems():
|
|
values.append(value)
|
|
|
|
print(to_tsv(values, tsv_delimiter))
|
|
|
|
def print_json(rows):
|
|
for row in rows:
|
|
print(json.dumps(row))
|
|
|
|
def table_to_printable_value(value):
|
|
# Unicode string, trimmed with control characters removed
|
|
if value is None:
|
|
return u"NULL"
|
|
else:
|
|
return to_utf8(value).strip().decode('utf-8').translate(dict.fromkeys(range(32)))
|
|
|
|
def table_compute_string_width(v):
|
|
normalized = unicodedata.normalize('NFC', v)
|
|
width = 0
|
|
for c in normalized:
|
|
ccategory = unicodedata.category(c)
|
|
cwidth = unicodedata.east_asian_width(c)
|
|
if ccategory == 'Cf':
|
|
# Formatting control, zero width
|
|
pass
|
|
elif cwidth == 'F' or cwidth == 'W':
|
|
# Double-wide character, prints in two columns
|
|
width = width + 2
|
|
else:
|
|
# All other characters
|
|
width = width + 1
|
|
return width
|
|
|
|
def table_compute_column_widths(row_buffer):
|
|
widths = None
|
|
for values in row_buffer:
|
|
values_widths = [table_compute_string_width(v) for v in values]
|
|
if not widths:
|
|
widths = values_widths
|
|
else:
|
|
i = 0
|
|
for v in values:
|
|
widths[i] = max(widths[i], values_widths[i])
|
|
i = i + 1
|
|
return widths
|
|
|
|
def table_print_row(values, column_widths, column_types):
|
|
vertical_line = u'\u2502'.encode('utf-8')
|
|
for i in xrange(0, len(values)):
|
|
padding = ' ' * max(0, column_widths[i] - table_compute_string_width(values[i]))
|
|
if column_types and column_types[i] == 'n':
|
|
print(vertical_line + ' ' + padding + values[i].encode('utf-8') + ' ', end="")
|
|
else:
|
|
print(vertical_line + ' ' + values[i].encode('utf-8') + padding + ' ', end="")
|
|
print(vertical_line)
|
|
|
|
def table_print_header(values, column_widths):
|
|
# Line 1
|
|
left_corner = u'\u250C'.encode('utf-8')
|
|
horizontal_line = u'\u2500'.encode('utf-8')
|
|
top_tee = u'\u252C'.encode('utf-8')
|
|
right_corner = u'\u2510'.encode('utf-8')
|
|
print(left_corner, end="")
|
|
for i in xrange(0, len(column_widths)):
|
|
print(horizontal_line * max(0, column_widths[i] + 2), end="")
|
|
if i + 1 < len(column_widths):
|
|
print(top_tee, end="")
|
|
print(right_corner)
|
|
|
|
# Line 2
|
|
table_print_row(values, column_widths, None)
|
|
|
|
# Line 3
|
|
left_tee = u'\u251C'.encode('utf-8')
|
|
cross = u'\u253C'.encode('utf-8')
|
|
right_tee = u'\u2524'.encode('utf-8')
|
|
print(left_tee, end="")
|
|
for i in xrange(0, len(column_widths)):
|
|
print(horizontal_line * max(0, column_widths[i] + 2), end="")
|
|
if i + 1 < len(column_widths):
|
|
print(cross, end="")
|
|
print(right_tee)
|
|
|
|
def table_print_bottom(column_widths):
|
|
left_corner = u'\u2514'.encode('utf-8')
|
|
right_corner = u'\u2518'.encode('utf-8')
|
|
bottom_tee = u'\u2534'.encode('utf-8')
|
|
horizontal_line = u'\u2500'.encode('utf-8')
|
|
print(left_corner, end="")
|
|
for i in xrange(0, len(column_widths)):
|
|
print(horizontal_line * max(0, column_widths[i] + 2), end="")
|
|
if i + 1 < len(column_widths):
|
|
print(bottom_tee, end="")
|
|
print(right_corner)
|
|
|
|
def table_print_row_buffer(row_buffer, column_widths, column_types):
|
|
first = True
|
|
for values in row_buffer:
|
|
if first:
|
|
table_print_header(values, column_widths)
|
|
first = False
|
|
else:
|
|
table_print_row(values, column_widths, column_types)
|
|
|
|
def print_table(rows):
|
|
start = time.time()
|
|
nrows = 0
|
|
first = True
|
|
|
|
# Buffer some rows before printing.
|
|
rows_to_buffer = 500
|
|
row_buffer = []
|
|
column_types = []
|
|
column_widths = None
|
|
|
|
for row in rows:
|
|
nrows = nrows + 1
|
|
|
|
if first:
|
|
row_buffer.append([table_to_printable_value(k) for k in row.keys()])
|
|
for k in row.keys():
|
|
if isinstance(row[k], numbers.Number):
|
|
column_types.append('n')
|
|
else:
|
|
column_types.append('s')
|
|
first = False
|
|
|
|
values = [table_to_printable_value(v) for k, v in row.iteritems()]
|
|
if rows_to_buffer > 0:
|
|
row_buffer.append(values)
|
|
rows_to_buffer = rows_to_buffer - 1
|
|
else:
|
|
if row_buffer:
|
|
column_widths = table_compute_column_widths(row_buffer)
|
|
table_print_row_buffer(row_buffer, column_widths, column_types)
|
|
del row_buffer[:]
|
|
table_print_row(values, column_widths, column_types)
|
|
|
|
if row_buffer:
|
|
column_widths = table_compute_column_widths(row_buffer)
|
|
table_print_row_buffer(row_buffer, column_widths, column_types)
|
|
|
|
if column_widths:
|
|
table_print_bottom(column_widths)
|
|
|
|
print("Retrieved {0:,d} row{1:s} in {2:.2f}s.".format(nrows, 's' if nrows != 1 else '', time.time() - start))
|
|
print("")
|
|
|
|
def display_query(url, sql, context, args):
|
|
rows = do_query_with_args(url, sql, context, args)
|
|
|
|
if args.format == 'csv':
|
|
print_csv(rows, args.header)
|
|
elif args.format == 'tsv':
|
|
print_tsv(rows, args.header, args.tsv_delimiter)
|
|
elif args.format == 'json':
|
|
print_json(rows)
|
|
elif args.format == 'table':
|
|
print_table(rows)
|
|
|
|
def sql_literal_escape(s):
|
|
if s is None:
|
|
return "''"
|
|
elif isinstance(s, unicode):
|
|
ustr = s
|
|
else:
|
|
ustr = str(s).decode('utf-8')
|
|
|
|
escaped = [u"U&'"]
|
|
|
|
for c in ustr:
|
|
ccategory = unicodedata.category(c)
|
|
if ccategory.startswith('L') or ccategory.startswith('N') or c == ' ':
|
|
escaped.append(c)
|
|
else:
|
|
escaped.append(u'\\')
|
|
escaped.append('%04x' % ord(c))
|
|
|
|
escaped.append("'")
|
|
return ''.join(escaped)
|
|
|
|
def make_readline_completer(url, context, args):
|
|
starters = [
|
|
'EXPLAIN PLAN FOR',
|
|
'SELECT'
|
|
]
|
|
|
|
middlers = [
|
|
'FROM',
|
|
'WHERE',
|
|
'GROUP BY',
|
|
'ORDER BY',
|
|
'LIMIT'
|
|
]
|
|
|
|
def readline_completer(text, state):
|
|
if readline.get_begidx() == 0:
|
|
results = [x for x in starters if x.startswith(text.upper())] + [None]
|
|
else:
|
|
results = ([x for x in middlers if x.startswith(text.upper())] + [None])
|
|
|
|
return results[state] + " "
|
|
|
|
print("Connected to [" + args.host + "].")
|
|
print("")
|
|
|
|
return readline_completer
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Druid SQL command-line client.')
|
|
parser_cnn = parser.add_argument_group('Connection options')
|
|
parser_fmt = parser.add_argument_group('Formatting options')
|
|
parser_oth = parser.add_argument_group('Other options')
|
|
parser_cnn.add_argument('--host', '-H', type=str, default='http://localhost:8082/', help='Druid query host or url, like https://localhost:8282/')
|
|
parser_cnn.add_argument('--user', '-u', type=str, help='HTTP basic authentication credentials, like user:password')
|
|
parser_cnn.add_argument('--timeout', type=int, default=0, help='Timeout in seconds')
|
|
parser_cnn.add_argument('--cafile', type=str, help='Path to SSL CA file for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
|
|
parser_cnn.add_argument('--capath', type=str, help='SSL CA path for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
|
|
parser_cnn.add_argument('--ignore-ssl-verification', '-k', action='store_true', default=False, help='Skip verification of SSL certificates.')
|
|
parser_fmt.add_argument('--format', type=str, default='table', choices=('csv', 'tsv', 'json', 'table'), help='Result format')
|
|
parser_fmt.add_argument('--header', action='store_true', help='Include header row for formats "csv" and "tsv"')
|
|
parser_fmt.add_argument('--tsv-delimiter', type=str, default='\t', help='Delimiter for format "tsv"')
|
|
parser_oth.add_argument('--context-option', '-c', type=str, action='append', help='Set context option for this connection, see https://druid.apache.org/docs/latest/querying/sql.html#connection-context for options')
|
|
parser_oth.add_argument('--execute', '-e', type=str, help='Execute single SQL query')
|
|
parser_cnn.add_argument('--certchain', type=str, help='Path to SSL certificate used to connect to server. See load_cert_chain() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
|
|
parser_cnn.add_argument('--keyfile', type=str, help='Path to private SSL key used to connect to server. See load_cert_chain() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
|
|
parser_cnn.add_argument('--keypass', type=str, help='Password to private SSL key file used to connect to server. See load_cert_chain() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
|
|
args = parser.parse_args()
|
|
|
|
# Build broker URL
|
|
url = args.host.rstrip('/') + '/druid/v2/sql/'
|
|
if not url.startswith('http:') and not url.startswith('https:'):
|
|
url = 'http://' + url
|
|
|
|
# Build context
|
|
context = {}
|
|
if args.context_option:
|
|
for opt in args.context_option:
|
|
kv = opt.split("=", 1)
|
|
if len(kv) != 2:
|
|
raise ValueError('Invalid context option, should be key=value: ' + opt)
|
|
if re.match(r"^\d+$", kv[1]):
|
|
context[kv[0]] = long(kv[1])
|
|
else:
|
|
context[kv[0]] = kv[1]
|
|
|
|
if args.execute:
|
|
display_query(url, args.execute, context, args)
|
|
else:
|
|
# interactive mode
|
|
print("Welcome to dsql, the command-line client for Druid SQL.")
|
|
|
|
readline_history_file = os.path.expanduser("~/.dsql_history")
|
|
readline.parse_and_bind('tab: complete')
|
|
readline.set_history_length(500)
|
|
readline.set_completer(make_readline_completer(url, context, args))
|
|
|
|
try:
|
|
readline.read_history_file(readline_history_file)
|
|
except IOError:
|
|
# IOError can happen if the file doesn't exist.
|
|
pass
|
|
|
|
print("Type \"\\h\" for help.")
|
|
|
|
while True:
|
|
sql = ''
|
|
while not sql.endswith(';'):
|
|
prompt = "dsql> " if sql == '' else 'more> '
|
|
try:
|
|
more_sql = raw_input(prompt)
|
|
except EOFError:
|
|
sys.stdout.write('\n')
|
|
sys.exit(1)
|
|
if sql == '' and more_sql.startswith('\\'):
|
|
# backslash command
|
|
dmatch = re.match(r'^\\d(S?)(\+?)(\s+.*?|)\s*$', more_sql)
|
|
if dmatch:
|
|
include_system = dmatch.group(1)
|
|
extra_info = dmatch.group(2)
|
|
arg = dmatch.group(3).strip()
|
|
if arg:
|
|
sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = " + sql_literal_escape(arg)
|
|
if not include_system:
|
|
sql = sql + " AND TABLE_SCHEMA = 'druid'"
|
|
# break to execute sql
|
|
break
|
|
else:
|
|
sql = "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES"
|
|
if not include_system:
|
|
sql = sql + " WHERE TABLE_SCHEMA = 'druid'"
|
|
# break to execute sql
|
|
break
|
|
|
|
hmatch = re.match(r'^\\h\s*$', more_sql)
|
|
if hmatch:
|
|
print("Commands:")
|
|
print(" \\d show tables")
|
|
print(" \\dS show tables, including system tables")
|
|
print(" \\d table_name describe table")
|
|
print(" \\h show this help")
|
|
print(" \\q exit this program")
|
|
print("Or enter a SQL query ending with a semicolon (;).")
|
|
continue
|
|
|
|
qmatch = re.match(r'^\\q\s*$', more_sql)
|
|
if qmatch:
|
|
sys.exit(0)
|
|
|
|
print("No such command: " + more_sql)
|
|
else:
|
|
sql = (sql + ' ' + more_sql).strip()
|
|
|
|
try:
|
|
readline.write_history_file(readline_history_file)
|
|
display_query(url, sql.rstrip(';'), context, args)
|
|
except DruidSqlException as e:
|
|
e.write_to(sys.stdout)
|
|
except KeyboardInterrupt:
|
|
sys.stdout.write("Query interrupted\n")
|
|
sys.stdout.flush()
|
|
|
|
try:
|
|
main()
|
|
except DruidSqlException as e:
|
|
e.write_to(sys.stderr)
|
|
sys.exit(1)
|
|
except KeyboardInterrupt:
|
|
sys.exit(1)
|
|
except IOError as e:
|
|
if e.errno == errno.EPIPE:
|
|
sys.exit(1)
|
|
else:
|
|
raise
|