Enhancements to dsql. (#6929)

- CLI history, basic autocomplete through deadline.
- Include timeout in query context.
- Group CLI options into... groups.
This commit is contained in:
Gian Merlino 2019-01-28 20:02:43 -05:00 committed by Fangjin Yang
parent 8d70ba69cf
commit ac4c7e21a2
2 changed files with 85 additions and 38 deletions

View File

@ -21,16 +21,9 @@ PWD="$(pwd)"
WHEREAMI="$(dirname "$0")" WHEREAMI="$(dirname "$0")"
WHEREAMI="$(cd "$WHEREAMI" && pwd)" WHEREAMI="$(cd "$WHEREAMI" && pwd)"
RLWRAP=""
if [ -x "$(command -v rlwrap)" ]
then
RLWRAP="rlwrap -C dsql"
fi
if [ -x "$(command -v python2)" ] if [ -x "$(command -v python2)" ]
then then
exec $RLWRAP python2 "$WHEREAMI/dsql-main" "$@" exec python2 "$WHEREAMI/dsql-main" "$@"
else else
exec $RLWRAP "$WHEREAMI/dsql-main" "$@" exec "$WHEREAMI/dsql-main" "$@"
fi fi

View File

@ -26,7 +26,9 @@ import csv
import errno import errno
import json import json
import numbers import numbers
import os
import re import re
import readline
import ssl import ssl
import sys import sys
import time import time
@ -34,17 +36,30 @@ import unicodedata
import urllib2 import urllib2
class DruidSqlException(Exception): class DruidSqlException(Exception):
def friendly_message(self):
return self.message if self.message else "Query failed"
def write_to(self, f): def write_to(self, f):
f.write('\x1b[31m') f.write('\x1b[31m')
f.write(self.message if self.message else "Query failed") f.write(self.friendly_message())
f.write('\x1b[0m') f.write('\x1b[0m')
f.write('\n') f.write('\n')
f.flush() f.flush()
def do_query(url, sql, context, timeout, user, password, ignore_ssl_verification, ca_file, ca_path): def do_query_with_args(url, sql, context, args):
return do_query(url, sql, context, args.timeout, args.user, args.ignore_ssl_verification, args.cafile, args.capath)
def do_query(url, sql, context, timeout, user, ignore_ssl_verification, ca_file, ca_path):
json_decoder = json.JSONDecoder(object_pairs_hook=collections.OrderedDict) json_decoder = json.JSONDecoder(object_pairs_hook=collections.OrderedDict)
try: try:
sql_json = json.dumps({'query' : sql, 'context' : context}) if timeout <= 0:
timeout = None
query_context = context
elif int(context.get('timeout', 0)) / 1000 < timeout:
query_context = context.copy()
query_context['timeout'] = timeout * 1000;
sql_json = json.dumps({'query' : sql, 'context' : query_context})
# SSL stuff # SSL stuff
ssl_context = None; ssl_context = None;
@ -57,12 +72,9 @@ def do_query(url, sql, context, timeout, user, password, ignore_ssl_verification
ssl_context.load_verify_locations(cafile=ca_file, capath=ca_path) ssl_context.load_verify_locations(cafile=ca_file, capath=ca_path)
req = urllib2.Request(url, sql_json, {'Content-Type' : 'application/json'}) req = urllib2.Request(url, sql_json, {'Content-Type' : 'application/json'})
if timeout <= 0:
timeout = None
if (user and password): if user:
basicAuthEncoding = base64.b64encode('%s:%s' % (user, password)) req.add_header("Authorization", "Basic %s" % base64.b64encode(user))
req.add_header("Authorization", "Basic %s" % basicAuthEncoding)
response = urllib2.urlopen(req, None, timeout, context=ssl_context) response = urllib2.urlopen(req, None, timeout, context=ssl_context)
@ -311,7 +323,7 @@ def print_table(rows):
print("") print("")
def display_query(url, sql, context, args): def display_query(url, sql, context, args):
rows = do_query(url, sql, context, args.timeout, args.user, args.password, args.ignore_ssl_verification, args.cafile, args.capath) rows = do_query_with_args(url, sql, context, args)
if args.format == 'csv': if args.format == 'csv':
print_csv(rows, args.header) print_csv(rows, args.header)
@ -322,7 +334,7 @@ def display_query(url, sql, context, args):
elif args.format == 'table': elif args.format == 'table':
print_table(rows) print_table(rows)
def sql_escape(s): def sql_literal_escape(s):
if s is None: if s is None:
return "''" return "''"
elif isinstance(s, unicode): elif isinstance(s, unicode):
@ -343,20 +355,49 @@ def sql_escape(s):
escaped.append("'") escaped.append("'")
return ''.join(escaped) return ''.join(escaped)
def make_readline_completer(url, context, args):
starters = [
'EXPLAIN PLAN FOR',
'SELECT'
]
middlers = [
'FROM',
'WHERE',
'GROUP BY',
'ORDER BY',
'LIMIT'
]
def readline_completer(text, state):
if readline.get_begidx() == 0:
results = [x for x in starters if x.startswith(text.upper())] + [None]
else:
results = ([x for x in middlers if x.startswith(text.upper())] + [None])
return results[state] + " "
print("Connected to [" + args.host + "].")
print("")
return readline_completer
def main(): def main():
parser = argparse.ArgumentParser(description='Druid SQL command-line client.') parser = argparse.ArgumentParser(description='Druid SQL command-line client.')
parser.add_argument('--host', '-H', type=str, default='http://localhost:8082/', help='Broker host or url') parser_cnn = parser.add_argument_group('Connection options')
parser.add_argument('--timeout', type=int, default=0, help='Timeout in seconds, 0 for no timeout') parser_fmt = parser.add_argument_group('Formatting options')
parser.add_argument('--format', type=str, default='table', choices=('csv', 'tsv', 'json', 'table'), help='Result format') parser_oth = parser.add_argument_group('Other options')
parser.add_argument('--header', action='store_true', help='Include header row for formats "csv" and "tsv"') parser_cnn.add_argument('--host', '-H', type=str, default='http://localhost:8082/', help='Druid query host or url, like https://localhost:8282/')
parser.add_argument('--tsv-delimiter', type=str, default='\t', help='Delimiter for format "tsv"') parser_cnn.add_argument('--user', '-u', type=str, help='HTTP basic authentication credentials, like user:password')
parser.add_argument('--context-option', '-c', type=str, action='append', help='Set context option for this connection') parser_cnn.add_argument('--timeout', type=int, default=0, help='Timeout in seconds')
parser.add_argument('--execute', '-e', type=str, help='Execute single SQL query') parser_cnn.add_argument('--cafile', type=str, help='Path to SSL CA file for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
parser.add_argument('--user', '-u', type=str, help='Username for HTTP basic auth') parser_cnn.add_argument('--capath', type=str, help='SSL CA path for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
parser.add_argument('--password', '-p', type=str, help='Password for HTTP basic auth') parser_cnn.add_argument('--ignore-ssl-verification', '-k', action='store_true', default=False, help='Skip verification of SSL certificates.')
parser.add_argument('--ignore-ssl-verification', '-k', action='store_true', default=False, help='Skip verification of SSL certificates.') parser_fmt.add_argument('--format', type=str, default='table', choices=('csv', 'tsv', 'json', 'table'), help='Result format')
parser.add_argument('--cafile', type=str, help='Path to SSL CA file for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') parser_fmt.add_argument('--header', action='store_true', help='Include header row for formats "csv" and "tsv"')
parser.add_argument('--capath', type=str, help='SSL CA path for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') parser_fmt.add_argument('--tsv-delimiter', type=str, default='\t', help='Delimiter for format "tsv"')
parser_oth.add_argument('--context-option', '-c', type=str, action='append', help='Set context option for this connection, see https://docs.imply.io/on-prem/query-data/sql for options')
parser_oth.add_argument('--execute', '-e', type=str, help='Execute single SQL query')
args = parser.parse_args() args = parser.parse_args()
# Build broker URL # Build broker URL
@ -381,7 +422,19 @@ def main():
else: else:
# interactive mode # interactive mode
print("Welcome to dsql, the command-line client for Druid SQL.") print("Welcome to dsql, the command-line client for Druid SQL.")
print("Type \"\h\" for help.")
readline_history_file = os.path.expanduser("~/.dsql_history")
readline.parse_and_bind('tab: complete')
readline.set_history_length(500)
readline.set_completer(make_readline_completer(url, context, args))
try:
readline.read_history_file(readline_history_file)
except IOError:
# IOError can happen if the file doesn't exist.
pass
print("Type \"\\h\" for help.")
while True: while True:
sql = '' sql = ''
@ -400,7 +453,7 @@ def main():
extra_info = dmatch.group(2) extra_info = dmatch.group(2)
arg = dmatch.group(3).strip() arg = dmatch.group(3).strip()
if arg: if arg:
sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = " + sql_escape(arg) sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = " + sql_literal_escape(arg)
if not include_system: if not include_system:
sql = sql + " AND TABLE_SCHEMA = 'druid'" sql = sql + " AND TABLE_SCHEMA = 'druid'"
# break to execute sql # break to execute sql
@ -415,11 +468,11 @@ def main():
hmatch = re.match(r'^\\h\s*$', more_sql) hmatch = re.match(r'^\\h\s*$', more_sql)
if hmatch: if hmatch:
print("Commands:") print("Commands:")
print(" \d show tables") print(" \\d show tables")
print(" \dS show tables, including system tables") print(" \\dS show tables, including system tables")
print(" \d table_name describe table") print(" \\d table_name describe table")
print(" \h show this help") print(" \\h show this help")
print(" \q exit this program") print(" \\q exit this program")
print("Or enter a SQL query ending with a semicolon (;).") print("Or enter a SQL query ending with a semicolon (;).")
continue continue
@ -432,6 +485,7 @@ def main():
sql = (sql + ' ' + more_sql).strip() sql = (sql + ' ' + more_sql).strip()
try: try:
readline.write_history_file(readline_history_file)
display_query(url, sql.rstrip(';'), context, args) display_query(url, sql.rstrip(';'), context, args)
except DruidSqlException as e: except DruidSqlException as e:
e.write_to(sys.stdout) e.write_to(sys.stdout)