From 6711eb7571727552aad3ace53c52c9a8fe07dc40 Mon Sep 17 00:00:00 2001 From: Timothy Potter Date: Mon, 11 Jan 2021 10:34:28 -0700 Subject: [PATCH] SOLR-15036: auto- select / rollup / sort / plist over facet expression when using a collection alias with multiple collections (#2132) --- solr/CHANGES.txt | 3 + .../handler/SolrDefaultStreamFactory.java | 2 + .../src/stream-source-reference.adoc | 1 + .../client/solrj/io/comp/FieldComparator.java | 23 +- .../io/comp/MultipleFieldComparator.java | 14 + .../solrj/io/comp/StreamComparator.java | 6 +- .../client/solrj/io/stream/DrillStream.java | 35 +- .../client/solrj/io/stream/FacetStream.java | 101 +++++- .../io/stream/ParallelMetricsRollup.java | 130 ++++++++ .../solrj/io/stream/metrics/MaxMetric.java | 2 +- .../solrj/io/stream/metrics/MeanMetric.java | 2 +- .../solrj/io/stream/metrics/MinMetric.java | 2 +- .../solrj/io/stream/metrics/SumMetric.java | 2 +- .../io/stream/metrics/WeightedSumMetric.java | 135 ++++++++ .../ParallelFacetStreamOverAliasTest.java | 301 ++++++++++++++++++ .../stream/metrics/WeightedSumMetricTest.java | 75 +++++ 16 files changed, 803 insertions(+), 31 deletions(-) create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ParallelMetricsRollup.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/WeightedSumMetric.java create mode 100644 solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/ParallelFacetStreamOverAliasTest.java create mode 100644 solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/metrics/WeightedSumMetricTest.java diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index e2566d2128a..0f499e9c745 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -259,6 +259,9 @@ Optimizations Partial updates to nested documents and Realtime Get of child documents is now more reliable. (David Smiley, Thomas Wöckinger) +* SOLR-15036: Automatically wrap a facet expression with a select / rollup / sort / plist when using a + collection alias with multiple collections and count|sum|min|max|avg metrics. (Timothy Potter) + Bug Fixes --------------------- * SOLR-14946: Fix responseHeader being returned in response when omitHeader=true and EmbeddedSolrServer is used diff --git a/solr/core/src/java/org/apache/solr/handler/SolrDefaultStreamFactory.java b/solr/core/src/java/org/apache/solr/handler/SolrDefaultStreamFactory.java index ca75927e447..8df98fc0e97 100644 --- a/solr/core/src/java/org/apache/solr/handler/SolrDefaultStreamFactory.java +++ b/solr/core/src/java/org/apache/solr/handler/SolrDefaultStreamFactory.java @@ -18,6 +18,7 @@ package org.apache.solr.handler; import org.apache.solr.client.solrj.io.Lang; import org.apache.solr.client.solrj.io.stream.expr.DefaultStreamFactory; +import org.apache.solr.client.solrj.io.stream.metrics.WeightedSumMetric; import org.apache.solr.core.SolrResourceLoader; /** @@ -37,6 +38,7 @@ public class SolrDefaultStreamFactory extends DefaultStreamFactory { this.withFunctionName("cat", CatStream.class); this.withFunctionName("classify", ClassifyStream.class); this.withFunctionName("haversineMeters", HaversineMetersEvaluator.class); + this.withFunctionName("wsum", WeightedSumMetric.class); } public SolrDefaultStreamFactory withSolrResourceLoader(SolrResourceLoader solrResourceLoader) { diff --git a/solr/solr-ref-guide/src/stream-source-reference.adoc b/solr/solr-ref-guide/src/stream-source-reference.adoc index d31cc3c74cc..d591534b8dc 100644 --- a/solr/solr-ref-guide/src/stream-source-reference.adoc +++ b/solr/solr-ref-guide/src/stream-source-reference.adoc @@ -181,6 +181,7 @@ The `facet` function provides aggregations that are rolled up over buckets. Unde * `bucketSizeLimit`: Sets the absolute number of rows to fetch. This is incompatible with rows, offset and overfetch. This value is applied to each dimension. '-1' will fetch all the buckets. * `metrics`: List of metrics to compute for the buckets. Currently supported metrics are `sum(col)`, `avg(col)`, `min(col)`, `max(col)`, `count(*)`, `per(col, 50)`. The `per` metric calculates a percentile for a numeric column and can be specified multiple times in the same facet function. +* `tiered`: (Default true) Flag governing whether the `facet` stream should parallelize JSON Facet requests to multiple Solr collections using a `plist` expression; this option only applies if the `collection` is an alias backed by multiple collections. If `tiered` is enabled, then a `rollup` expression is used internally to aggregate the metrics from multiple `facet` expressions into a single result; only `count`, `min`, `max`, `sum`, and `avg` metrics are supported. Client applications can disable this globally by setting the `solr.facet.stream.tiered=false` system property. === facet Syntax diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/comp/FieldComparator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/comp/FieldComparator.java index 36ecc5e20c3..68f9454cbe6 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/comp/FieldComparator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/comp/FieldComparator.java @@ -18,6 +18,7 @@ package org.apache.solr.client.solrj.io.comp; import java.io.IOException; import java.util.Map; +import java.util.Objects; import java.util.UUID; import org.apache.solr.client.solrj.io.Tuple; @@ -41,16 +42,13 @@ public class FieldComparator implements StreamComparator { private ComparatorLambda comparator; public FieldComparator(String fieldName, ComparatorOrder order){ - leftFieldName = fieldName; - rightFieldName = fieldName; - this.order = order; - assignComparator(); + this(fieldName, fieldName, order); } public FieldComparator(String leftFieldName, String rightFieldName, ComparatorOrder order) { this.leftFieldName = leftFieldName; this.rightFieldName = rightFieldName; - this.order = order; + this.order = order != null ? order : ComparatorOrder.ASCENDING; assignComparator(); } @@ -176,4 +174,19 @@ public class FieldComparator implements StreamComparator { public StreamComparator append(StreamComparator other){ return new MultipleFieldComparator(this).append(other); } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FieldComparator that = (FieldComparator) o; + return leftFieldName.equals(that.leftFieldName) && + rightFieldName.equals(that.rightFieldName) && + order == that.order; // comparator is based on the other fields so is not needed in this compare + } + + @Override + public int hashCode() { + return Objects.hash(leftFieldName, rightFieldName, order); + } } \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/comp/MultipleFieldComparator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/comp/MultipleFieldComparator.java index 09532e19f39..de8fa6dfbdf 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/comp/MultipleFieldComparator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/comp/MultipleFieldComparator.java @@ -18,6 +18,7 @@ package org.apache.solr.client.solrj.io.comp; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.UUID; @@ -61,6 +62,19 @@ public class MultipleFieldComparator implements StreamComparator { return 0; } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MultipleFieldComparator that = (MultipleFieldComparator) o; + return Arrays.equals(comps, that.comps); + } + + @Override + public int hashCode() { + return Arrays.hashCode(comps); + } + @Override public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { StringBuilder sb = new StringBuilder(); diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/comp/StreamComparator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/comp/StreamComparator.java index 70bd51d801b..ff91f31eb3a 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/comp/StreamComparator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/comp/StreamComparator.java @@ -25,7 +25,7 @@ import org.apache.solr.client.solrj.io.stream.expr.Expressible; /** Defines a comparator we can use with TupleStreams */ public interface StreamComparator extends Comparator, Expressible, Serializable { - public boolean isDerivedFrom(StreamComparator base); - public StreamComparator copyAliased(Map aliases); - public StreamComparator append(StreamComparator other); + boolean isDerivedFrom(StreamComparator base); + StreamComparator copyAliased(Map aliases); + StreamComparator append(StreamComparator other); } \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/DrillStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/DrillStream.java index f1fe718711a..91e307e171f 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/DrillStream.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/DrillStream.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import org.apache.commons.lang.exception.ExceptionUtils; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.FieldComparator; import org.apache.solr.client.solrj.io.comp.StreamComparator; @@ -252,29 +253,27 @@ public class DrillStream extends CloudSolrStream implements Expressible { } protected void constructStreams() throws IOException { - try { - Object pushStream = ((Expressible) tupleStream).toExpression(streamFactory); - - List shardUrls = getShards(this.zkHost, this.collection, this.streamContext); - - for(int w=0; w { + SolrStream solrStream = new SolrStream(r.getBaseUrl(), paramsLoc, r.getCoreName()); solrStream.setStreamContext(streamContext); solrStreams.add(solrStream); - } - + }); } catch (Exception e) { - throw new IOException(e); + Throwable rootCause = ExceptionUtils.getRootCause(e); + if (rootCause instanceof IOException) { + throw (IOException)rootCause; + } else { + throw new IOException(e); + } } } } \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FacetStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FacetStream.java index 638550f79de..66c9c56342c 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FacetStream.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FacetStream.java @@ -20,8 +20,10 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Map.Entry; import java.util.Optional; import java.util.stream.Collectors; @@ -29,6 +31,7 @@ import java.util.stream.Collectors; import org.apache.solr.client.solrj.SolrRequest; import org.apache.solr.client.solrj.impl.CloudSolrClient; import org.apache.solr.client.solrj.impl.CloudSolrClient.Builder; +import org.apache.solr.client.solrj.impl.ClusterStateProvider; import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.ComparatorOrder; @@ -58,10 +61,16 @@ import org.apache.solr.common.util.NamedList; * @since 6.0.0 **/ -public class FacetStream extends TupleStream implements Expressible { +public class FacetStream extends TupleStream implements Expressible, ParallelMetricsRollup { private static final long serialVersionUID = 1; + // allow client apps to disable the auto-plist via system property if they want to turn it off globally + private static final boolean defaultTieredEnabled = + Boolean.parseBoolean(System.getProperty("solr.facet.stream.tiered", "true")); + + static final String TIERED_PARAM = "tiered"; + private Bucket[] buckets; private Metric[] metrics; private int rows; @@ -81,6 +90,8 @@ public class FacetStream extends TupleStream implements Expressible { protected transient SolrClientCache cache; protected transient CloudSolrClient cloudSolrClient; + protected transient TupleStream parallelizedStream; + protected transient StreamContext context; public FacetStream(String zkHost, String collection, @@ -321,6 +332,10 @@ public class FacetStream extends TupleStream implements Expressible { zkHost); } + // see usage in parallelize method + private FacetStream() { + } + public int getBucketSizeLimit() { return this.bucketSizeLimit; } @@ -529,6 +544,7 @@ public class FacetStream extends TupleStream implements Expressible { } public void setStreamContext(StreamContext context) { + this.context = context; cache = context.getSolrClientCache(); } @@ -545,6 +561,19 @@ public class FacetStream extends TupleStream implements Expressible { cloudSolrClient = new Builder(hosts, Optional.empty()).withSocketTimeout(30000).withConnectionTimeout(15000).build(); } + // Parallelize the facet expression across multiple collections for an alias using plist if possible + if (params.getBool(TIERED_PARAM, defaultTieredEnabled)) { + ClusterStateProvider clusterStateProvider = cloudSolrClient.getClusterStateProvider(); + final List resolved = clusterStateProvider != null ? clusterStateProvider.resolveAlias(collection) : null; + if (resolved != null && resolved.size() > 1) { + Optional maybeParallelize = openParallelStream(context, resolved, metrics); + if (maybeParallelize.isPresent()) { + this.parallelizedStream = maybeParallelize.get(); + return; // we're using a plist to parallelize the facet operation + } // else, there's a metric that we can't rollup over the plist results safely ... no plist for you! + } + } + FieldComparator[] adjustedSorts = adjustSorts(buckets, bucketSorts); this.resortNeeded = resortNeeded(adjustedSorts); @@ -555,6 +584,10 @@ public class FacetStream extends TupleStream implements Expressible { paramsLoc.set("rows", "0"); QueryRequest request = new QueryRequest(paramsLoc, SolrRequest.METHOD.POST); + if (paramsLoc.get("lb.proxy") != null) { + request.setPath("/"+collection+"/select"); + } + try { @SuppressWarnings({"rawtypes"}) NamedList response = cloudSolrClient.request(request, collection); @@ -620,6 +653,12 @@ public class FacetStream extends TupleStream implements Expressible { } public Tuple read() throws IOException { + // if we're parallelizing the facet expression over multiple collections with plist, + // then delegate the read operation to that stream instead + if (parallelizedStream != null) { + return parallelizedStream.read(); + } + if(index < tuples.size() && index < (offset+rows)) { Tuple tuple = tuples.get(index); ++index; @@ -848,4 +887,64 @@ public class FacetStream extends TupleStream implements Expressible { return bucketSorts[0]; } } + + List getMetrics() { + return Arrays.asList(metrics); + } + + @Override + public TupleStream[] parallelize(List partitions) throws IOException { + + final ModifiableSolrParams withoutTieredParam = new ModifiableSolrParams(params); + withoutTieredParam.remove(TIERED_PARAM); // each individual facet request is not tiered + + TupleStream[] streams = new TupleStream[partitions.size()]; + for (int p = 0; p < streams.length; p++) { + FacetStream cloned = new FacetStream(); + cloned.init(partitions.get(p), /* each collection */ + withoutTieredParam, /* removes the tiered param */ + buckets, + bucketSorts, + metrics, + rows, + offset, + bucketSizeLimit, + refine, + method, + serializeBucketSizeLimit, + overfetch, + zkHost); + streams[p] = cloned; + } + return streams; + } + + @Override + public TupleStream getSortedRollupStream(ParallelListStream plist, Metric[] rollupMetrics) throws IOException { + // using a hashRollup removes the need to sort the streams from the plist + HashRollupStream rollup = new HashRollupStream(plist, buckets, rollupMetrics); + SelectStream select = new SelectStream(rollup, getRollupSelectFields(rollupMetrics)); + // the final stream must be sorted based on the original stream sort + return new SortStream(select, getStreamSort()); + } + + /** + * The projection of dimensions and metrics from the rollup stream. + * + * @param rollupMetrics The metrics being rolled up. + * @return A mapping of fields produced by the rollup stream to their output name. + */ + protected Map getRollupSelectFields(Metric[] rollupMetrics) { + Map map = new HashMap<>(); + for (Bucket b : buckets) { + String key = b.toString(); + map.put(key, key); + } + for (Metric m : rollupMetrics) { + String[] cols = m.getColumns(); + String col = cols != null && cols.length > 0 ? cols[0] : "*"; + map.put(m.getIdentifier(), col); + } + return map; + } } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ParallelMetricsRollup.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ParallelMetricsRollup.java new file mode 100644 index 00000000000..0d0f1c31aee --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ParallelMetricsRollup.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.stream; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; + +import org.apache.solr.client.solrj.io.stream.metrics.CountMetric; +import org.apache.solr.client.solrj.io.stream.metrics.MaxMetric; +import org.apache.solr.client.solrj.io.stream.metrics.MeanMetric; +import org.apache.solr.client.solrj.io.stream.metrics.Metric; +import org.apache.solr.client.solrj.io.stream.metrics.MinMetric; +import org.apache.solr.client.solrj.io.stream.metrics.SumMetric; +import org.apache.solr.client.solrj.io.stream.metrics.WeightedSumMetric; + +/** + * Indicates the underlying stream source supports parallelizing metrics computation across collections + * using a rollup of metrics from each collection. + */ +public interface ParallelMetricsRollup { + + /** + * Given a list of collections, return an array of TupleStream for each partition. + * + * @param partitions A list of collections to parallelize metrics computation across. + * @return An array of TupleStream for each partition requested. + * @throws IOException if an error occurs while constructing the underlying TupleStream for a partition. + */ + TupleStream[] parallelize(List partitions) throws IOException; + + /** + * Get the rollup for the parallelized streams that is sorted based on the original (non-parallel) sort order. + * + * @param plistStream A parallel list stream to fetch metrics from each partition concurrently + * @param rollupMetrics An array of metrics to rollup + * @return A rollup over parallelized streams that provide metrics; this is typically a SortStream. + * @throws IOException if an error occurs while reading from the sorted stream + */ + TupleStream getSortedRollupStream(ParallelListStream plistStream, Metric[] rollupMetrics) throws IOException; + + /** + * Given a list of partitions (collections), open a select stream that projects the dimensions and + * metrics produced by rolling up over a parallelized group of streams. If it's not possible to rollup + * the metrics produced by the underlying metrics stream, this method returns Optional.empty. + * + * @param context The current streaming expression context + * @param partitions A list of collections to parallelize metrics computation across. + * @param metrics A list of metrics to rollup. + * @return Either a TupleStream that performs a rollup over parallelized streams or empty if parallelization is not possible. + * @throws IOException if an error occurs reading tuples from the parallelized streams + */ + default Optional openParallelStream(StreamContext context, List partitions, Metric[] metrics) throws IOException { + Optional maybeRollupMetrics = getRollupMetrics(metrics); + if (maybeRollupMetrics.isEmpty()) + return Optional.empty(); // some metric is incompatible with doing a rollup over the plist results + + TupleStream parallelStream = getSortedRollupStream(new ParallelListStream(parallelize(partitions)), maybeRollupMetrics.get()); + parallelStream.setStreamContext(context); + parallelStream.open(); + return Optional.of(parallelStream); + } + + /** + * Either an array of metrics that can be parallelized and rolled up or empty. + * + * @param metrics The list of metrics that we want to parallelize. + * @return Either an array of metrics that can be parallelized and rolled up or empty. + */ + default Optional getRollupMetrics(Metric[] metrics) { + Metric[] rollup = new Metric[metrics.length]; + CountMetric count = null; + for (int m = 0; m < rollup.length; m++) { + Metric nextRollup; + Metric next = metrics[m]; + if (next instanceof SumMetric) { + // sum of sums + nextRollup = new SumMetric(next.getIdentifier()); + } else if (next instanceof MinMetric) { + // min of mins + nextRollup = new MinMetric(next.getIdentifier()); + } else if (next instanceof MaxMetric) { + // max of max + nextRollup = new MaxMetric(next.getIdentifier()); + } else if (next instanceof CountMetric) { + // sum of counts + nextRollup = new SumMetric(next.getIdentifier()); + count = (CountMetric) next; + } else if (next instanceof MeanMetric) { + // WeightedSumMetric must have a count to compute the weighted avg. rollup from ... + // if the user is not requesting count, then we can't parallelize + if (count == null) { + // just look past the current position + for (int n = m + 1; n < metrics.length; n++) { + if (metrics[n] instanceof CountMetric) { + count = (CountMetric) metrics[n]; + break; + } + } + } + if (count != null) { + nextRollup = new WeightedSumMetric(next.getIdentifier(), count.getIdentifier()); + } else { + return Optional.empty(); // can't properly rollup mean metrics w/o a count (reqd by WeightedSumMetric) + } + } else { + return Optional.empty(); // can't parallelize this expr! + } + + rollup[m] = nextRollup; + } + + return Optional.of(rollup); + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java index f41ca0fdeaa..165e446d2ab 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java @@ -86,7 +86,7 @@ public class MaxMetric extends Metric { if(l > longMax) { longMax = l; } - } else { + } else if(o instanceof Long) { long l = (long)o; if(l > longMax) { longMax = l; diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java index ca0d48e2f0a..a35a59412c9 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java @@ -83,7 +83,7 @@ public class MeanMetric extends Metric { } else if(o instanceof Integer) { Integer i = (Integer)o; longSum += i.longValue(); - } else { + } else if (o instanceof Long) { Long l = (Long)o; longSum += l; } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java index e23dcda6c5a..910cf4a591f 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java @@ -87,7 +87,7 @@ public class MinMetric extends Metric { if(l < longMin) { longMin = l; } - } else { + } else if(o instanceof Long) { long l = (long)o; if(l < longMin) { longMin = l; diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java index a82e77507b5..70676f269d2 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java @@ -70,7 +70,7 @@ public class SumMetric extends Metric { } else if(o instanceof Integer) { Integer i = (Integer)o; longSum += i.longValue(); - } else { + } else if (o instanceof Long) { Long l = (Long)o; longSum += l; } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/WeightedSumMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/WeightedSumMetric.java new file mode 100644 index 00000000000..ba96a71e5c9 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/WeightedSumMetric.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.stream.metrics; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; + +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class WeightedSumMetric extends Metric { + + public static final String FUNC = "wsum"; + private String valueCol; + private String countCol; + private List parts; + + public WeightedSumMetric(String valueCol, String countCol) { + init(valueCol, countCol, false); + } + + public WeightedSumMetric(String valueCol, String countCol, boolean outputLong) { + init(valueCol, countCol, outputLong); + } + + public WeightedSumMetric(StreamExpression expression, StreamFactory factory) throws IOException { + // grab all parameters out + String functionName = expression.getFunctionName(); + if (!FUNC.equals(functionName)) { + throw new IOException("Expected '" + FUNC + "' function but found " + functionName); + } + String valueCol = factory.getValueOperand(expression, 0); + String countCol = factory.getValueOperand(expression, 1); + String outputLong = factory.getValueOperand(expression, 2); + + // validate expression contains only what we want. + if (null == valueCol) { + throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - expected %s(valueCol,countCol)", expression, FUNC)); + } + + boolean ol = false; + if (outputLong != null) { + ol = Boolean.parseBoolean(outputLong); + } + + init(valueCol, countCol, ol); + } + + private void init(String valueCol, String countCol, boolean outputLong) { + this.valueCol = valueCol; + this.countCol = countCol != null ? countCol : "count(*)"; + this.outputLong = outputLong; + setFunctionName(FUNC); + setIdentifier(FUNC, "(", valueCol, ", " + countCol + ", " + outputLong + ")"); + } + + public void update(Tuple tuple) { + Object c = tuple.get(countCol); + Object o = tuple.get(valueCol); + if (c instanceof Number && o instanceof Number) { + if (parts == null) { + parts = new LinkedList<>(); + } + Number count = (Number) c; + Number value = (Number) o; + parts.add(new Part(count.longValue(), value.doubleValue())); + } + } + + public Metric newInstance() { + return new WeightedSumMetric(valueCol, countCol, outputLong); + } + + public String[] getColumns() { + return new String[]{valueCol, countCol}; + } + + public Number getValue() { + long total = sumCounts(); + double wavg = 0d; + if (total > 0L) { + for (Part next : parts) { + wavg += next.weighted(total); + } + } + return outputLong ? Math.round(wavg) : wavg; + } + + private long sumCounts() { + long total = 0L; + if (parts != null) { + for (Part next : parts) { + total += next.count; + } + } + return total; + } + + @Override + public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { + return new StreamExpression(getFunctionName()).withParameter(valueCol).withParameter(countCol).withParameter(Boolean.toString(outputLong)); + } + + private static final class Part { + private final double value; + private final long count; + + Part(long count, double value) { + this.count = count; + this.value = value; + } + + private double weighted(final long total) { + return ((double) count / total) * value; + } + } +} diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/ParallelFacetStreamOverAliasTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/ParallelFacetStreamOverAliasTest.java new file mode 100644 index 00000000000..25a67f9c039 --- /dev/null +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/ParallelFacetStreamOverAliasTest.java @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.stream; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.SortedMap; +import java.util.TreeMap; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.random.JDKRandomGenerator; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.util.Precision; +import org.apache.lucene.util.LuceneTestCase; +import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.client.solrj.SolrServerException; +import org.apache.solr.client.solrj.io.SolrClientCache; +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; +import org.apache.solr.client.solrj.io.stream.metrics.Metric; +import org.apache.solr.client.solrj.request.CollectionAdminRequest; +import org.apache.solr.client.solrj.request.UpdateRequest; +import org.apache.solr.cloud.SolrCloudTestCase; +import org.apache.solr.handler.SolrDefaultStreamFactory; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * Verify auto-plist with rollup over a facet expression when using collection alias over multiple collections. + */ +@SolrTestCaseJ4.SuppressSSL +@LuceneTestCase.SuppressCodecs({"Lucene3x", "Lucene40", "Lucene41", "Lucene42", "Lucene45"}) +public class ParallelFacetStreamOverAliasTest extends SolrCloudTestCase { + + private static final String ALIAS_NAME = "SOME_ALIAS_WITH_MANY_COLLS"; + + private static final String id = "id"; + private static final int NUM_COLLECTIONS = 2; // this test requires at least 2 collections, each with multiple shards + private static final int NUM_DOCS_PER_COLLECTION = 40; + private static final int NUM_SHARDS_PER_COLLECTION = 4; + private static final int CARDINALITY = 10; + private static final RandomGenerator rand = new JDKRandomGenerator(5150); + private static List listOfCollections; + private static SolrClientCache solrClientCache; + + @BeforeClass + public static void setupCluster() throws Exception { + System.setProperty("solr.tests.numeric.dv", "true"); + + configureCluster(NUM_COLLECTIONS).withMetrics(false) + .addConfig("conf", getFile("solrj").toPath().resolve("solr").resolve("configsets").resolve("streaming").resolve("conf")) + .configure(); + cleanup(); + setupCollectionsAndAlias(); + + solrClientCache = new SolrClientCache(); + } + + /** + * setup the testbed with necessary collections, documents, and alias + */ + public static void setupCollectionsAndAlias() throws Exception { + + final NormalDistribution[] dists = new NormalDistribution[CARDINALITY]; + for (int i = 0; i < dists.length; i++) { + dists[i] = new NormalDistribution(rand, i + 1, 1d); + } + + List collections = new ArrayList<>(NUM_COLLECTIONS); + final List errors = new LinkedList<>(); + Stream.iterate(1, n -> n + 1).limit(NUM_COLLECTIONS).forEach(colIdx -> { + final String collectionName = "coll" + colIdx; + collections.add(collectionName); + try { + CollectionAdminRequest.createCollection(collectionName, "conf", NUM_SHARDS_PER_COLLECTION, 1).process(cluster.getSolrClient()); + cluster.waitForActiveCollection(collectionName, NUM_SHARDS_PER_COLLECTION, NUM_SHARDS_PER_COLLECTION); + + // want a variable num of docs per collection so that avg of avg does not work ;-) + final int numDocsInColl = colIdx % 2 == 0 ? NUM_DOCS_PER_COLLECTION / 2 : NUM_DOCS_PER_COLLECTION; + final int limit = NUM_COLLECTIONS == 1 ? NUM_DOCS_PER_COLLECTION * 2 : numDocsInColl; + UpdateRequest ur = new UpdateRequest(); + Stream.iterate(0, n -> n + 1).limit(limit) + .forEach(docId -> ur.add(id, UUID.randomUUID().toString(), + "a_s", "hello" + docId, + "a_i", String.valueOf(docId % CARDINALITY), + "b_i", rand.nextBoolean() ? "1" : "0", + "a_d", String.valueOf(dists[docId % dists.length].sample()))); + ur.commit(cluster.getSolrClient(), collectionName); + } catch (SolrServerException | IOException e) { + errors.add(e); + } + }); + + if (!errors.isEmpty()) { + throw errors.get(0); + } + + listOfCollections = collections; + String aliasedCollectionString = String.join(",", collections); + CollectionAdminRequest.createAlias(ALIAS_NAME, aliasedCollectionString).process(cluster.getSolrClient()); + } + + public static void cleanup() throws Exception { + if (cluster != null && cluster.getSolrClient() != null) { + // cleanup the alias and the collections behind it + CollectionAdminRequest.deleteAlias(ALIAS_NAME).process(cluster.getSolrClient()); + if (listOfCollections != null) { + final List errors = new LinkedList<>(); + listOfCollections.stream().map(CollectionAdminRequest::deleteCollection).forEach(c -> { + try { + c.process(cluster.getSolrClient()); + } catch (SolrServerException | IOException e) { + errors.add(e); + } + }); + if (!errors.isEmpty()) { + throw errors.get(0); + } + } + } + } + + @AfterClass + public static void after() throws Exception { + cleanup(); + + if (solrClientCache != null) { + solrClientCache.close(); + } + } + + /** + * Test parallelized calls to facet expression, one for each collection in the alias + */ + @Test + public void testParallelFacetOverAlias() throws Exception { + + String facetExprTmpl = "" + + "facet(\n" + + " %s,\n" + + " tiered=%s,\n" + + " q=\"*:*\", \n" + + " buckets=\"a_i\", \n" + + " bucketSorts=\"a_i asc\", \n" + + " bucketSizeLimit=100, \n" + + " sum(a_d), avg(a_d), min(a_d), max(a_d), count(*)\n" + + ")\n"; + + compareTieredStreamWithNonTiered(facetExprTmpl, 1); + } + + /** + * Test parallelized calls to facet expression with multiple dimensions, one for each collection in the alias + */ + @Test + public void testParallelFacetMultipleDimensionsOverAlias() throws Exception { + + // notice we're sorting the stream by a metric, but internally, that doesn't work for parallelization + // so the rollup has to sort by dimensions and then apply a final re-sort once the parallel streams are merged + String facetExprTmpl = "" + + "facet(\n" + + " %s,\n" + + " tiered=%s,\n" + + " q=\"*:*\", \n" + + " buckets=\"a_i,b_i\", \n" + /* two dimensions here ~ doubles the number of tuples */ + " bucketSorts=\"sum(a_d) desc\", \n" + + " bucketSizeLimit=100, \n" + + " sum(a_d), avg(a_d), min(a_d), max(a_d), count(*)\n" + + ")\n"; + + compareTieredStreamWithNonTiered(facetExprTmpl, 2); + } + + @Test + public void testParallelFacetSortByDimensions() throws Exception { + + // notice we're sorting the stream by a metric, but internally, that doesn't work for parallelization + // so the rollup has to sort by dimensions and then apply a final re-sort once the parallel streams are merged + String facetExprTmpl = "" + + "facet(\n" + + " %s,\n" + + " tiered=%s,\n" + + " q=\"*:*\", \n" + + " buckets=\"a_i,b_i\", \n" + + " bucketSorts=\"a_i asc, b_i asc\", \n" + + " bucketSizeLimit=100, \n" + + " sum(a_d), avg(a_d), min(a_d), max(a_d), count(*)\n" + + ")\n"; + + compareTieredStreamWithNonTiered(facetExprTmpl, 2); + } + + // execute the provided expression with tiered=true and compare to results of tiered=false + private void compareTieredStreamWithNonTiered(String facetExprTmpl, int dims) throws IOException { + String facetExpr = String.format(Locale.US, facetExprTmpl, ALIAS_NAME, "true"); + + StreamContext streamContext = new StreamContext(); + streamContext.setSolrClientCache(solrClientCache); + StreamFactory factory = new SolrDefaultStreamFactory().withDefaultZkHost(cluster.getZkServer().getZkAddress()); + + TupleStream stream = factory.constructStream(facetExpr); + stream.setStreamContext(streamContext); + + // check the parallel setup logic + assertParallelFacetStreamConfig(stream, dims); + + List plistTuples = getTuples(stream); + assertEquals(CARDINALITY * dims, plistTuples.size()); + + // now re-execute the same expression w/o plist + facetExpr = String.format(Locale.US, facetExprTmpl, ALIAS_NAME, "false"); + stream = factory.constructStream(facetExpr); + stream.setStreamContext(streamContext); + List tuples = getTuples(stream); + assertEquals(CARDINALITY * dims, tuples.size()); + + // results should be identical regardless of tiered=true|false + assertListOfTuplesEquals(plistTuples, tuples); + } + + private void assertParallelFacetStreamConfig(TupleStream stream, int dims) throws IOException { + assertTrue(stream instanceof FacetStream); + FacetStream facetStream = (FacetStream) stream; + TupleStream[] parallelStreams = facetStream.parallelize(listOfCollections); + assertEquals(2, parallelStreams.length); + assertTrue(parallelStreams[0] instanceof FacetStream); + + Optional rollupMetrics = facetStream.getRollupMetrics(facetStream.getMetrics().toArray(new Metric[0])); + assertTrue(rollupMetrics.isPresent()); + assertEquals(5, rollupMetrics.get().length); + Map selectFields = facetStream.getRollupSelectFields(rollupMetrics.get()); + assertNotNull(selectFields); + assertEquals(5 + dims /* num metrics + num dims */, selectFields.size()); + assertEquals("a_i", selectFields.get("a_i")); + assertEquals("max(a_d)", selectFields.get("max(max(a_d))")); + assertEquals("min(a_d)", selectFields.get("min(min(a_d))")); + assertEquals("sum(a_d)", selectFields.get("sum(sum(a_d))")); + assertEquals("avg(a_d)", selectFields.get("wsum(avg(a_d), count(*), false)")); + assertEquals("count(*)", selectFields.get("sum(count(*))")); + if (dims > 1) { + assertEquals("b_i", selectFields.get("b_i")); + } + } + + // assert results are the same, with some sorting and rounding of floating point values + private void assertListOfTuplesEquals(List exp, List act) { + List> expList = exp.stream().map(this::toComparableMap).collect(Collectors.toList()); + List> actList = act.stream().map(this::toComparableMap).collect(Collectors.toList()); + assertEquals(expList, actList); + } + + private SortedMap toComparableMap(Tuple t) { + SortedMap cmap = new TreeMap<>(); + for (Map.Entry e : t.getFields().entrySet()) { + Object value = e.getValue(); + if (value instanceof Double) { + cmap.put(e.getKey(), Precision.round((Double) value, 5)); + } else if (value instanceof Float) { + cmap.put(e.getKey(), Precision.round((Float) value, 3)); + } else { + cmap.put(e.getKey(), e.getValue()); + } + } + return cmap; + } + + List getTuples(TupleStream tupleStream) throws IOException { + List tuples = new ArrayList<>(); + try (tupleStream) { + tupleStream.open(); + for (Tuple t = tupleStream.read(); !t.EOF; t = tupleStream.read()) { + tuples.add(t); + } + } + return tuples; + } +} diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/metrics/WeightedSumMetricTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/metrics/WeightedSumMetricTest.java new file mode 100644 index 00000000000..3b3b78d1ac3 --- /dev/null +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/metrics/WeightedSumMetricTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.stream.metrics; + +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; +import org.apache.solr.handler.SolrDefaultStreamFactory; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public class WeightedSumMetricTest { + + final long[] counts = new long[]{10, 20, 30, 40}; + final double[] avg = new double[]{2, 4, 6, 8}; + + @Test + public void testWsumPojo() throws Exception { + WeightedSumMetric wsum = new WeightedSumMetric("avg", "count"); + assertEquals("wsum(avg, count, false)", wsum.getIdentifier()); + assertArrayEquals(new String[]{"avg", "count"}, wsum.getColumns()); + + StreamFactory factory = new SolrDefaultStreamFactory(); + StreamExpressionParameter expr = wsum.toExpression(factory); + assertTrue(expr instanceof StreamExpression); + wsum = new WeightedSumMetric((StreamExpression) expr, factory); + + double expectedSum = 0d; + for (int i = 0; i < counts.length; i++) { + expectedSum += ((double) counts[i] / 100) * avg[i]; + } + long expectedSumLong = Math.round(expectedSum); + + Number weightedSum = updateMetric(wsum); + assertNotNull(weightedSum); + assertTrue(weightedSum instanceof Double); + assertTrue(weightedSum.doubleValue() == expectedSum); + + wsum = new WeightedSumMetric("avg", "count", true); + assertEquals("wsum(avg, count, true)", wsum.getIdentifier()); + weightedSum = updateMetric(wsum); + assertNotNull(weightedSum); + assertTrue(weightedSum.longValue() == expectedSumLong); + } + + private Number updateMetric(WeightedSumMetric wsum) { + for (int i = 0; i < counts.length; i++) { + Tuple t = new Tuple(); + t.put("avg", avg[i]); + t.put("count", counts[i]); + wsum.update(t); + } + return wsum.getValue(); + } +}