SOLR-15036: auto- select / rollup / sort / plist over facet expression when using a collection alias with multiple collections (#2132)

This commit is contained in:
Timothy Potter 2021-01-11 10:34:28 -07:00 committed by GitHub
parent f0d6fd84bb
commit 6711eb7571
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 803 additions and 31 deletions

View File

@ -259,6 +259,9 @@ Optimizations
Partial updates to nested documents and Realtime Get of child documents is now more reliable. Partial updates to nested documents and Realtime Get of child documents is now more reliable.
(David Smiley, Thomas Wöckinger) (David Smiley, Thomas Wöckinger)
* SOLR-15036: Automatically wrap a facet expression with a select / rollup / sort / plist when using a
collection alias with multiple collections and count|sum|min|max|avg metrics. (Timothy Potter)
Bug Fixes Bug Fixes
--------------------- ---------------------
* SOLR-14946: Fix responseHeader being returned in response when omitHeader=true and EmbeddedSolrServer is used * SOLR-14946: Fix responseHeader being returned in response when omitHeader=true and EmbeddedSolrServer is used

View File

@ -18,6 +18,7 @@ package org.apache.solr.handler;
import org.apache.solr.client.solrj.io.Lang; import org.apache.solr.client.solrj.io.Lang;
import org.apache.solr.client.solrj.io.stream.expr.DefaultStreamFactory; import org.apache.solr.client.solrj.io.stream.expr.DefaultStreamFactory;
import org.apache.solr.client.solrj.io.stream.metrics.WeightedSumMetric;
import org.apache.solr.core.SolrResourceLoader; import org.apache.solr.core.SolrResourceLoader;
/** /**
@ -37,6 +38,7 @@ public class SolrDefaultStreamFactory extends DefaultStreamFactory {
this.withFunctionName("cat", CatStream.class); this.withFunctionName("cat", CatStream.class);
this.withFunctionName("classify", ClassifyStream.class); this.withFunctionName("classify", ClassifyStream.class);
this.withFunctionName("haversineMeters", HaversineMetersEvaluator.class); this.withFunctionName("haversineMeters", HaversineMetersEvaluator.class);
this.withFunctionName("wsum", WeightedSumMetric.class);
} }
public SolrDefaultStreamFactory withSolrResourceLoader(SolrResourceLoader solrResourceLoader) { public SolrDefaultStreamFactory withSolrResourceLoader(SolrResourceLoader solrResourceLoader) {

View File

@ -181,6 +181,7 @@ The `facet` function provides aggregations that are rolled up over buckets. Unde
* `bucketSizeLimit`: Sets the absolute number of rows to fetch. This is incompatible with rows, offset and overfetch. This value is applied to each dimension. '-1' will fetch all the buckets. * `bucketSizeLimit`: Sets the absolute number of rows to fetch. This is incompatible with rows, offset and overfetch. This value is applied to each dimension. '-1' will fetch all the buckets.
* `metrics`: List of metrics to compute for the buckets. Currently supported metrics are `sum(col)`, `avg(col)`, `min(col)`, `max(col)`, `count(*)`, `per(col, 50)`. The `per` metric calculates a percentile * `metrics`: List of metrics to compute for the buckets. Currently supported metrics are `sum(col)`, `avg(col)`, `min(col)`, `max(col)`, `count(*)`, `per(col, 50)`. The `per` metric calculates a percentile
for a numeric column and can be specified multiple times in the same facet function. for a numeric column and can be specified multiple times in the same facet function.
* `tiered`: (Default true) Flag governing whether the `facet` stream should parallelize JSON Facet requests to multiple Solr collections using a `plist` expression; this option only applies if the `collection` is an alias backed by multiple collections. If `tiered` is enabled, then a `rollup` expression is used internally to aggregate the metrics from multiple `facet` expressions into a single result; only `count`, `min`, `max`, `sum`, and `avg` metrics are supported. Client applications can disable this globally by setting the `solr.facet.stream.tiered=false` system property.
=== facet Syntax === facet Syntax

View File

@ -18,6 +18,7 @@ package org.apache.solr.client.solrj.io.comp;
import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.UUID; import java.util.UUID;
import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.Tuple;
@ -41,16 +42,13 @@ public class FieldComparator implements StreamComparator {
private ComparatorLambda comparator; private ComparatorLambda comparator;
public FieldComparator(String fieldName, ComparatorOrder order){ public FieldComparator(String fieldName, ComparatorOrder order){
leftFieldName = fieldName; this(fieldName, fieldName, order);
rightFieldName = fieldName;
this.order = order;
assignComparator();
} }
public FieldComparator(String leftFieldName, String rightFieldName, ComparatorOrder order) { public FieldComparator(String leftFieldName, String rightFieldName, ComparatorOrder order) {
this.leftFieldName = leftFieldName; this.leftFieldName = leftFieldName;
this.rightFieldName = rightFieldName; this.rightFieldName = rightFieldName;
this.order = order; this.order = order != null ? order : ComparatorOrder.ASCENDING;
assignComparator(); assignComparator();
} }
@ -176,4 +174,19 @@ public class FieldComparator implements StreamComparator {
public StreamComparator append(StreamComparator other){ public StreamComparator append(StreamComparator other){
return new MultipleFieldComparator(this).append(other); return new MultipleFieldComparator(this).append(other);
} }
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FieldComparator that = (FieldComparator) o;
return leftFieldName.equals(that.leftFieldName) &&
rightFieldName.equals(that.rightFieldName) &&
order == that.order; // comparator is based on the other fields so is not needed in this compare
}
@Override
public int hashCode() {
return Objects.hash(leftFieldName, rightFieldName, order);
}
} }

View File

@ -18,6 +18,7 @@ package org.apache.solr.client.solrj.io.comp;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
@ -61,6 +62,19 @@ public class MultipleFieldComparator implements StreamComparator {
return 0; return 0;
} }
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MultipleFieldComparator that = (MultipleFieldComparator) o;
return Arrays.equals(comps, that.comps);
}
@Override
public int hashCode() {
return Arrays.hashCode(comps);
}
@Override @Override
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();

View File

@ -25,7 +25,7 @@ import org.apache.solr.client.solrj.io.stream.expr.Expressible;
/** Defines a comparator we can use with TupleStreams */ /** Defines a comparator we can use with TupleStreams */
public interface StreamComparator extends Comparator<Tuple>, Expressible, Serializable { public interface StreamComparator extends Comparator<Tuple>, Expressible, Serializable {
public boolean isDerivedFrom(StreamComparator base); boolean isDerivedFrom(StreamComparator base);
public StreamComparator copyAliased(Map<String,String> aliases); StreamComparator copyAliased(Map<String,String> aliases);
public StreamComparator append(StreamComparator other); StreamComparator append(StreamComparator other);
} }

View File

@ -21,6 +21,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import org.apache.commons.lang.exception.ExceptionUtils;
import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.FieldComparator; import org.apache.solr.client.solrj.io.comp.FieldComparator;
import org.apache.solr.client.solrj.io.comp.StreamComparator; import org.apache.solr.client.solrj.io.comp.StreamComparator;
@ -252,29 +253,27 @@ public class DrillStream extends CloudSolrStream implements Expressible {
} }
protected void constructStreams() throws IOException { protected void constructStreams() throws IOException {
try { try {
Object pushStream = ((Expressible) tupleStream).toExpression(streamFactory); Object pushStream = ((Expressible) tupleStream).toExpression(streamFactory);
final ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
List<String> shardUrls = getShards(this.zkHost, this.collection, this.streamContext); paramsLoc.set(DISTRIB,"false"); // We are the aggregator.
paramsLoc.set("expr", pushStream.toString());
for(int w=0; w<shardUrls.size(); w++) { paramsLoc.set("qt","/export");
ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); paramsLoc.set("fl", fl);
paramsLoc.set(DISTRIB,"false"); // We are the aggregator. paramsLoc.set("sort", sort);
paramsLoc.set("expr", pushStream.toString()); paramsLoc.set("q", q);
paramsLoc.set("qt","/export"); getReplicas(this.zkHost, this.collection, this.streamContext, paramsLoc).forEach(r -> {
paramsLoc.set("fl", fl); SolrStream solrStream = new SolrStream(r.getBaseUrl(), paramsLoc, r.getCoreName());
paramsLoc.set("sort", sort);
paramsLoc.set("q", q);
String url = shardUrls.get(w);
SolrStream solrStream = new SolrStream(url, paramsLoc);
solrStream.setStreamContext(streamContext); solrStream.setStreamContext(streamContext);
solrStreams.add(solrStream); solrStreams.add(solrStream);
} });
} catch (Exception e) { } catch (Exception e) {
throw new IOException(e); Throwable rootCause = ExceptionUtils.getRootCause(e);
if (rootCause instanceof IOException) {
throw (IOException)rootCause;
} else {
throw new IOException(e);
}
} }
} }
} }

View File

@ -20,8 +20,10 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -29,6 +31,7 @@ import java.util.stream.Collectors;
import org.apache.solr.client.solrj.SolrRequest; import org.apache.solr.client.solrj.SolrRequest;
import org.apache.solr.client.solrj.impl.CloudSolrClient; import org.apache.solr.client.solrj.impl.CloudSolrClient;
import org.apache.solr.client.solrj.impl.CloudSolrClient.Builder; import org.apache.solr.client.solrj.impl.CloudSolrClient.Builder;
import org.apache.solr.client.solrj.impl.ClusterStateProvider;
import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.ComparatorOrder; import org.apache.solr.client.solrj.io.comp.ComparatorOrder;
@ -58,10 +61,16 @@ import org.apache.solr.common.util.NamedList;
* @since 6.0.0 * @since 6.0.0
**/ **/
public class FacetStream extends TupleStream implements Expressible { public class FacetStream extends TupleStream implements Expressible, ParallelMetricsRollup {
private static final long serialVersionUID = 1; private static final long serialVersionUID = 1;
// allow client apps to disable the auto-plist via system property if they want to turn it off globally
private static final boolean defaultTieredEnabled =
Boolean.parseBoolean(System.getProperty("solr.facet.stream.tiered", "true"));
static final String TIERED_PARAM = "tiered";
private Bucket[] buckets; private Bucket[] buckets;
private Metric[] metrics; private Metric[] metrics;
private int rows; private int rows;
@ -81,6 +90,8 @@ public class FacetStream extends TupleStream implements Expressible {
protected transient SolrClientCache cache; protected transient SolrClientCache cache;
protected transient CloudSolrClient cloudSolrClient; protected transient CloudSolrClient cloudSolrClient;
protected transient TupleStream parallelizedStream;
protected transient StreamContext context;
public FacetStream(String zkHost, public FacetStream(String zkHost,
String collection, String collection,
@ -321,6 +332,10 @@ public class FacetStream extends TupleStream implements Expressible {
zkHost); zkHost);
} }
// see usage in parallelize method
private FacetStream() {
}
public int getBucketSizeLimit() { public int getBucketSizeLimit() {
return this.bucketSizeLimit; return this.bucketSizeLimit;
} }
@ -529,6 +544,7 @@ public class FacetStream extends TupleStream implements Expressible {
} }
public void setStreamContext(StreamContext context) { public void setStreamContext(StreamContext context) {
this.context = context;
cache = context.getSolrClientCache(); cache = context.getSolrClientCache();
} }
@ -545,6 +561,19 @@ public class FacetStream extends TupleStream implements Expressible {
cloudSolrClient = new Builder(hosts, Optional.empty()).withSocketTimeout(30000).withConnectionTimeout(15000).build(); cloudSolrClient = new Builder(hosts, Optional.empty()).withSocketTimeout(30000).withConnectionTimeout(15000).build();
} }
// Parallelize the facet expression across multiple collections for an alias using plist if possible
if (params.getBool(TIERED_PARAM, defaultTieredEnabled)) {
ClusterStateProvider clusterStateProvider = cloudSolrClient.getClusterStateProvider();
final List<String> resolved = clusterStateProvider != null ? clusterStateProvider.resolveAlias(collection) : null;
if (resolved != null && resolved.size() > 1) {
Optional<TupleStream> maybeParallelize = openParallelStream(context, resolved, metrics);
if (maybeParallelize.isPresent()) {
this.parallelizedStream = maybeParallelize.get();
return; // we're using a plist to parallelize the facet operation
} // else, there's a metric that we can't rollup over the plist results safely ... no plist for you!
}
}
FieldComparator[] adjustedSorts = adjustSorts(buckets, bucketSorts); FieldComparator[] adjustedSorts = adjustSorts(buckets, bucketSorts);
this.resortNeeded = resortNeeded(adjustedSorts); this.resortNeeded = resortNeeded(adjustedSorts);
@ -555,6 +584,10 @@ public class FacetStream extends TupleStream implements Expressible {
paramsLoc.set("rows", "0"); paramsLoc.set("rows", "0");
QueryRequest request = new QueryRequest(paramsLoc, SolrRequest.METHOD.POST); QueryRequest request = new QueryRequest(paramsLoc, SolrRequest.METHOD.POST);
if (paramsLoc.get("lb.proxy") != null) {
request.setPath("/"+collection+"/select");
}
try { try {
@SuppressWarnings({"rawtypes"}) @SuppressWarnings({"rawtypes"})
NamedList response = cloudSolrClient.request(request, collection); NamedList response = cloudSolrClient.request(request, collection);
@ -620,6 +653,12 @@ public class FacetStream extends TupleStream implements Expressible {
} }
public Tuple read() throws IOException { public Tuple read() throws IOException {
// if we're parallelizing the facet expression over multiple collections with plist,
// then delegate the read operation to that stream instead
if (parallelizedStream != null) {
return parallelizedStream.read();
}
if(index < tuples.size() && index < (offset+rows)) { if(index < tuples.size() && index < (offset+rows)) {
Tuple tuple = tuples.get(index); Tuple tuple = tuples.get(index);
++index; ++index;
@ -848,4 +887,64 @@ public class FacetStream extends TupleStream implements Expressible {
return bucketSorts[0]; return bucketSorts[0];
} }
} }
List<Metric> getMetrics() {
return Arrays.asList(metrics);
}
@Override
public TupleStream[] parallelize(List<String> partitions) throws IOException {
final ModifiableSolrParams withoutTieredParam = new ModifiableSolrParams(params);
withoutTieredParam.remove(TIERED_PARAM); // each individual facet request is not tiered
TupleStream[] streams = new TupleStream[partitions.size()];
for (int p = 0; p < streams.length; p++) {
FacetStream cloned = new FacetStream();
cloned.init(partitions.get(p), /* each collection */
withoutTieredParam, /* removes the tiered param */
buckets,
bucketSorts,
metrics,
rows,
offset,
bucketSizeLimit,
refine,
method,
serializeBucketSizeLimit,
overfetch,
zkHost);
streams[p] = cloned;
}
return streams;
}
@Override
public TupleStream getSortedRollupStream(ParallelListStream plist, Metric[] rollupMetrics) throws IOException {
// using a hashRollup removes the need to sort the streams from the plist
HashRollupStream rollup = new HashRollupStream(plist, buckets, rollupMetrics);
SelectStream select = new SelectStream(rollup, getRollupSelectFields(rollupMetrics));
// the final stream must be sorted based on the original stream sort
return new SortStream(select, getStreamSort());
}
/**
* The projection of dimensions and metrics from the rollup stream.
*
* @param rollupMetrics The metrics being rolled up.
* @return A mapping of fields produced by the rollup stream to their output name.
*/
protected Map<String, String> getRollupSelectFields(Metric[] rollupMetrics) {
Map<String, String> map = new HashMap<>();
for (Bucket b : buckets) {
String key = b.toString();
map.put(key, key);
}
for (Metric m : rollupMetrics) {
String[] cols = m.getColumns();
String col = cols != null && cols.length > 0 ? cols[0] : "*";
map.put(m.getIdentifier(), col);
}
return map;
}
} }

View File

@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.stream;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import org.apache.solr.client.solrj.io.stream.metrics.CountMetric;
import org.apache.solr.client.solrj.io.stream.metrics.MaxMetric;
import org.apache.solr.client.solrj.io.stream.metrics.MeanMetric;
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
import org.apache.solr.client.solrj.io.stream.metrics.MinMetric;
import org.apache.solr.client.solrj.io.stream.metrics.SumMetric;
import org.apache.solr.client.solrj.io.stream.metrics.WeightedSumMetric;
/**
* Indicates the underlying stream source supports parallelizing metrics computation across collections
* using a rollup of metrics from each collection.
*/
public interface ParallelMetricsRollup {
/**
* Given a list of collections, return an array of TupleStream for each partition.
*
* @param partitions A list of collections to parallelize metrics computation across.
* @return An array of TupleStream for each partition requested.
* @throws IOException if an error occurs while constructing the underlying TupleStream for a partition.
*/
TupleStream[] parallelize(List<String> partitions) throws IOException;
/**
* Get the rollup for the parallelized streams that is sorted based on the original (non-parallel) sort order.
*
* @param plistStream A parallel list stream to fetch metrics from each partition concurrently
* @param rollupMetrics An array of metrics to rollup
* @return A rollup over parallelized streams that provide metrics; this is typically a SortStream.
* @throws IOException if an error occurs while reading from the sorted stream
*/
TupleStream getSortedRollupStream(ParallelListStream plistStream, Metric[] rollupMetrics) throws IOException;
/**
* Given a list of partitions (collections), open a select stream that projects the dimensions and
* metrics produced by rolling up over a parallelized group of streams. If it's not possible to rollup
* the metrics produced by the underlying metrics stream, this method returns Optional.empty.
*
* @param context The current streaming expression context
* @param partitions A list of collections to parallelize metrics computation across.
* @param metrics A list of metrics to rollup.
* @return Either a TupleStream that performs a rollup over parallelized streams or empty if parallelization is not possible.
* @throws IOException if an error occurs reading tuples from the parallelized streams
*/
default Optional<TupleStream> openParallelStream(StreamContext context, List<String> partitions, Metric[] metrics) throws IOException {
Optional<Metric[]> maybeRollupMetrics = getRollupMetrics(metrics);
if (maybeRollupMetrics.isEmpty())
return Optional.empty(); // some metric is incompatible with doing a rollup over the plist results
TupleStream parallelStream = getSortedRollupStream(new ParallelListStream(parallelize(partitions)), maybeRollupMetrics.get());
parallelStream.setStreamContext(context);
parallelStream.open();
return Optional.of(parallelStream);
}
/**
* Either an array of metrics that can be parallelized and rolled up or empty.
*
* @param metrics The list of metrics that we want to parallelize.
* @return Either an array of metrics that can be parallelized and rolled up or empty.
*/
default Optional<Metric[]> getRollupMetrics(Metric[] metrics) {
Metric[] rollup = new Metric[metrics.length];
CountMetric count = null;
for (int m = 0; m < rollup.length; m++) {
Metric nextRollup;
Metric next = metrics[m];
if (next instanceof SumMetric) {
// sum of sums
nextRollup = new SumMetric(next.getIdentifier());
} else if (next instanceof MinMetric) {
// min of mins
nextRollup = new MinMetric(next.getIdentifier());
} else if (next instanceof MaxMetric) {
// max of max
nextRollup = new MaxMetric(next.getIdentifier());
} else if (next instanceof CountMetric) {
// sum of counts
nextRollup = new SumMetric(next.getIdentifier());
count = (CountMetric) next;
} else if (next instanceof MeanMetric) {
// WeightedSumMetric must have a count to compute the weighted avg. rollup from ...
// if the user is not requesting count, then we can't parallelize
if (count == null) {
// just look past the current position
for (int n = m + 1; n < metrics.length; n++) {
if (metrics[n] instanceof CountMetric) {
count = (CountMetric) metrics[n];
break;
}
}
}
if (count != null) {
nextRollup = new WeightedSumMetric(next.getIdentifier(), count.getIdentifier());
} else {
return Optional.empty(); // can't properly rollup mean metrics w/o a count (reqd by WeightedSumMetric)
}
} else {
return Optional.empty(); // can't parallelize this expr!
}
rollup[m] = nextRollup;
}
return Optional.of(rollup);
}
}

View File

@ -86,7 +86,7 @@ public class MaxMetric extends Metric {
if(l > longMax) { if(l > longMax) {
longMax = l; longMax = l;
} }
} else { } else if(o instanceof Long) {
long l = (long)o; long l = (long)o;
if(l > longMax) { if(l > longMax) {
longMax = l; longMax = l;

View File

@ -83,7 +83,7 @@ public class MeanMetric extends Metric {
} else if(o instanceof Integer) { } else if(o instanceof Integer) {
Integer i = (Integer)o; Integer i = (Integer)o;
longSum += i.longValue(); longSum += i.longValue();
} else { } else if (o instanceof Long) {
Long l = (Long)o; Long l = (Long)o;
longSum += l; longSum += l;
} }

View File

@ -87,7 +87,7 @@ public class MinMetric extends Metric {
if(l < longMin) { if(l < longMin) {
longMin = l; longMin = l;
} }
} else { } else if(o instanceof Long) {
long l = (long)o; long l = (long)o;
if(l < longMin) { if(l < longMin) {
longMin = l; longMin = l;

View File

@ -70,7 +70,7 @@ public class SumMetric extends Metric {
} else if(o instanceof Integer) { } else if(o instanceof Integer) {
Integer i = (Integer)o; Integer i = (Integer)o;
longSum += i.longValue(); longSum += i.longValue();
} else { } else if (o instanceof Long) {
Long l = (Long)o; Long l = (Long)o;
longSum += l; longSum += l;
} }

View File

@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.stream.metrics;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class WeightedSumMetric extends Metric {
public static final String FUNC = "wsum";
private String valueCol;
private String countCol;
private List<Part> parts;
public WeightedSumMetric(String valueCol, String countCol) {
init(valueCol, countCol, false);
}
public WeightedSumMetric(String valueCol, String countCol, boolean outputLong) {
init(valueCol, countCol, outputLong);
}
public WeightedSumMetric(StreamExpression expression, StreamFactory factory) throws IOException {
// grab all parameters out
String functionName = expression.getFunctionName();
if (!FUNC.equals(functionName)) {
throw new IOException("Expected '" + FUNC + "' function but found " + functionName);
}
String valueCol = factory.getValueOperand(expression, 0);
String countCol = factory.getValueOperand(expression, 1);
String outputLong = factory.getValueOperand(expression, 2);
// validate expression contains only what we want.
if (null == valueCol) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - expected %s(valueCol,countCol)", expression, FUNC));
}
boolean ol = false;
if (outputLong != null) {
ol = Boolean.parseBoolean(outputLong);
}
init(valueCol, countCol, ol);
}
private void init(String valueCol, String countCol, boolean outputLong) {
this.valueCol = valueCol;
this.countCol = countCol != null ? countCol : "count(*)";
this.outputLong = outputLong;
setFunctionName(FUNC);
setIdentifier(FUNC, "(", valueCol, ", " + countCol + ", " + outputLong + ")");
}
public void update(Tuple tuple) {
Object c = tuple.get(countCol);
Object o = tuple.get(valueCol);
if (c instanceof Number && o instanceof Number) {
if (parts == null) {
parts = new LinkedList<>();
}
Number count = (Number) c;
Number value = (Number) o;
parts.add(new Part(count.longValue(), value.doubleValue()));
}
}
public Metric newInstance() {
return new WeightedSumMetric(valueCol, countCol, outputLong);
}
public String[] getColumns() {
return new String[]{valueCol, countCol};
}
public Number getValue() {
long total = sumCounts();
double wavg = 0d;
if (total > 0L) {
for (Part next : parts) {
wavg += next.weighted(total);
}
}
return outputLong ? Math.round(wavg) : wavg;
}
private long sumCounts() {
long total = 0L;
if (parts != null) {
for (Part next : parts) {
total += next.count;
}
}
return total;
}
@Override
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
return new StreamExpression(getFunctionName()).withParameter(valueCol).withParameter(countCol).withParameter(Boolean.toString(outputLong));
}
private static final class Part {
private final double value;
private final long count;
Part(long count, double value) {
this.count = count;
this.value = value;
}
private double weighted(final long total) {
return ((double) count / total) * value;
}
}
}

View File

@ -0,0 +1,301 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.stream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.Precision;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.client.solrj.SolrServerException;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
import org.apache.solr.client.solrj.request.CollectionAdminRequest;
import org.apache.solr.client.solrj.request.UpdateRequest;
import org.apache.solr.cloud.SolrCloudTestCase;
import org.apache.solr.handler.SolrDefaultStreamFactory;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
/**
* Verify auto-plist with rollup over a facet expression when using collection alias over multiple collections.
*/
@SolrTestCaseJ4.SuppressSSL
@LuceneTestCase.SuppressCodecs({"Lucene3x", "Lucene40", "Lucene41", "Lucene42", "Lucene45"})
public class ParallelFacetStreamOverAliasTest extends SolrCloudTestCase {
private static final String ALIAS_NAME = "SOME_ALIAS_WITH_MANY_COLLS";
private static final String id = "id";
private static final int NUM_COLLECTIONS = 2; // this test requires at least 2 collections, each with multiple shards
private static final int NUM_DOCS_PER_COLLECTION = 40;
private static final int NUM_SHARDS_PER_COLLECTION = 4;
private static final int CARDINALITY = 10;
private static final RandomGenerator rand = new JDKRandomGenerator(5150);
private static List<String> listOfCollections;
private static SolrClientCache solrClientCache;
@BeforeClass
public static void setupCluster() throws Exception {
System.setProperty("solr.tests.numeric.dv", "true");
configureCluster(NUM_COLLECTIONS).withMetrics(false)
.addConfig("conf", getFile("solrj").toPath().resolve("solr").resolve("configsets").resolve("streaming").resolve("conf"))
.configure();
cleanup();
setupCollectionsAndAlias();
solrClientCache = new SolrClientCache();
}
/**
* setup the testbed with necessary collections, documents, and alias
*/
public static void setupCollectionsAndAlias() throws Exception {
final NormalDistribution[] dists = new NormalDistribution[CARDINALITY];
for (int i = 0; i < dists.length; i++) {
dists[i] = new NormalDistribution(rand, i + 1, 1d);
}
List<String> collections = new ArrayList<>(NUM_COLLECTIONS);
final List<Exception> errors = new LinkedList<>();
Stream.iterate(1, n -> n + 1).limit(NUM_COLLECTIONS).forEach(colIdx -> {
final String collectionName = "coll" + colIdx;
collections.add(collectionName);
try {
CollectionAdminRequest.createCollection(collectionName, "conf", NUM_SHARDS_PER_COLLECTION, 1).process(cluster.getSolrClient());
cluster.waitForActiveCollection(collectionName, NUM_SHARDS_PER_COLLECTION, NUM_SHARDS_PER_COLLECTION);
// want a variable num of docs per collection so that avg of avg does not work ;-)
final int numDocsInColl = colIdx % 2 == 0 ? NUM_DOCS_PER_COLLECTION / 2 : NUM_DOCS_PER_COLLECTION;
final int limit = NUM_COLLECTIONS == 1 ? NUM_DOCS_PER_COLLECTION * 2 : numDocsInColl;
UpdateRequest ur = new UpdateRequest();
Stream.iterate(0, n -> n + 1).limit(limit)
.forEach(docId -> ur.add(id, UUID.randomUUID().toString(),
"a_s", "hello" + docId,
"a_i", String.valueOf(docId % CARDINALITY),
"b_i", rand.nextBoolean() ? "1" : "0",
"a_d", String.valueOf(dists[docId % dists.length].sample())));
ur.commit(cluster.getSolrClient(), collectionName);
} catch (SolrServerException | IOException e) {
errors.add(e);
}
});
if (!errors.isEmpty()) {
throw errors.get(0);
}
listOfCollections = collections;
String aliasedCollectionString = String.join(",", collections);
CollectionAdminRequest.createAlias(ALIAS_NAME, aliasedCollectionString).process(cluster.getSolrClient());
}
public static void cleanup() throws Exception {
if (cluster != null && cluster.getSolrClient() != null) {
// cleanup the alias and the collections behind it
CollectionAdminRequest.deleteAlias(ALIAS_NAME).process(cluster.getSolrClient());
if (listOfCollections != null) {
final List<Exception> errors = new LinkedList<>();
listOfCollections.stream().map(CollectionAdminRequest::deleteCollection).forEach(c -> {
try {
c.process(cluster.getSolrClient());
} catch (SolrServerException | IOException e) {
errors.add(e);
}
});
if (!errors.isEmpty()) {
throw errors.get(0);
}
}
}
}
@AfterClass
public static void after() throws Exception {
cleanup();
if (solrClientCache != null) {
solrClientCache.close();
}
}
/**
* Test parallelized calls to facet expression, one for each collection in the alias
*/
@Test
public void testParallelFacetOverAlias() throws Exception {
String facetExprTmpl = "" +
"facet(\n" +
" %s,\n" +
" tiered=%s,\n" +
" q=\"*:*\", \n" +
" buckets=\"a_i\", \n" +
" bucketSorts=\"a_i asc\", \n" +
" bucketSizeLimit=100, \n" +
" sum(a_d), avg(a_d), min(a_d), max(a_d), count(*)\n" +
")\n";
compareTieredStreamWithNonTiered(facetExprTmpl, 1);
}
/**
* Test parallelized calls to facet expression with multiple dimensions, one for each collection in the alias
*/
@Test
public void testParallelFacetMultipleDimensionsOverAlias() throws Exception {
// notice we're sorting the stream by a metric, but internally, that doesn't work for parallelization
// so the rollup has to sort by dimensions and then apply a final re-sort once the parallel streams are merged
String facetExprTmpl = "" +
"facet(\n" +
" %s,\n" +
" tiered=%s,\n" +
" q=\"*:*\", \n" +
" buckets=\"a_i,b_i\", \n" + /* two dimensions here ~ doubles the number of tuples */
" bucketSorts=\"sum(a_d) desc\", \n" +
" bucketSizeLimit=100, \n" +
" sum(a_d), avg(a_d), min(a_d), max(a_d), count(*)\n" +
")\n";
compareTieredStreamWithNonTiered(facetExprTmpl, 2);
}
@Test
public void testParallelFacetSortByDimensions() throws Exception {
// notice we're sorting the stream by a metric, but internally, that doesn't work for parallelization
// so the rollup has to sort by dimensions and then apply a final re-sort once the parallel streams are merged
String facetExprTmpl = "" +
"facet(\n" +
" %s,\n" +
" tiered=%s,\n" +
" q=\"*:*\", \n" +
" buckets=\"a_i,b_i\", \n" +
" bucketSorts=\"a_i asc, b_i asc\", \n" +
" bucketSizeLimit=100, \n" +
" sum(a_d), avg(a_d), min(a_d), max(a_d), count(*)\n" +
")\n";
compareTieredStreamWithNonTiered(facetExprTmpl, 2);
}
// execute the provided expression with tiered=true and compare to results of tiered=false
private void compareTieredStreamWithNonTiered(String facetExprTmpl, int dims) throws IOException {
String facetExpr = String.format(Locale.US, facetExprTmpl, ALIAS_NAME, "true");
StreamContext streamContext = new StreamContext();
streamContext.setSolrClientCache(solrClientCache);
StreamFactory factory = new SolrDefaultStreamFactory().withDefaultZkHost(cluster.getZkServer().getZkAddress());
TupleStream stream = factory.constructStream(facetExpr);
stream.setStreamContext(streamContext);
// check the parallel setup logic
assertParallelFacetStreamConfig(stream, dims);
List<Tuple> plistTuples = getTuples(stream);
assertEquals(CARDINALITY * dims, plistTuples.size());
// now re-execute the same expression w/o plist
facetExpr = String.format(Locale.US, facetExprTmpl, ALIAS_NAME, "false");
stream = factory.constructStream(facetExpr);
stream.setStreamContext(streamContext);
List<Tuple> tuples = getTuples(stream);
assertEquals(CARDINALITY * dims, tuples.size());
// results should be identical regardless of tiered=true|false
assertListOfTuplesEquals(plistTuples, tuples);
}
private void assertParallelFacetStreamConfig(TupleStream stream, int dims) throws IOException {
assertTrue(stream instanceof FacetStream);
FacetStream facetStream = (FacetStream) stream;
TupleStream[] parallelStreams = facetStream.parallelize(listOfCollections);
assertEquals(2, parallelStreams.length);
assertTrue(parallelStreams[0] instanceof FacetStream);
Optional<Metric[]> rollupMetrics = facetStream.getRollupMetrics(facetStream.getMetrics().toArray(new Metric[0]));
assertTrue(rollupMetrics.isPresent());
assertEquals(5, rollupMetrics.get().length);
Map<String, String> selectFields = facetStream.getRollupSelectFields(rollupMetrics.get());
assertNotNull(selectFields);
assertEquals(5 + dims /* num metrics + num dims */, selectFields.size());
assertEquals("a_i", selectFields.get("a_i"));
assertEquals("max(a_d)", selectFields.get("max(max(a_d))"));
assertEquals("min(a_d)", selectFields.get("min(min(a_d))"));
assertEquals("sum(a_d)", selectFields.get("sum(sum(a_d))"));
assertEquals("avg(a_d)", selectFields.get("wsum(avg(a_d), count(*), false)"));
assertEquals("count(*)", selectFields.get("sum(count(*))"));
if (dims > 1) {
assertEquals("b_i", selectFields.get("b_i"));
}
}
// assert results are the same, with some sorting and rounding of floating point values
private void assertListOfTuplesEquals(List<Tuple> exp, List<Tuple> act) {
List<SortedMap<Object, Object>> expList = exp.stream().map(this::toComparableMap).collect(Collectors.toList());
List<SortedMap<Object, Object>> actList = act.stream().map(this::toComparableMap).collect(Collectors.toList());
assertEquals(expList, actList);
}
private SortedMap<Object, Object> toComparableMap(Tuple t) {
SortedMap<Object, Object> cmap = new TreeMap<>();
for (Map.Entry<Object, Object> e : t.getFields().entrySet()) {
Object value = e.getValue();
if (value instanceof Double) {
cmap.put(e.getKey(), Precision.round((Double) value, 5));
} else if (value instanceof Float) {
cmap.put(e.getKey(), Precision.round((Float) value, 3));
} else {
cmap.put(e.getKey(), e.getValue());
}
}
return cmap;
}
List<Tuple> getTuples(TupleStream tupleStream) throws IOException {
List<Tuple> tuples = new ArrayList<>();
try (tupleStream) {
tupleStream.open();
for (Tuple t = tupleStream.read(); !t.EOF; t = tupleStream.read()) {
tuples.add(t);
}
}
return tuples;
}
}

View File

@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.stream.metrics;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.handler.SolrDefaultStreamFactory;
import org.junit.Test;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
public class WeightedSumMetricTest {
final long[] counts = new long[]{10, 20, 30, 40};
final double[] avg = new double[]{2, 4, 6, 8};
@Test
public void testWsumPojo() throws Exception {
WeightedSumMetric wsum = new WeightedSumMetric("avg", "count");
assertEquals("wsum(avg, count, false)", wsum.getIdentifier());
assertArrayEquals(new String[]{"avg", "count"}, wsum.getColumns());
StreamFactory factory = new SolrDefaultStreamFactory();
StreamExpressionParameter expr = wsum.toExpression(factory);
assertTrue(expr instanceof StreamExpression);
wsum = new WeightedSumMetric((StreamExpression) expr, factory);
double expectedSum = 0d;
for (int i = 0; i < counts.length; i++) {
expectedSum += ((double) counts[i] / 100) * avg[i];
}
long expectedSumLong = Math.round(expectedSum);
Number weightedSum = updateMetric(wsum);
assertNotNull(weightedSum);
assertTrue(weightedSum instanceof Double);
assertTrue(weightedSum.doubleValue() == expectedSum);
wsum = new WeightedSumMetric("avg", "count", true);
assertEquals("wsum(avg, count, true)", wsum.getIdentifier());
weightedSum = updateMetric(wsum);
assertNotNull(weightedSum);
assertTrue(weightedSum.longValue() == expectedSumLong);
}
private Number updateMetric(WeightedSumMetric wsum) {
for (int i = 0; i < counts.length; i++) {
Tuple t = new Tuple();
t.put("avg", avg[i]);
t.put("count", counts[i]);
wsum.update(t);
}
return wsum.getValue();
}
}