SOLR-8038: Add StatsStream to the Streaming API and wire it into the SQLHandler

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1704973 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Joel Bernstein 2015-09-24 01:49:09 +00:00
parent f959fa2e4b
commit fb881a4f0f
7 changed files with 463 additions and 5 deletions

View File

@ -69,6 +69,8 @@ New Features
* SOLR-7986: JDBC Driver for SQL Interface (Uwe Schindler, Joel Bernstein)
* SOLR-8038: Add the StatsStream to the Streaming API and wire it into the SQLHandler (Joel Bernstein)
Optimizations
----------------------
* SOLR-7876: Speed up queries and operations that use many terms when timeAllowed has not been

View File

@ -22,7 +22,6 @@ import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Locale;
import java.util.Map.Entry;
import java.util.Set;
import com.facebook.presto.sql.ExpressionFormatter;
@ -39,15 +38,16 @@ import org.apache.solr.client.solrj.io.stream.FacetStream;
import org.apache.solr.client.solrj.io.stream.ParallelStream;
import org.apache.solr.client.solrj.io.stream.RankStream;
import org.apache.solr.client.solrj.io.stream.RollupStream;
import org.apache.solr.client.solrj.io.stream.StatsStream;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.ExceptionStream;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.io.stream.metrics.*;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.CoreContainer;
import org.apache.solr.core.SolrCore;
import org.apache.solr.request.SolrQueryRequest;
@ -56,7 +56,6 @@ import org.apache.solr.util.plugin.SolrCoreAware;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import org.apache.solr.client.solrj.io.stream.metrics.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -291,6 +290,12 @@ public class SQLHandler extends RequestHandlerBase implements SolrCoreAware {
private static TupleStream doSelect(SQLVisitor sqlVisitor) throws IOException {
List<String> fields = sqlVisitor.fields;
Set<String> fieldSet = new HashSet();
Metric[] metrics = getMetrics(fields, fieldSet);
if(metrics.length > 0) {
return doAggregates(sqlVisitor, metrics);
}
StringBuilder flbuf = new StringBuilder();
boolean comma = false;
@ -396,6 +401,28 @@ public class SQLHandler extends RequestHandlerBase implements SolrCoreAware {
return true;
}
private static TupleStream doAggregates(SQLVisitor sqlVisitor, Metric[] metrics) throws IOException {
if(metrics.length != sqlVisitor.fields.size()) {
throw new IOException("Only aggregate functions are allowed when group by is not specified.");
}
TableSpec tableSpec = new TableSpec(sqlVisitor.table, defaultZkhost);
String zkHost = tableSpec.zkHost;
String collection = tableSpec.collection;
Map<String, String> params = new HashMap();
params.put(CommonParams.Q, sqlVisitor.query);
TupleStream tupleStream = new StatsStream(zkHost,
collection,
params,
metrics);
return tupleStream;
}
private static String bucketSort(Bucket[] buckets, String dir) {
StringBuilder buf = new StringBuilder();
boolean comma = false;

View File

@ -28,7 +28,14 @@ import com.facebook.presto.sql.tree.Statement;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.ExceptionStream;
import org.apache.solr.client.solrj.io.stream.SolrStream;
import org.apache.solr.client.solrj.io.stream.StatsStream;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.metrics.CountMetric;
import org.apache.solr.client.solrj.io.stream.metrics.MaxMetric;
import org.apache.solr.client.solrj.io.stream.metrics.MeanMetric;
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
import org.apache.solr.client.solrj.io.stream.metrics.MinMetric;
import org.apache.solr.client.solrj.io.stream.metrics.SumMetric;
import org.apache.solr.cloud.AbstractFullDistribZkTestBase;
import org.apache.solr.common.params.CommonParams;
import org.junit.After;
@ -89,6 +96,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
testBasicSelect();
testBasicGrouping();
testBasicGroupingFacets();
testAggregatesWithoutGrouping();
testSQLException();
testTimeSeriesGrouping();
testTimeSeriesGroupingFacet();
@ -809,6 +817,138 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
}
private void testAggregatesWithoutGrouping() throws Exception {
CloudJettyRunner jetty = this.cloudJettys.get(0);
del("*:*");
commit();
indexr(id, "0", "a_s", "hello0", "a_i", "0", "a_f", "1");
indexr(id, "2", "a_s", "hello0", "a_i", "2", "a_f", "2");
indexr(id, "3", "a_s", "hello3", "a_i", "3", "a_f", "3");
indexr(id, "4", "a_s", "hello4", "a_i", "4", "a_f", "4");
indexr(id, "1", "a_s", "hello0", "a_i", "1", "a_f", "5");
indexr(id, "5", "a_s", "hello3", "a_i", "10", "a_f", "6");
indexr(id, "6", "a_s", "hello4", "a_i", "11", "a_f", "7");
indexr(id, "7", "a_s", "hello3", "a_i", "12", "a_f", "8");
indexr(id, "8", "a_s", "hello3", "a_i", "13", "a_f", "9");
indexr(id, "9", "a_s", "hello0", "a_i", "14", "a_f", "10");
commit();
Map params = new HashMap();
params.put(CommonParams.QT, "/sql");
params.put("sql", "select count(*), sum(a_i), min(a_i), max(a_i), avg(a_i), sum(a_f), min(a_f), max(a_f), avg(a_f) from collection1");
SolrStream solrStream = new SolrStream(jetty.url, params);
List<Tuple> tuples = getTuples(solrStream);
assert(tuples.size() == 1);
//Test Long and Double Sums
Tuple tuple = tuples.get(0);
Double sumi = tuple.getDouble("sum(a_i)");
Double sumf = tuple.getDouble("sum(a_f)");
Double mini = tuple.getDouble("min(a_i)");
Double minf = tuple.getDouble("min(a_f)");
Double maxi = tuple.getDouble("max(a_i)");
Double maxf = tuple.getDouble("max(a_f)");
Double avgi = tuple.getDouble("avg(a_i)");
Double avgf = tuple.getDouble("avg(a_f)");
Double count = tuple.getDouble("count(*)");
assertTrue(sumi.longValue() == 70);
assertTrue(sumf.doubleValue() == 55.0D);
assertTrue(mini.doubleValue() == 0.0D);
assertTrue(minf.doubleValue() == 1.0D);
assertTrue(maxi.doubleValue() == 14.0D);
assertTrue(maxf.doubleValue() == 10.0D);
assertTrue(avgi.doubleValue() == 7.0D);
assertTrue(avgf.doubleValue() == 5.5D);
assertTrue(count.doubleValue() == 10);
// Test where clause hits
params = new HashMap();
params.put(CommonParams.QT, "/sql");
params.put("sql", "select count(*), sum(a_i), min(a_i), max(a_i), avg(a_i), sum(a_f), min(a_f), max(a_f), avg(a_f) from collection1 where id = 2");
solrStream = new SolrStream(jetty.url, params);
tuples = getTuples(solrStream);
assert(tuples.size() == 1);
tuple = tuples.get(0);
sumi = tuple.getDouble("sum(a_i)");
sumf = tuple.getDouble("sum(a_f)");
mini = tuple.getDouble("min(a_i)");
minf = tuple.getDouble("min(a_f)");
maxi = tuple.getDouble("max(a_i)");
maxf = tuple.getDouble("max(a_f)");
avgi = tuple.getDouble("avg(a_i)");
avgf = tuple.getDouble("avg(a_f)");
count = tuple.getDouble("count(*)");
assertTrue(sumi.longValue() == 2);
assertTrue(sumf.doubleValue() == 2.0D);
assertTrue(mini == 2);
assertTrue(minf == 2);
assertTrue(maxi == 2);
assertTrue(maxf == 2);
assertTrue(avgi.doubleValue() == 2.0D);
assertTrue(avgf.doubleValue() == 2.0);
assertTrue(count.doubleValue() == 1);
// Test zero hits
params = new HashMap();
params.put(CommonParams.QT, "/sql");
params.put("sql", "select count(*), sum(a_i), min(a_i), max(a_i), avg(a_i), sum(a_f), min(a_f), max(a_f), avg(a_f) from collection1 where a_s = 'blah'");
solrStream = new SolrStream(jetty.url, params);
tuples = getTuples(solrStream);
assert(tuples.size() == 1);
tuple = tuples.get(0);
sumi = tuple.getDouble("sum(a_i)");
sumf = tuple.getDouble("sum(a_f)");
mini = tuple.getDouble("min(a_i)");
minf = tuple.getDouble("min(a_f)");
maxi = tuple.getDouble("max(a_i)");
maxf = tuple.getDouble("max(a_f)");
avgi = tuple.getDouble("avg(a_i)");
avgf = tuple.getDouble("avg(a_f)");
count = tuple.getDouble("count(*)");
assertTrue(sumi.longValue() == 0);
assertTrue(sumf.doubleValue() == 0.0D);
assertTrue(mini == null);
assertTrue(minf == null);
assertTrue(maxi == null);
assertTrue(maxf == null);
assertTrue(Double.isNaN(avgi));
assertTrue(Double.isNaN(avgf));
assertTrue(count.doubleValue() == 0);
del("*:*");
commit();
}
private void testTimeSeriesGrouping() throws Exception {
try {

View File

@ -71,6 +71,11 @@ public class Tuple implements Cloneable {
public Long getLong(Object key) {
Object o = this.fields.get(key);
if(o == null) {
return null;
}
if(o instanceof Long) {
return (Long)o;
} else {
@ -81,6 +86,11 @@ public class Tuple implements Cloneable {
public Double getDouble(Object key) {
Object o = this.fields.get(key);
if(o == null) {
return null;
}
if(o instanceof Double) {
return (Double)o;
} else {

View File

@ -190,7 +190,12 @@ class ResultSetImpl implements ResultSet {
@Override
public long getLong(String columnLabel) throws SQLException {
return tuple.getLong(columnLabel);
Long l = tuple.getLong(columnLabel);
if(l == null) {
return 0;
} else {
return l.longValue();
}
}
@Override
@ -200,7 +205,12 @@ class ResultSetImpl implements ResultSet {
@Override
public double getDouble(String columnLabel) throws SQLException {
return tuple.getDouble(columnLabel);
Double d = tuple.getDouble(columnLabel);
if(d == null) {
return 0.0D;
} else {
return d.doubleValue();
}
}
@Override

View File

@ -0,0 +1,201 @@
package org.apache.solr.client.solrj.io.stream;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.solr.client.solrj.impl.CloudSolrClient;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
import org.apache.solr.client.solrj.request.QueryRequest;
import org.apache.solr.common.SolrDocumentList;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.util.NamedList;
public class StatsStream extends TupleStream {
private static final long serialVersionUID = 1;
private Metric[] metrics;
private String zkHost;
private Tuple tuple;
private Map<String, String> props;
private String collection;
private boolean done;
private long count;
private boolean doCount;
protected transient SolrClientCache cache;
protected transient CloudSolrClient cloudSolrClient;
public StatsStream(String zkHost,
String collection,
Map<String, String> props,
Metric[] metrics) {
this.zkHost = zkHost;
this.props = props;
this.metrics = metrics;
this.collection = collection;
}
public void setStreamContext(StreamContext context) {
cache = context.getSolrClientCache();
}
public List<TupleStream> children() {
List<TupleStream> l = new ArrayList();
return l;
}
public void open() throws IOException {
if(cache != null) {
cloudSolrClient = cache.getCloudSolrClient(zkHost);
} else {
cloudSolrClient = new CloudSolrClient(zkHost);
}
ModifiableSolrParams params = getParams(this.props);
addStats(params, metrics);
params.add("stats", "true");
params.add("rows", "0");
QueryRequest request = new QueryRequest(params);
try {
NamedList response = cloudSolrClient.request(request, collection);
this.tuple = getTuple(response);
} catch (Exception e) {
throw new IOException(e);
}
}
public void close() throws IOException {
if(cache == null) {
cloudSolrClient.close();
}
}
public Tuple read() throws IOException {
if(!done) {
done = true;
return tuple;
} else {
Map fields = new HashMap();
fields.put("EOF", true);
Tuple tuple = new Tuple(fields);
return tuple;
}
}
public StreamComparator getStreamSort() {
return null;
}
private void addStats(ModifiableSolrParams params, Metric[] _metrics) {
Map<String, List<String>> m = new HashMap();
for(Metric metric : _metrics) {
String metricId = metric.getIdentifier();
if(metricId.contains("(")) {
metricId = metricId.substring(0, metricId.length()-1);
String[] parts = metricId.split("\\(");
String function = parts[0];
String column = parts[1];
List<String> stats = m.get(column);
if(stats == null && !column.equals("*")) {
stats = new ArrayList();
m.put(column, stats);
}
if(function.equals("min")) {
stats.add("min");
} else if(function.equals("max")) {
stats.add("max");
} else if(function.equals("sum")) {
stats.add("sum");
} else if(function.equals("avg")) {
stats.add("mean");
} else if(function.equals("count")) {
this.doCount = true;
}
}
}
for(String field : m.keySet()) {
StringBuilder buf = new StringBuilder();
List<String> stats = m.get(field);
buf.append("{!");
for(String stat : stats) {
buf.append(stat).append("=").append("true ");
}
buf.append("}").append(field);
params.add("stats.field", buf.toString());
}
}
private ModifiableSolrParams getParams(Map<String, String> props) {
ModifiableSolrParams params = new ModifiableSolrParams();
for(String key : props.keySet()) {
String value = props.get(key);
params.add(key, value);
}
return params;
}
private Tuple getTuple(NamedList response) {
Map map = new HashMap();
if(doCount) {
SolrDocumentList solrDocumentList = (SolrDocumentList) response.get("response");
this.count = solrDocumentList.getNumFound();
map.put("count(*)", this.count);
}
NamedList stats = (NamedList)response.get("stats");
NamedList statsFields = (NamedList)stats.get("stats_fields");
for(int i=0; i<statsFields.size(); i++) {
String field = statsFields.getName(i);
NamedList theStats = (NamedList)statsFields.getVal(i);
for(int s=0; s<theStats.size(); s++) {
addStat(map, field, theStats.getName(s), theStats.getVal(s));
}
}
Tuple tuple = new Tuple(map);
return tuple;
}
public int getCost() {
return 0;
}
private void addStat(Map map, String field, String stat, Object val) {
if(stat.equals("mean")) {
map.put("avg("+field+")", val);
} else {
map.put(stat+"("+field+")", val);
}
}
}

View File

@ -575,6 +575,73 @@ public class StreamingTest extends AbstractFullDistribZkTestBase {
assert(t.getException().contains("undefined field:"));
}
private void testStatsStream() throws Exception {
indexr(id, "0", "a_s", "hello0", "a_i", "0", "a_f", "1");
indexr(id, "2", "a_s", "hello0", "a_i", "2", "a_f", "2");
indexr(id, "3", "a_s", "hello3", "a_i", "3", "a_f", "3");
indexr(id, "4", "a_s", "hello4", "a_i", "4", "a_f", "4");
indexr(id, "1", "a_s", "hello0", "a_i", "1", "a_f", "5");
indexr(id, "5", "a_s", "hello3", "a_i", "10", "a_f", "6");
indexr(id, "6", "a_s", "hello4", "a_i", "11", "a_f", "7");
indexr(id, "7", "a_s", "hello3", "a_i", "12", "a_f", "8");
indexr(id, "8", "a_s", "hello3", "a_i", "13", "a_f", "9");
indexr(id, "9", "a_s", "hello0", "a_i", "14", "a_f", "10");
commit();
String zkHost = zkServer.getZkAddress();
Map paramsA = mapParams("q", "*:*");
Metric[] metrics = {new SumMetric("a_i"),
new SumMetric("a_f"),
new MinMetric("a_i"),
new MinMetric("a_f"),
new MaxMetric("a_i"),
new MaxMetric("a_f"),
new MeanMetric("a_i"),
new MeanMetric("a_f"),
new CountMetric()};
StatsStream statsStream = new StatsStream(zkHost,
"collection1",
paramsA,
metrics);
List<Tuple> tuples = getTuples(statsStream);
assert(tuples.size() == 1);
//Test Long and Double Sums
Tuple tuple = tuples.get(0);
Double sumi = tuple.getDouble("sum(a_i)");
Double sumf = tuple.getDouble("sum(a_f)");
Double mini = tuple.getDouble("min(a_i)");
Double minf = tuple.getDouble("min(a_f)");
Double maxi = tuple.getDouble("max(a_i)");
Double maxf = tuple.getDouble("max(a_f)");
Double avgi = tuple.getDouble("avg(a_i)");
Double avgf = tuple.getDouble("avg(a_f)");
Double count = tuple.getDouble("count(*)");
assertTrue(sumi.longValue() == 70);
assertTrue(sumf.doubleValue() == 55.0D);
assertTrue(mini.doubleValue() == 0.0D);
assertTrue(minf.doubleValue() == 1.0D);
assertTrue(maxi.doubleValue() == 14.0D);
assertTrue(maxf.doubleValue() == 10.0D);
assertTrue(avgi.doubleValue() == 7.0D);
assertTrue(avgf.doubleValue() == 5.5D);
assertTrue(count.doubleValue() == 10);
del("*:*");
commit();
}
private void testFacetStream() throws Exception {
indexr(id, "0", "a_s", "hello0", "a_i", "0", "a_f", "1");
@ -1701,6 +1768,7 @@ public class StreamingTest extends AbstractFullDistribZkTestBase {
testZeroReducerStream();
testFacetStream();
testSubFacetStream();
testStatsStream();
//testExceptionStream();
testParallelEOF();
testParallelUniqueStream();