#!/usr/bin/env python # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import print_function import argparse import base64 import collections import csv import errno import json import numbers import re import ssl import sys import time import unicodedata import urllib2 class DruidSqlException(Exception): def write_to(self, f): f.write('\x1b[31m') f.write(self.message if self.message else "Query failed") f.write('\x1b[0m') f.write('\n') f.flush() def do_query(url, sql, context, timeout, user, password, ignore_ssl_verification, ca_file, ca_path): json_decoder = json.JSONDecoder(object_pairs_hook=collections.OrderedDict) try: sql_json = json.dumps({'query' : sql, 'context' : context}) # SSL stuff ssl_context = None; if (ignore_ssl_verification or ca_file != None or ca_path != None): ssl_context = ssl.create_default_context() if (ignore_ssl_verification): ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE else: ssl_context.load_verify_locations(cafile=ca_file, capath=ca_path) req = urllib2.Request(url, sql_json, {'Content-Type' : 'application/json'}) if timeout <= 0: timeout = None if (user and password): basicAuthEncoding = base64.b64encode('%s:%s' % (user, password)) req.add_header("Authorization", "Basic %s" % basicAuthEncoding) response = urllib2.urlopen(req, None, timeout, context=ssl_context) first_chunk = True eof = False buf = '' while not eof or len(buf) > 0: while True: try: # Remove starting ',' buf = buf.lstrip(',') obj, sz = json_decoder.raw_decode(buf) yield obj buf = buf[sz:] except ValueError as e: # Maybe invalid JSON, maybe partial object; it's hard to tell with this library. if eof and buf.rstrip() == ']': # Stream done and all objects read. buf = '' break elif eof or len(buf) > 256 * 1024: # If we read more than 256KB or if it's eof then report the parse error. raise else: # Stop reading objects, get more from the stream instead. break # Read more from the http stream if not eof: chunk = response.read(8192) if chunk: buf = buf + chunk if first_chunk: # Remove starting '[' buf = buf.lstrip('[') else: # Stream done. Keep reading objects out of buf though. eof = True except urllib2.URLError as e: raise_friendly_error(e) def raise_friendly_error(e): if isinstance(e, urllib2.HTTPError): text = e.read().strip() error_obj = {} try: error_obj = dict(json.loads(text)) except: pass if e.code == 500 and 'errorMessage' in error_obj: error_text = '' if error_obj['error'] != 'Unknown exception': error_text = error_text + error_obj['error'] + ': ' if error_obj['errorClass']: error_text = error_text + str(error_obj['errorClass']) + ': ' error_text = error_text + str(error_obj['errorMessage']) if error_obj['host']: error_text = error_text + ' (' + str(error_obj['host']) + ')' raise DruidSqlException(error_text) else: raise DruidSqlException("HTTP Error {0}: {1}\n{2}".format(e.code, e.reason, text)) else: raise DruidSqlException(str(e)) def to_utf8(value): if value is None: return "" elif isinstance(value, unicode): return value.encode("utf-8") else: return str(value) def to_tsv(values, delimiter): return delimiter.join(to_utf8(v).replace(delimiter, '') for v in values) def print_csv(rows, header): csv_writer = csv.writer(sys.stdout) first = True for row in rows: if first and header: csv_writer.writerow(list(to_utf8(k) for k in row.keys())) first = False values = [] for key, value in row.iteritems(): values.append(to_utf8(value)) csv_writer.writerow(values) def print_tsv(rows, header, tsv_delimiter): first = True for row in rows: if first and header: print(to_tsv(row.keys(), tsv_delimiter)) first = False values = [] for key, value in row.iteritems(): values.append(value) print(to_tsv(values, tsv_delimiter)) def print_json(rows): for row in rows: print(json.dumps(row)) def table_to_printable_value(value): # Unicode string, trimmed with control characters removed if value is None: return u"NULL" else: return to_utf8(value).strip().decode('utf-8').translate(dict.fromkeys(range(32))) def table_compute_string_width(v): normalized = unicodedata.normalize('NFC', v) width = 0 for c in normalized: ccategory = unicodedata.category(c) cwidth = unicodedata.east_asian_width(c) if ccategory == 'Cf': # Formatting control, zero width pass elif cwidth == 'F' or cwidth == 'W': # Double-wide character, prints in two columns width = width + 2 else: # All other characters width = width + 1 return width def table_compute_column_widths(row_buffer): widths = None for values in row_buffer: values_widths = [table_compute_string_width(v) for v in values] if not widths: widths = values_widths else: i = 0 for v in values: widths[i] = max(widths[i], values_widths[i]) i = i + 1 return widths def table_print_row(values, column_widths, column_types): vertical_line = u'\u2502'.encode('utf-8') for i in xrange(0, len(values)): padding = ' ' * max(0, column_widths[i] - table_compute_string_width(values[i])) if column_types and column_types[i] == 'n': print(vertical_line + ' ' + padding + values[i].encode('utf-8') + ' ', end="") else: print(vertical_line + ' ' + values[i].encode('utf-8') + padding + ' ', end="") print(vertical_line) def table_print_header(values, column_widths): # Line 1 left_corner = u'\u250C'.encode('utf-8') horizontal_line = u'\u2500'.encode('utf-8') top_tee = u'\u252C'.encode('utf-8') right_corner = u'\u2510'.encode('utf-8') print(left_corner, end="") for i in xrange(0, len(column_widths)): print(horizontal_line * max(0, column_widths[i] + 2), end="") if i + 1 < len(column_widths): print(top_tee, end="") print(right_corner) # Line 2 table_print_row(values, column_widths, None) # Line 3 left_tee = u'\u251C'.encode('utf-8') cross = u'\u253C'.encode('utf-8') right_tee = u'\u2524'.encode('utf-8') print(left_tee, end="") for i in xrange(0, len(column_widths)): print(horizontal_line * max(0, column_widths[i] + 2), end="") if i + 1 < len(column_widths): print(cross, end="") print(right_tee) def table_print_bottom(column_widths): left_corner = u'\u2514'.encode('utf-8') right_corner = u'\u2518'.encode('utf-8') bottom_tee = u'\u2534'.encode('utf-8') horizontal_line = u'\u2500'.encode('utf-8') print(left_corner, end="") for i in xrange(0, len(column_widths)): print(horizontal_line * max(0, column_widths[i] + 2), end="") if i + 1 < len(column_widths): print(bottom_tee, end="") print(right_corner) def table_print_row_buffer(row_buffer, column_widths, column_types): first = True for values in row_buffer: if first: table_print_header(values, column_widths) first = False else: table_print_row(values, column_widths, column_types) def print_table(rows): start = time.time() nrows = 0 first = True # Buffer some rows before printing. rows_to_buffer = 500 row_buffer = [] column_types = [] column_widths = None for row in rows: nrows = nrows + 1 if first: row_buffer.append([table_to_printable_value(k) for k in row.keys()]) for k in row.keys(): if isinstance(row[k], numbers.Number): column_types.append('n') else: column_types.append('s') first = False values = [table_to_printable_value(v) for k, v in row.iteritems()] if rows_to_buffer > 0: row_buffer.append(values) rows_to_buffer = rows_to_buffer - 1 else: if row_buffer: column_widths = table_compute_column_widths(row_buffer) table_print_row_buffer(row_buffer, column_widths, column_types) del row_buffer[:] table_print_row(values, column_widths, column_types) if row_buffer: column_widths = table_compute_column_widths(row_buffer) table_print_row_buffer(row_buffer, column_widths, column_types) if column_widths: table_print_bottom(column_widths) print("Retrieved {0:,d} row{1:s} in {2:.2f}s.".format(nrows, 's' if nrows != 1 else '', time.time() - start)) print("") def display_query(url, sql, context, args): rows = do_query(url, sql, context, args.timeout, args.user, args.password, args.ignore_ssl_verification, args.cafile, args.capath) if args.format == 'csv': print_csv(rows, args.header) elif args.format == 'tsv': print_tsv(rows, args.header, args.tsv_delimiter) elif args.format == 'json': print_json(rows) elif args.format == 'table': print_table(rows) def sql_escape(s): if s is None: return "''" elif isinstance(s, unicode): ustr = s else: ustr = str(s).decode('utf-8') escaped = [u"U&'"] for c in ustr: ccategory = unicodedata.category(c) if ccategory.startswith('L') or ccategory.startswith('N') or c == ' ': escaped.append(c) else: escaped.append(u'\\') escaped.append('%04x' % ord(c)) escaped.append("'") return ''.join(escaped) def main(): parser = argparse.ArgumentParser(description='Druid SQL command-line client.') parser.add_argument('--host', '-H', type=str, default='http://localhost:8082/', help='Broker host or url') parser.add_argument('--timeout', type=int, default=0, help='Timeout in seconds, 0 for no timeout') parser.add_argument('--format', type=str, default='table', choices=('csv', 'tsv', 'json', 'table'), help='Result format') parser.add_argument('--header', action='store_true', help='Include header row for formats "csv" and "tsv"') parser.add_argument('--tsv-delimiter', type=str, default='\t', help='Delimiter for format "tsv"') parser.add_argument('--context-option', '-c', type=str, action='append', help='Set context option for this connection') parser.add_argument('--execute', '-e', type=str, help='Execute single SQL query') parser.add_argument('--user', '-u', type=str, help='Username for HTTP basic auth') parser.add_argument('--password', '-p', type=str, help='Password for HTTP basic auth') parser.add_argument('--ignore-ssl-verification', '-k', action='store_true', default=False, help='Skip verification of SSL certificates.') parser.add_argument('--cafile', type=str, help='Path to SSL CA file for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') parser.add_argument('--capath', type=str, help='SSL CA path for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') args = parser.parse_args() # Build broker URL url = args.host.rstrip('/') + '/druid/v2/sql/' if not url.startswith('http:') and not url.startswith('https:'): url = 'http://' + url # Build context context = {} if args.context_option: for opt in args.context_option: kv = opt.split("=", 1) if len(kv) != 2: raise ValueError('Invalid context option, should be key=value: ' + opt) if re.match(r"^\d+$", kv[1]): context[kv[0]] = long(kv[1]) else: context[kv[0]] = kv[1] if args.execute: display_query(url, args.execute, context, args) else: # interactive mode print("Welcome to dsql, the command-line client for Druid SQL.") print("Type \"\h\" for help.") while True: sql = '' while not sql.endswith(';'): prompt = "dsql> " if sql == '' else 'more> ' try: more_sql = raw_input(prompt) except EOFError: sys.stdout.write('\n') sys.exit(1) if sql == '' and more_sql.startswith('\\'): # backslash command dmatch = re.match(r'^\\d(S?)(\+?)(\s+.*?|)\s*$', more_sql) if dmatch: include_system = dmatch.group(1) extra_info = dmatch.group(2) arg = dmatch.group(3).strip() if arg: sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = " + sql_escape(arg) if not include_system: sql = sql + " AND TABLE_SCHEMA = 'druid'" # break to execute sql break else: sql = "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES"; if not include_system: sql = sql + " WHERE TABLE_SCHEMA = 'druid'" # break to execute sql break hmatch = re.match(r'^\\h\s*$', more_sql) if hmatch: print("Commands:") print(" \d show tables") print(" \dS show tables, including system tables") print(" \d table_name describe table") print(" \h show this help") print(" \q exit this program") print("Or enter a SQL query ending with a semicolon (;).") continue qmatch = re.match(r'^\\q\s*$', more_sql) if qmatch: sys.exit(0) print("No such command: " + more_sql) else: sql = (sql + ' ' + more_sql).strip() try: display_query(url, sql.rstrip(';'), context, args) except DruidSqlException as e: e.write_to(sys.stdout) except KeyboardInterrupt: sys.stdout.write("Query interrupted\n") sys.stdout.flush() try: main() except DruidSqlException as e: e.write_to(sys.stderr) sys.exit(1) except KeyboardInterrupt: sys.exit(1) except IOError as e: if e.errno == errno.EPIPE: sys.exit(1) else: raise