SOLR-10436: Add hashRollup Streaming Expression

This commit is contained in:
Joel Bernstein 2019-02-27 21:53:43 -05:00
parent 881c9c66e2
commit b516d67c93
8 changed files with 575 additions and 8 deletions

View File

@ -93,6 +93,7 @@ public class Lang {
.withFunctionName("sql", SqlStream.class)
.withFunctionName("plist", ParallelListStream.class)
.withFunctionName("zplot", ZplotStream.class)
.withFunctionName("hashRollup", HashRollupStream.class)
// metrics

View File

@ -0,0 +1,256 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.stream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Iterator;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.HashKey;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.eq.FieldEqualitor;
import org.apache.solr.client.solrj.io.eq.MultipleFieldEqualitor;
import org.apache.solr.client.solrj.io.eq.StreamEqualitor;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.io.stream.metrics.Bucket;
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
public class HashRollupStream extends TupleStream implements Expressible {
private static final long serialVersionUID = 1;
private PushBackStream tupleStream;
private Bucket[] buckets;
private Metric[] metrics;
private Iterator<Tuple> tupleIterator;
public HashRollupStream(TupleStream tupleStream,
Bucket[] buckets,
Metric[] metrics) {
init(tupleStream, buckets, metrics);
}
public HashRollupStream(StreamExpression expression, StreamFactory factory) throws IOException {
// grab all parameters out
List<StreamExpression> streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class);
List<StreamExpression> metricExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, Metric.class);
StreamExpressionNamedParameter overExpression = factory.getNamedOperand(expression, "over");
// validate expression contains only what we want.
if(expression.getParameters().size() != streamExpressions.size() + metricExpressions.size() + 1){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - unknown operands found", expression));
}
if(1 != streamExpressions.size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting a single stream but found %d",expression, streamExpressions.size()));
}
if(null == overExpression || !(overExpression.getParameter() instanceof StreamExpressionValue)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting single 'over' parameter listing fields to rollup by but didn't find one",expression));
}
// Construct the metrics
Metric[] metrics = new Metric[metricExpressions.size()];
for(int idx = 0; idx < metricExpressions.size(); ++idx){
metrics[idx] = factory.constructMetric(metricExpressions.get(idx));
}
// Construct the buckets.
// Buckets are nothing more than equalitors (I think). We can use equalitors as helpers for creating the buckets, but because
// I feel I'm missing something wrt buckets I don't want to change the use of buckets in this class to instead be equalitors.
StreamEqualitor streamEqualitor = factory.constructEqualitor(((StreamExpressionValue)overExpression.getParameter()).getValue(), FieldEqualitor.class);
List<FieldEqualitor> flattenedEqualitors = flattenEqualitor(streamEqualitor);
Bucket[] buckets = new Bucket[flattenedEqualitors.size()];
for(int idx = 0; idx < flattenedEqualitors.size(); ++idx){
buckets[idx] = new Bucket(flattenedEqualitors.get(idx).getLeftFieldName());
// while we're using equalitors we don't support those of the form a=b. Only single field names.
}
init(factory.constructStream(streamExpressions.get(0)), buckets, metrics);
}
private List<FieldEqualitor> flattenEqualitor(StreamEqualitor equalitor){
List<FieldEqualitor> flattenedList = new ArrayList<>();
if(equalitor instanceof FieldEqualitor){
flattenedList.add((FieldEqualitor)equalitor);
}
else if(equalitor instanceof MultipleFieldEqualitor){
MultipleFieldEqualitor mEqualitor = (MultipleFieldEqualitor)equalitor;
for(StreamEqualitor subEqualitor : mEqualitor.getEqs()){
flattenedList.addAll(flattenEqualitor(subEqualitor));
}
}
return flattenedList;
}
private void init(TupleStream tupleStream, Bucket[] buckets, Metric[] metrics){
this.tupleStream = new PushBackStream(tupleStream);
this.buckets = buckets;
this.metrics = metrics;
}
@Override
public StreamExpression toExpression(StreamFactory factory) throws IOException{
return toExpression(factory, true);
}
private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException {
// function name
StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
// stream
if(includeStreams){
expression.addParameter(tupleStream.toExpression(factory));
}
else{
expression.addParameter("<stream>");
}
// over
StringBuilder overBuilder = new StringBuilder();
for(Bucket bucket : buckets){
if(overBuilder.length() > 0){ overBuilder.append(","); }
overBuilder.append(bucket.toString());
}
expression.addParameter(new StreamExpressionNamedParameter("over",overBuilder.toString()));
// metrics
for(Metric metric : metrics){
expression.addParameter(metric.toExpression(factory));
}
return expression;
}
@Override
public Explanation toExplanation(StreamFactory factory) throws IOException {
Explanation explanation = new StreamExplanation(getStreamNodeId().toString())
.withChildren(new Explanation[]{
tupleStream.toExplanation(factory)
})
.withFunctionName(factory.getFunctionName(this.getClass()))
.withImplementingClass(this.getClass().getName())
.withExpressionType(ExpressionType.STREAM_DECORATOR)
.withExpression(toExpression(factory, false).toString());
for(Metric metric : metrics){
explanation.withHelper(metric.toExplanation(factory));
}
return explanation;
}
public void setStreamContext(StreamContext context) {
this.tupleStream.setStreamContext(context);
}
public List<TupleStream> children() {
List<TupleStream> l = new ArrayList<TupleStream>();
l.add(tupleStream);
return l;
}
public void open() throws IOException {
tupleStream.open();
}
public void close() throws IOException {
tupleStream.close();
tupleIterator = null;
}
public Tuple read() throws IOException {
//On the first call to read build the tupleIterator.
if(tupleIterator == null) {
Map<HashKey, Metric[]> metricMap = new HashMap();
while (true) {
Tuple tuple = tupleStream.read();
if (tuple.EOF) {
List tuples = new ArrayList();
for(Map.Entry<HashKey, Metric[]> entry : metricMap.entrySet()) {
Map<String, Object> map = new HashMap<String, Object>();
Metric[] finishedMetrics = entry.getValue();
for (Metric metric : finishedMetrics) {
map.put(metric.getIdentifier(), metric.getValue());
}
HashKey hashKey = entry.getKey();
for (int i = 0; i < buckets.length; i++) {
map.put(buckets[i].toString(), hashKey.getParts()[i]);
}
Tuple t = new Tuple(map);
tuples.add(t);
}
tuples.add(tuple);
this.tupleIterator = tuples.iterator();
break;
}
Object[] bucketValues = new Object[buckets.length];
for (int i = 0; i < buckets.length; i++) {
bucketValues[i] = buckets[i].getBucketValue(tuple);
}
HashKey hashKey = new HashKey(bucketValues);
Metric[] currentMetrics = metricMap.get(hashKey);
if (currentMetrics != null) {
for (Metric bucketMetric : currentMetrics) {
bucketMetric.update(tuple);
}
} else {
currentMetrics = new Metric[metrics.length];
for (int i = 0; i < metrics.length; i++) {
Metric bucketMetric = metrics[i].newInstance();
bucketMetric.update(tuple);
currentMetrics[i] = bucketMetric;
}
metricMap.put(hashKey, currentMetrics);
}
}
}
return tupleIterator.next();
}
public int getCost() {
return 0;
}
@Override
public StreamComparator getStreamSort() {
return tupleStream.getStreamSort();
}
}

View File

@ -70,10 +70,22 @@ public class MaxMetric extends Metric {
public void update(Tuple tuple) {
Object o = tuple.get(columnName);
if(o instanceof Double) {
double d = (double)o;
if(d > doubleMax) {
double d = (double) o;
if (d > doubleMax) {
doubleMax = d;
}
}else if(o instanceof Float) {
Float f = (Float) o;
double d = f.doubleValue();
if (d > doubleMax) {
doubleMax = d;
}
} else if(o instanceof Integer) {
Integer i = (Integer)o;
long l = i.longValue();
if(l > longMax) {
longMax = l;
}
} else {
long l = (long)o;
if(l > longMax) {

View File

@ -75,10 +75,16 @@ public class MeanMetric extends Metric {
++count;
Object o = tuple.get(columnName);
if(o instanceof Double) {
Double d = (Double)tuple.get(columnName);
Double d = (Double) o;
doubleSum += d;
} else if(o instanceof Float) {
Float f = (Float) o;
doubleSum += f.doubleValue();
} else if(o instanceof Integer) {
Integer i = (Integer)o;
longSum += i.longValue();
} else {
Long l = (Long)tuple.get(columnName);
Long l = (Long)o;
longSum += l;
}
}

View File

@ -71,10 +71,22 @@ public class MinMetric extends Metric {
public void update(Tuple tuple) {
Object o = tuple.get(columnName);
if(o instanceof Double) {
double d = (double)o;
if(d < doubleMin) {
double d = (double) o;
if (d < doubleMin) {
doubleMin = d;
}
} else if(o instanceof Float) {
Float f = (Float) o;
double d = f.doubleValue();
if (d < doubleMin) {
doubleMin = d;
}
} else if(o instanceof Integer) {
Integer i = (Integer)o;
long l = i.longValue();
if(l < longMin) {
longMin = l;
}
} else {
long l = (long)o;
if(l < longMin) {

View File

@ -62,8 +62,14 @@ public class SumMetric extends Metric {
public void update(Tuple tuple) {
Object o = tuple.get(columnName);
if(o instanceof Double) {
Double d = (Double)o;
Double d = (Double) o;
doubleSum += d;
} else if(o instanceof Float) {
Float f = (Float) o;
doubleSum += f.doubleValue();
} else if(o instanceof Integer) {
Integer i = (Integer)o;
longSum += i.longValue();
} else {
Long l = (Long)o;
longSum += l;

View File

@ -74,7 +74,7 @@ public class TestLang extends LuceneTestCase {
"convexHull", "getVertices", "getBaryCenter", "getArea", "getBoundarySize","oscillate",
"getAmplitude", "getPhase", "getAngularFrequency", "enclosingDisk", "getCenter", "getRadius",
"getSupportPoints", "pairSort", "log10", "plist", "recip", "pivot", "ltrim", "rtrim", "export",
"zplot", "natural", "repeat", "movingMAD"};
"zplot", "natural", "repeat", "movingMAD", "hashRollup"};
@Test
public void testLang() {

View File

@ -1247,6 +1247,141 @@ public class StreamDecoratorTest extends SolrCloudTestCase {
}
}
@Test
public void testHashRollupStream() throws Exception {
new UpdateRequest()
.add(id, "0", "a_s", "hello0", "a_i", "0", "a_f", "1")
.add(id, "2", "a_s", "hello0", "a_i", "2", "a_f", "2")
.add(id, "3", "a_s", "hello3", "a_i", "3", "a_f", "3")
.add(id, "4", "a_s", "hello4", "a_i", "4", "a_f", "4")
.add(id, "1", "a_s", "hello0", "a_i", "1", "a_f", "5")
.add(id, "5", "a_s", "hello3", "a_i", "10", "a_f", "6")
.add(id, "6", "a_s", "hello4", "a_i", "11", "a_f", "7")
.add(id, "7", "a_s", "hello3", "a_i", "12", "a_f", "8")
.add(id, "8", "a_s", "hello3", "a_i", "13", "a_f", "9")
.add(id, "9", "a_s", "hello0", "a_i", "14", "a_f", "10")
.commit(cluster.getSolrClient(), COLLECTIONORALIAS);
StreamFactory factory = new StreamFactory()
.withCollectionZkHost(COLLECTIONORALIAS, cluster.getZkServer().getZkAddress())
.withFunctionName("search", CloudSolrStream.class)
.withFunctionName("hashRollup", HashRollupStream.class)
.withFunctionName("sum", SumMetric.class)
.withFunctionName("min", MinMetric.class)
.withFunctionName("max", MaxMetric.class)
.withFunctionName("avg", MeanMetric.class)
.withFunctionName("count", CountMetric.class)
.withFunctionName("sort", SortStream.class);
StreamExpression expression;
TupleStream stream;
List<Tuple> tuples;
StreamContext streamContext = new StreamContext();
SolrClientCache solrClientCache = new SolrClientCache();
streamContext.setSolrClientCache(solrClientCache);
try {
expression = StreamExpressionParser.parse("sort(hashRollup("
+ "search(" + COLLECTIONORALIAS + ", q=*:*, fl=\"a_s,a_i,a_f\", sort=\"a_s asc\"),"
+ "over=\"a_s\","
+ "sum(a_i),"
+ "sum(a_f),"
+ "min(a_i),"
+ "min(a_f),"
+ "max(a_i),"
+ "max(a_f),"
+ "avg(a_i),"
+ "avg(a_f),"
+ "count(*),"
+ "), by=\"avg(a_f) asc\")");
stream = factory.constructStream(expression);
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
assert (tuples.size() == 3);
//Test Long and Double Sums
Tuple tuple = tuples.get(0);
String bucket = tuple.getString("a_s");
Double sumi = tuple.getDouble("sum(a_i)");
Double sumf = tuple.getDouble("sum(a_f)");
Double mini = tuple.getDouble("min(a_i)");
Double minf = tuple.getDouble("min(a_f)");
Double maxi = tuple.getDouble("max(a_i)");
Double maxf = tuple.getDouble("max(a_f)");
Double avgi = tuple.getDouble("avg(a_i)");
Double avgf = tuple.getDouble("avg(a_f)");
Double count = tuple.getDouble("count(*)");
assertTrue(bucket.equals("hello0"));
assertTrue(sumi.doubleValue() == 17.0D);
assertTrue(sumf.doubleValue() == 18.0D);
assertTrue(mini.doubleValue() == 0.0D);
assertTrue(minf.doubleValue() == 1.0D);
assertTrue(maxi.doubleValue() == 14.0D);
assertTrue(maxf.doubleValue() == 10.0D);
assertTrue(avgi.doubleValue() == 4.25D);
assertTrue(avgf.doubleValue() == 4.5D);
assertTrue(count.doubleValue() == 4);
tuple = tuples.get(1);
bucket = tuple.getString("a_s");
sumi = tuple.getDouble("sum(a_i)");
sumf = tuple.getDouble("sum(a_f)");
mini = tuple.getDouble("min(a_i)");
minf = tuple.getDouble("min(a_f)");
maxi = tuple.getDouble("max(a_i)");
maxf = tuple.getDouble("max(a_f)");
avgi = tuple.getDouble("avg(a_i)");
avgf = tuple.getDouble("avg(a_f)");
count = tuple.getDouble("count(*)");
System.out.println("################:bucket"+bucket);
assertTrue(bucket.equals("hello4"));
assertTrue(sumi.longValue() == 15);
assertTrue(sumf.doubleValue() == 11.0D);
assertTrue(mini.doubleValue() == 4.0D);
assertTrue(minf.doubleValue() == 4.0D);
assertTrue(maxi.doubleValue() == 11.0D);
assertTrue(maxf.doubleValue() == 7.0D);
assertTrue(avgi.doubleValue() == 7.5D);
assertTrue(avgf.doubleValue() == 5.5D);
assertTrue(count.doubleValue() == 2);
tuple = tuples.get(2);
bucket = tuple.getString("a_s");
sumi = tuple.getDouble("sum(a_i)");
sumf = tuple.getDouble("sum(a_f)");
mini = tuple.getDouble("min(a_i)");
minf = tuple.getDouble("min(a_f)");
maxi = tuple.getDouble("max(a_i)");
maxf = tuple.getDouble("max(a_f)");
avgi = tuple.getDouble("avg(a_i)");
avgf = tuple.getDouble("avg(a_f)");
count = tuple.getDouble("count(*)");
assertTrue(bucket.equals("hello3"));
assertTrue(sumi.doubleValue() == 38.0D);
assertTrue(sumf.doubleValue() == 26.0D);
assertTrue(mini.doubleValue() == 3.0D);
assertTrue(minf.doubleValue() == 3.0D);
assertTrue(maxi.doubleValue() == 13.0D);
assertTrue(maxf.doubleValue() == 9.0D);
assertTrue(avgi.doubleValue() == 9.5D);
assertTrue(avgf.doubleValue() == 6.5D);
assertTrue(count.doubleValue() == 4);
} finally {
solrClientCache.close();
}
}
@Test
public void testParallelUniqueStream() throws Exception {
@ -1706,6 +1841,145 @@ public class StreamDecoratorTest extends SolrCloudTestCase {
}
}
@Test
public void testParallelHashRollupStream() throws Exception {
new UpdateRequest()
.add(id, "0", "a_s", "hello0", "a_i", "0", "a_f", "1")
.add(id, "2", "a_s", "hello0", "a_i", "2", "a_f", "2")
.add(id, "3", "a_s", "hello3", "a_i", "3", "a_f", "3")
.add(id, "4", "a_s", "hello4", "a_i", "4", "a_f", "4")
.add(id, "1", "a_s", "hello0", "a_i", "1", "a_f", "5")
.add(id, "5", "a_s", "hello3", "a_i", "10", "a_f", "6")
.add(id, "6", "a_s", "hello4", "a_i", "11", "a_f", "7")
.add(id, "7", "a_s", "hello3", "a_i", "12", "a_f", "8")
.add(id, "8", "a_s", "hello3", "a_i", "13", "a_f", "9")
.add(id, "9", "a_s", "hello0", "a_i", "14", "a_f", "10")
.commit(cluster.getSolrClient(), COLLECTIONORALIAS);
StreamFactory factory = new StreamFactory()
.withCollectionZkHost(COLLECTIONORALIAS, cluster.getZkServer().getZkAddress())
.withFunctionName("search", CloudSolrStream.class)
.withFunctionName("parallel", ParallelStream.class)
.withFunctionName("hashRollup", HashRollupStream.class)
.withFunctionName("sum", SumMetric.class)
.withFunctionName("min", MinMetric.class)
.withFunctionName("max", MaxMetric.class)
.withFunctionName("avg", MeanMetric.class)
.withFunctionName("count", CountMetric.class)
.withFunctionName("sort", SortStream.class);
StreamContext streamContext = new StreamContext();
SolrClientCache solrClientCache = new SolrClientCache();
streamContext.setSolrClientCache(solrClientCache);
StreamExpression expression;
TupleStream stream;
List<Tuple> tuples;
try {
expression = StreamExpressionParser.parse("sort(parallel(" + COLLECTIONORALIAS + ","
+ "hashRollup("
+ "search(" + COLLECTIONORALIAS + ", q=*:*, fl=\"a_s,a_i,a_f\", sort=\"a_s asc\", partitionKeys=\"a_s\", qt=\"/export\"),"
+ "over=\"a_s\","
+ "sum(a_i),"
+ "sum(a_f),"
+ "min(a_i),"
+ "min(a_f),"
+ "max(a_i),"
+ "max(a_f),"
+ "avg(a_i),"
+ "avg(a_f),"
+ "count(*)"
+ "),"
+ "workers=\"2\", zkHost=\"" + cluster.getZkServer().getZkAddress() + "\", sort=\"a_s asc\"), by=\"avg(a_f) asc\")"
);
stream = factory.constructStream(expression);
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
assert (tuples.size() == 3);
//Test Long and Double Sums
Tuple tuple = tuples.get(0);
String bucket = tuple.getString("a_s");
Double sumi = tuple.getDouble("sum(a_i)");
Double sumf = tuple.getDouble("sum(a_f)");
Double mini = tuple.getDouble("min(a_i)");
Double minf = tuple.getDouble("min(a_f)");
Double maxi = tuple.getDouble("max(a_i)");
Double maxf = tuple.getDouble("max(a_f)");
Double avgi = tuple.getDouble("avg(a_i)");
Double avgf = tuple.getDouble("avg(a_f)");
Double count = tuple.getDouble("count(*)");
assertTrue(bucket.equals("hello0"));
assertTrue(sumi.doubleValue() == 17.0D);
assertTrue(sumf.doubleValue() == 18.0D);
assertTrue(mini.doubleValue() == 0.0D);
assertTrue(minf.doubleValue() == 1.0D);
assertTrue(maxi.doubleValue() == 14.0D);
assertTrue(maxf.doubleValue() == 10.0D);
assertTrue(avgi.doubleValue() == 4.25D);
assertTrue(avgf.doubleValue() == 4.5D);
assertTrue(count.doubleValue() == 4);
tuple = tuples.get(1);
bucket = tuple.getString("a_s");
sumi = tuple.getDouble("sum(a_i)");
sumf = tuple.getDouble("sum(a_f)");
mini = tuple.getDouble("min(a_i)");
minf = tuple.getDouble("min(a_f)");
maxi = tuple.getDouble("max(a_i)");
maxf = tuple.getDouble("max(a_f)");
avgi = tuple.getDouble("avg(a_i)");
avgf = tuple.getDouble("avg(a_f)");
count = tuple.getDouble("count(*)");
assertTrue(bucket.equals("hello4"));
assertTrue(sumi.longValue() == 15);
assertTrue(sumf.doubleValue() == 11.0D);
assertTrue(mini.doubleValue() == 4.0D);
assertTrue(minf.doubleValue() == 4.0D);
assertTrue(maxi.doubleValue() == 11.0D);
assertTrue(maxf.doubleValue() == 7.0D);
assertTrue(avgi.doubleValue() == 7.5D);
assertTrue(avgf.doubleValue() == 5.5D);
assertTrue(count.doubleValue() == 2);
tuple = tuples.get(2);
bucket = tuple.getString("a_s");
sumi = tuple.getDouble("sum(a_i)");
sumf = tuple.getDouble("sum(a_f)");
mini = tuple.getDouble("min(a_i)");
minf = tuple.getDouble("min(a_f)");
maxi = tuple.getDouble("max(a_i)");
maxf = tuple.getDouble("max(a_f)");
avgi = tuple.getDouble("avg(a_i)");
avgf = tuple.getDouble("avg(a_f)");
count = tuple.getDouble("count(*)");
assertTrue(bucket.equals("hello3"));
assertTrue(sumi.doubleValue() == 38.0D);
assertTrue(sumf.doubleValue() == 26.0D);
assertTrue(mini.doubleValue() == 3.0D);
assertTrue(minf.doubleValue() == 3.0D);
assertTrue(maxi.doubleValue() == 13.0D);
assertTrue(maxf.doubleValue() == 9.0D);
assertTrue(avgi.doubleValue() == 9.5D);
assertTrue(avgf.doubleValue() == 6.5D);
assertTrue(count.doubleValue() == 4);
} finally {
solrClientCache.close();
}
}
@Test
public void testInnerJoinStream() throws Exception {