#!/usr/bin/env python # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import print_function import argparse import base64 import collections import csv import errno import json import numbers import os import re import readline import ssl import sys import time import unicodedata import urllib2 class DruidSqlException(Exception): def friendly_message(self): return self.message if self.message else "Query failed" def write_to(self, f): f.write('\x1b[31m') f.write(self.friendly_message()) f.write('\x1b[0m') f.write('\n') f.flush() def do_query_with_args(url, sql, context, args): return do_query(url, sql, context, args.timeout, args.user, args.ignore_ssl_verification, args.cafile, args.capath) def do_query(url, sql, context, timeout, user, ignore_ssl_verification, ca_file, ca_path): json_decoder = json.JSONDecoder(object_pairs_hook=collections.OrderedDict) try: if timeout <= 0: timeout = None query_context = context elif int(context.get('timeout', 0)) / 1000 < timeout: query_context = context.copy() query_context['timeout'] = timeout * 1000; sql_json = json.dumps({'query' : sql, 'context' : query_context}) # SSL stuff ssl_context = None; if (ignore_ssl_verification or ca_file != None or ca_path != None): ssl_context = ssl.create_default_context() if (ignore_ssl_verification): ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE else: ssl_context.load_verify_locations(cafile=ca_file, capath=ca_path) req = urllib2.Request(url, sql_json, {'Content-Type' : 'application/json'}) if user: req.add_header("Authorization", "Basic %s" % base64.b64encode(user)) response = urllib2.urlopen(req, None, timeout, context=ssl_context) first_chunk = True eof = False buf = '' while not eof or len(buf) > 0: while True: try: # Remove starting ',' buf = buf.lstrip(',') obj, sz = json_decoder.raw_decode(buf) yield obj buf = buf[sz:] except ValueError as e: # Maybe invalid JSON, maybe partial object; it's hard to tell with this library. if eof and buf.rstrip() == ']': # Stream done and all objects read. buf = '' break elif eof or len(buf) > 256 * 1024: # If we read more than 256KB or if it's eof then report the parse error. raise else: # Stop reading objects, get more from the stream instead. break # Read more from the http stream if not eof: chunk = response.read(8192) if chunk: buf = buf + chunk if first_chunk: # Remove starting '[' buf = buf.lstrip('[') else: # Stream done. Keep reading objects out of buf though. eof = True except urllib2.URLError as e: raise_friendly_error(e) def raise_friendly_error(e): if isinstance(e, urllib2.HTTPError): text = e.read().strip() error_obj = {} try: error_obj = dict(json.loads(text)) except: pass if e.code == 500 and 'errorMessage' in error_obj: error_text = '' if error_obj['error'] != 'Unknown exception': error_text = error_text + error_obj['error'] + ': ' if error_obj['errorClass']: error_text = error_text + str(error_obj['errorClass']) + ': ' error_text = error_text + str(error_obj['errorMessage']) if error_obj['host']: error_text = error_text + ' (' + str(error_obj['host']) + ')' raise DruidSqlException(error_text) else: raise DruidSqlException("HTTP Error {0}: {1}\n{2}".format(e.code, e.reason, text)) else: raise DruidSqlException(str(e)) def to_utf8(value): if value is None: return "" elif isinstance(value, unicode): return value.encode("utf-8") else: return str(value) def to_tsv(values, delimiter): return delimiter.join(to_utf8(v).replace(delimiter, '') for v in values) def print_csv(rows, header): csv_writer = csv.writer(sys.stdout) first = True for row in rows: if first and header: csv_writer.writerow(list(to_utf8(k) for k in row.keys())) first = False values = [] for key, value in row.iteritems(): values.append(to_utf8(value)) csv_writer.writerow(values) def print_tsv(rows, header, tsv_delimiter): first = True for row in rows: if first and header: print(to_tsv(row.keys(), tsv_delimiter)) first = False values = [] for key, value in row.iteritems(): values.append(value) print(to_tsv(values, tsv_delimiter)) def print_json(rows): for row in rows: print(json.dumps(row)) def table_to_printable_value(value): # Unicode string, trimmed with control characters removed if value is None: return u"NULL" else: return to_utf8(value).strip().decode('utf-8').translate(dict.fromkeys(range(32))) def table_compute_string_width(v): normalized = unicodedata.normalize('NFC', v) width = 0 for c in normalized: ccategory = unicodedata.category(c) cwidth = unicodedata.east_asian_width(c) if ccategory == 'Cf': # Formatting control, zero width pass elif cwidth == 'F' or cwidth == 'W': # Double-wide character, prints in two columns width = width + 2 else: # All other characters width = width + 1 return width def table_compute_column_widths(row_buffer): widths = None for values in row_buffer: values_widths = [table_compute_string_width(v) for v in values] if not widths: widths = values_widths else: i = 0 for v in values: widths[i] = max(widths[i], values_widths[i]) i = i + 1 return widths def table_print_row(values, column_widths, column_types): vertical_line = u'\u2502'.encode('utf-8') for i in xrange(0, len(values)): padding = ' ' * max(0, column_widths[i] - table_compute_string_width(values[i])) if column_types and column_types[i] == 'n': print(vertical_line + ' ' + padding + values[i].encode('utf-8') + ' ', end="") else: print(vertical_line + ' ' + values[i].encode('utf-8') + padding + ' ', end="") print(vertical_line) def table_print_header(values, column_widths): # Line 1 left_corner = u'\u250C'.encode('utf-8') horizontal_line = u'\u2500'.encode('utf-8') top_tee = u'\u252C'.encode('utf-8') right_corner = u'\u2510'.encode('utf-8') print(left_corner, end="") for i in xrange(0, len(column_widths)): print(horizontal_line * max(0, column_widths[i] + 2), end="") if i + 1 < len(column_widths): print(top_tee, end="") print(right_corner) # Line 2 table_print_row(values, column_widths, None) # Line 3 left_tee = u'\u251C'.encode('utf-8') cross = u'\u253C'.encode('utf-8') right_tee = u'\u2524'.encode('utf-8') print(left_tee, end="") for i in xrange(0, len(column_widths)): print(horizontal_line * max(0, column_widths[i] + 2), end="") if i + 1 < len(column_widths): print(cross, end="") print(right_tee) def table_print_bottom(column_widths): left_corner = u'\u2514'.encode('utf-8') right_corner = u'\u2518'.encode('utf-8') bottom_tee = u'\u2534'.encode('utf-8') horizontal_line = u'\u2500'.encode('utf-8') print(left_corner, end="") for i in xrange(0, len(column_widths)): print(horizontal_line * max(0, column_widths[i] + 2), end="") if i + 1 < len(column_widths): print(bottom_tee, end="") print(right_corner) def table_print_row_buffer(row_buffer, column_widths, column_types): first = True for values in row_buffer: if first: table_print_header(values, column_widths) first = False else: table_print_row(values, column_widths, column_types) def print_table(rows): start = time.time() nrows = 0 first = True # Buffer some rows before printing. rows_to_buffer = 500 row_buffer = [] column_types = [] column_widths = None for row in rows: nrows = nrows + 1 if first: row_buffer.append([table_to_printable_value(k) for k in row.keys()]) for k in row.keys(): if isinstance(row[k], numbers.Number): column_types.append('n') else: column_types.append('s') first = False values = [table_to_printable_value(v) for k, v in row.iteritems()] if rows_to_buffer > 0: row_buffer.append(values) rows_to_buffer = rows_to_buffer - 1 else: if row_buffer: column_widths = table_compute_column_widths(row_buffer) table_print_row_buffer(row_buffer, column_widths, column_types) del row_buffer[:] table_print_row(values, column_widths, column_types) if row_buffer: column_widths = table_compute_column_widths(row_buffer) table_print_row_buffer(row_buffer, column_widths, column_types) if column_widths: table_print_bottom(column_widths) print("Retrieved {0:,d} row{1:s} in {2:.2f}s.".format(nrows, 's' if nrows != 1 else '', time.time() - start)) print("") def display_query(url, sql, context, args): rows = do_query_with_args(url, sql, context, args) if args.format == 'csv': print_csv(rows, args.header) elif args.format == 'tsv': print_tsv(rows, args.header, args.tsv_delimiter) elif args.format == 'json': print_json(rows) elif args.format == 'table': print_table(rows) def sql_literal_escape(s): if s is None: return "''" elif isinstance(s, unicode): ustr = s else: ustr = str(s).decode('utf-8') escaped = [u"U&'"] for c in ustr: ccategory = unicodedata.category(c) if ccategory.startswith('L') or ccategory.startswith('N') or c == ' ': escaped.append(c) else: escaped.append(u'\\') escaped.append('%04x' % ord(c)) escaped.append("'") return ''.join(escaped) def make_readline_completer(url, context, args): starters = [ 'EXPLAIN PLAN FOR', 'SELECT' ] middlers = [ 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY', 'LIMIT' ] def readline_completer(text, state): if readline.get_begidx() == 0: results = [x for x in starters if x.startswith(text.upper())] + [None] else: results = ([x for x in middlers if x.startswith(text.upper())] + [None]) return results[state] + " " print("Connected to [" + args.host + "].") print("") return readline_completer def main(): parser = argparse.ArgumentParser(description='Druid SQL command-line client.') parser_cnn = parser.add_argument_group('Connection options') parser_fmt = parser.add_argument_group('Formatting options') parser_oth = parser.add_argument_group('Other options') parser_cnn.add_argument('--host', '-H', type=str, default='http://localhost:8082/', help='Druid query host or url, like https://localhost:8282/') parser_cnn.add_argument('--user', '-u', type=str, help='HTTP basic authentication credentials, like user:password') parser_cnn.add_argument('--timeout', type=int, default=0, help='Timeout in seconds') parser_cnn.add_argument('--cafile', type=str, help='Path to SSL CA file for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') parser_cnn.add_argument('--capath', type=str, help='SSL CA path for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') parser_cnn.add_argument('--ignore-ssl-verification', '-k', action='store_true', default=False, help='Skip verification of SSL certificates.') parser_fmt.add_argument('--format', type=str, default='table', choices=('csv', 'tsv', 'json', 'table'), help='Result format') parser_fmt.add_argument('--header', action='store_true', help='Include header row for formats "csv" and "tsv"') parser_fmt.add_argument('--tsv-delimiter', type=str, default='\t', help='Delimiter for format "tsv"') parser_oth.add_argument('--context-option', '-c', type=str, action='append', help='Set context option for this connection, see https://docs.imply.io/on-prem/query-data/sql for options') parser_oth.add_argument('--execute', '-e', type=str, help='Execute single SQL query') args = parser.parse_args() # Build broker URL url = args.host.rstrip('/') + '/druid/v2/sql/' if not url.startswith('http:') and not url.startswith('https:'): url = 'http://' + url # Build context context = {} if args.context_option: for opt in args.context_option: kv = opt.split("=", 1) if len(kv) != 2: raise ValueError('Invalid context option, should be key=value: ' + opt) if re.match(r"^\d+$", kv[1]): context[kv[0]] = long(kv[1]) else: context[kv[0]] = kv[1] if args.execute: display_query(url, args.execute, context, args) else: # interactive mode print("Welcome to dsql, the command-line client for Druid SQL.") readline_history_file = os.path.expanduser("~/.dsql_history") readline.parse_and_bind('tab: complete') readline.set_history_length(500) readline.set_completer(make_readline_completer(url, context, args)) try: readline.read_history_file(readline_history_file) except IOError: # IOError can happen if the file doesn't exist. pass print("Type \"\\h\" for help.") while True: sql = '' while not sql.endswith(';'): prompt = "dsql> " if sql == '' else 'more> ' try: more_sql = raw_input(prompt) except EOFError: sys.stdout.write('\n') sys.exit(1) if sql == '' and more_sql.startswith('\\'): # backslash command dmatch = re.match(r'^\\d(S?)(\+?)(\s+.*?|)\s*$', more_sql) if dmatch: include_system = dmatch.group(1) extra_info = dmatch.group(2) arg = dmatch.group(3).strip() if arg: sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = " + sql_literal_escape(arg) if not include_system: sql = sql + " AND TABLE_SCHEMA = 'druid'" # break to execute sql break else: sql = "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES"; if not include_system: sql = sql + " WHERE TABLE_SCHEMA = 'druid'" # break to execute sql break hmatch = re.match(r'^\\h\s*$', more_sql) if hmatch: print("Commands:") print(" \\d show tables") print(" \\dS show tables, including system tables") print(" \\d table_name describe table") print(" \\h show this help") print(" \\q exit this program") print("Or enter a SQL query ending with a semicolon (;).") continue qmatch = re.match(r'^\\q\s*$', more_sql) if qmatch: sys.exit(0) print("No such command: " + more_sql) else: sql = (sql + ' ' + more_sql).strip() try: readline.write_history_file(readline_history_file) display_query(url, sql.rstrip(';'), context, args) except DruidSqlException as e: e.write_to(sys.stdout) except KeyboardInterrupt: sys.stdout.write("Query interrupted\n") sys.stdout.flush() try: main() except DruidSqlException as e: e.write_to(sys.stderr) sys.exit(1) except KeyboardInterrupt: sys.exit(1) except IOError as e: if e.errno == errno.EPIPE: sys.exit(1) else: raise