Add early termination support to BucketCollector (#33279)

This commit adds the support to early terminate the collection of a leaf
in the aggregation framework. This change introduces a MultiBucketCollector which
handles CollectionTerminatedException exactly like the Lucene MultiCollector.
Any aggregator can now throw a CollectionTerminatedException without stopping
the collection of a sibling aggregator. This is useful for aggregators that
can infer their result without visiting all documents (e.g.: a min/max aggregation on a match_all query).
This commit is contained in:
Jim Ferenczi 2018-09-03 09:34:35 +02:00 committed by GitHub
parent 3c367a2c46
commit 713c07e14d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 483 additions and 68 deletions

View File

@ -596,7 +596,7 @@ And the response:
] ]
}, },
{ {
"name": "BucketCollector: [[my_scoped_agg, my_global_agg]]", "name": "MultiBucketCollector: [[my_scoped_agg, my_global_agg]]",
"reason": "aggregation", "reason": "aggregation",
"time_in_nanos": 8273 "time_in_nanos": 8273
} }

View File

@ -60,7 +60,7 @@ public class AggregationPhase implements SearchPhase {
} }
context.aggregations().aggregators(aggregators); context.aggregations().aggregators(aggregators);
if (!collectors.isEmpty()) { if (!collectors.isEmpty()) {
Collector collector = BucketCollector.wrap(collectors); Collector collector = MultiBucketCollector.wrap(collectors);
((BucketCollector)collector).preCollection(); ((BucketCollector)collector).preCollection();
if (context.getProfilers() != null) { if (context.getProfilers() != null) {
collector = new InternalProfileCollector(collector, CollectorResult.REASON_AGGREGATION, collector = new InternalProfileCollector(collector, CollectorResult.REASON_AGGREGATION,
@ -97,7 +97,7 @@ public class AggregationPhase implements SearchPhase {
// optimize the global collector based execution // optimize the global collector based execution
if (!globals.isEmpty()) { if (!globals.isEmpty()) {
BucketCollector globalsCollector = BucketCollector.wrap(globals); BucketCollector globalsCollector = MultiBucketCollector.wrap(globals);
Query query = context.buildFilteredQuery(Queries.newMatchAllQuery()); Query query = context.buildFilteredQuery(Queries.newMatchAllQuery());
try { try {

View File

@ -183,7 +183,7 @@ public abstract class AggregatorBase extends Aggregator {
@Override @Override
public final void preCollection() throws IOException { public final void preCollection() throws IOException {
List<BucketCollector> collectors = Arrays.asList(subAggregators); List<BucketCollector> collectors = Arrays.asList(subAggregators);
collectableSubAggregators = BucketCollector.wrap(collectors); collectableSubAggregators = MultiBucketCollector.wrap(collectors);
doPreCollection(); doPreCollection();
collectableSubAggregators.preCollection(); collectableSubAggregators.preCollection();
} }

View File

@ -24,10 +24,6 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.StreamSupport;
/** /**
* A Collector that can collect data in separate buckets. * A Collector that can collect data in separate buckets.
@ -54,61 +50,6 @@ public abstract class BucketCollector implements Collector {
} }
}; };
/**
* Wrap the given collectors into a single instance.
*/
public static BucketCollector wrap(Iterable<? extends BucketCollector> collectorList) {
final BucketCollector[] collectors =
StreamSupport.stream(collectorList.spliterator(), false).toArray(size -> new BucketCollector[size]);
switch (collectors.length) {
case 0:
return NO_OP_COLLECTOR;
case 1:
return collectors[0];
default:
return new BucketCollector() {
@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException {
List<LeafBucketCollector> leafCollectors = new ArrayList<>(collectors.length);
for (BucketCollector c : collectors) {
leafCollectors.add(c.getLeafCollector(ctx));
}
return LeafBucketCollector.wrap(leafCollectors);
}
@Override
public void preCollection() throws IOException {
for (BucketCollector collector : collectors) {
collector.preCollection();
}
}
@Override
public void postCollection() throws IOException {
for (BucketCollector collector : collectors) {
collector.postCollection();
}
}
@Override
public boolean needsScores() {
for (BucketCollector collector : collectors) {
if (collector.needsScores()) {
return true;
}
}
return false;
}
@Override
public String toString() {
return Arrays.toString(collectors);
}
};
}
}
@Override @Override
public abstract LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException; public abstract LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException;

View File

@ -0,0 +1,207 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.ScoreCachingWrappingScorer;
import org.apache.lucene.search.Scorer;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* A {@link BucketCollector} which allows running a bucket collection with several
* {@link BucketCollector}s. It is similar to the {@link MultiCollector} except that the
* {@link #wrap} method filters out the {@link BucketCollector#NO_OP_COLLECTOR}s and not
* the null ones.
*/
public class MultiBucketCollector extends BucketCollector {
/** See {@link #wrap(Iterable)}. */
public static BucketCollector wrap(BucketCollector... collectors) {
return wrap(Arrays.asList(collectors));
}
/**
* Wraps a list of {@link BucketCollector}s with a {@link MultiBucketCollector}. This
* method works as follows:
* <ul>
* <li>Filters out the {@link BucketCollector#NO_OP_COLLECTOR}s collectors, so they are not used
* during search time.
* <li>If the input contains 1 real collector, it is returned.
* <li>Otherwise the method returns a {@link MultiBucketCollector} which wraps the
* non-{@link BucketCollector#NO_OP_COLLECTOR} collectors.
* </ul>
*/
public static BucketCollector wrap(Iterable<? extends BucketCollector> collectors) {
// For the user's convenience, we allow NO_OP collectors to be passed.
// However, to improve performance, these null collectors are found
// and dropped from the array we save for actual collection time.
int n = 0;
for (BucketCollector c : collectors) {
if (c != NO_OP_COLLECTOR) {
n++;
}
}
if (n == 0) {
return NO_OP_COLLECTOR;
} else if (n == 1) {
// only 1 Collector - return it.
BucketCollector col = null;
for (BucketCollector c : collectors) {
if (c != null) {
col = c;
break;
}
}
return col;
} else {
BucketCollector[] colls = new BucketCollector[n];
n = 0;
for (BucketCollector c : collectors) {
if (c != null) {
colls[n++] = c;
}
}
return new MultiBucketCollector(colls);
}
}
private final boolean cacheScores;
private final BucketCollector[] collectors;
private MultiBucketCollector(BucketCollector... collectors) {
this.collectors = collectors;
int numNeedsScores = 0;
for (BucketCollector collector : collectors) {
if (collector.needsScores()) {
numNeedsScores += 1;
}
}
this.cacheScores = numNeedsScores >= 2;
}
@Override
public void preCollection() throws IOException {
for (BucketCollector collector : collectors) {
collector.preCollection();
}
}
@Override
public void postCollection() throws IOException {
for (BucketCollector collector : collectors) {
collector.postCollection();
}
}
@Override
public boolean needsScores() {
for (BucketCollector collector : collectors) {
if (collector.needsScores()) {
return true;
}
}
return false;
}
@Override
public String toString() {
return Arrays.toString(collectors);
}
@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext context) throws IOException {
final List<LeafBucketCollector> leafCollectors = new ArrayList<>();
for (BucketCollector collector : collectors) {
final LeafBucketCollector leafCollector;
try {
leafCollector = collector.getLeafCollector(context);
} catch (CollectionTerminatedException e) {
// this leaf collector does not need this segment
continue;
}
leafCollectors.add(leafCollector);
}
switch (leafCollectors.size()) {
case 0:
throw new CollectionTerminatedException();
case 1:
return leafCollectors.get(0);
default:
return new MultiLeafBucketCollector(leafCollectors, cacheScores);
}
}
private static class MultiLeafBucketCollector extends LeafBucketCollector {
private final boolean cacheScores;
private final LeafBucketCollector[] collectors;
private int numCollectors;
private MultiLeafBucketCollector(List<LeafBucketCollector> collectors, boolean cacheScores) {
this.collectors = collectors.toArray(new LeafBucketCollector[collectors.size()]);
this.cacheScores = cacheScores;
this.numCollectors = this.collectors.length;
}
@Override
public void setScorer(Scorer scorer) throws IOException {
if (cacheScores) {
scorer = new ScoreCachingWrappingScorer(scorer);
}
for (int i = 0; i < numCollectors; ++i) {
final LeafCollector c = collectors[i];
c.setScorer(scorer);
}
}
private void removeCollector(int i) {
System.arraycopy(collectors, i + 1, collectors, i, numCollectors - i - 1);
--numCollectors;
collectors[numCollectors] = null;
}
@Override
public void collect(int doc, long bucket) throws IOException {
final LeafBucketCollector[] collectors = this.collectors;
int numCollectors = this.numCollectors;
for (int i = 0; i < numCollectors; ) {
final LeafBucketCollector collector = collectors[i];
try {
collector.collect(doc, bucket);
++i;
} catch (CollectionTerminatedException e) {
removeCollector(i);
numCollectors = this.numCollectors;
if (numCollectors == 0) {
throw new CollectionTerminatedException();
}
}
}
}
}
}

View File

@ -33,6 +33,7 @@ import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.BucketCollector; import org.elasticsearch.search.aggregations.BucketCollector;
import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.MultiBucketCollector;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException; import java.io.IOException;
@ -90,7 +91,7 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector {
/** Set the deferred collectors. */ /** Set the deferred collectors. */
@Override @Override
public void setDeferredCollector(Iterable<BucketCollector> deferredCollectors) { public void setDeferredCollector(Iterable<BucketCollector> deferredCollectors) {
this.collector = BucketCollector.wrap(deferredCollectors); this.collector = MultiBucketCollector.wrap(deferredCollectors);
} }
private void finishLeaf() { private void finishLeaf() {

View File

@ -22,6 +22,7 @@ package org.elasticsearch.search.aggregations.bucket;
import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.BucketCollector; import org.elasticsearch.search.aggregations.BucketCollector;
import org.elasticsearch.search.aggregations.MultiBucketCollector;
import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregator; import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
@ -59,7 +60,7 @@ public abstract class DeferableBucketAggregator extends BucketsAggregator {
recordingWrapper.setDeferredCollector(deferredCollectors); recordingWrapper.setDeferredCollector(deferredCollectors);
collectors.add(recordingWrapper); collectors.add(recordingWrapper);
} }
collectableSubAggregators = BucketCollector.wrap(collectors); collectableSubAggregators = MultiBucketCollector.wrap(collectors);
} }
public static boolean descendsFromGlobalAggregator(Aggregator parent) { public static boolean descendsFromGlobalAggregator(Aggregator parent) {

View File

@ -31,6 +31,7 @@ import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.BucketCollector; import org.elasticsearch.search.aggregations.BucketCollector;
import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.MultiBucketCollector;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException; import java.io.IOException;
@ -61,7 +62,7 @@ public class MergingBucketsDeferringCollector extends DeferringBucketCollector {
@Override @Override
public void setDeferredCollector(Iterable<BucketCollector> deferredCollectors) { public void setDeferredCollector(Iterable<BucketCollector> deferredCollectors) {
this.collector = BucketCollector.wrap(deferredCollectors); this.collector = MultiBucketCollector.wrap(deferredCollectors);
} }
@Override @Override

View File

@ -38,6 +38,7 @@ import org.elasticsearch.search.aggregations.BucketCollector;
import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.LeafBucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.MultiBucketCollector;
import org.elasticsearch.search.aggregations.bucket.BucketsAggregator; import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSource;
@ -93,7 +94,7 @@ final class CompositeAggregator extends BucketsAggregator {
@Override @Override
protected void doPreCollection() throws IOException { protected void doPreCollection() throws IOException {
List<BucketCollector> collectors = Arrays.asList(subAggregators); List<BucketCollector> collectors = Arrays.asList(subAggregators);
deferredCollectors = BucketCollector.wrap(collectors); deferredCollectors = MultiBucketCollector.wrap(collectors);
collectableSubAggregators = BucketCollector.NO_OP_COLLECTOR; collectableSubAggregators = BucketCollector.NO_OP_COLLECTOR;
} }

View File

@ -33,6 +33,7 @@ import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.common.util.ObjectArray;
import org.elasticsearch.search.aggregations.BucketCollector; import org.elasticsearch.search.aggregations.BucketCollector;
import org.elasticsearch.search.aggregations.LeafBucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.MultiBucketCollector;
import org.elasticsearch.search.aggregations.bucket.DeferringBucketCollector; import org.elasticsearch.search.aggregations.bucket.DeferringBucketCollector;
import java.io.IOException; import java.io.IOException;
@ -76,7 +77,7 @@ public class BestDocsDeferringCollector extends DeferringBucketCollector impleme
/** Set the deferred collectors. */ /** Set the deferred collectors. */
@Override @Override
public void setDeferredCollector(Iterable<BucketCollector> deferredCollectors) { public void setDeferredCollector(Iterable<BucketCollector> deferredCollectors) {
this.deferred = BucketCollector.wrap(deferredCollectors); this.deferred = MultiBucketCollector.wrap(deferredCollectors);
} }
@Override @Override

View File

@ -0,0 +1,262 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
public class MultiBucketCollectorTests extends ESTestCase {
private static class FakeScorer extends Scorer {
float score;
int doc = -1;
FakeScorer() {
super(null);
}
@Override
public int docID() {
return doc;
}
@Override
public float score() {
return score;
}
@Override
public DocIdSetIterator iterator() {
throw new UnsupportedOperationException();
}
@Override
public Weight getWeight() {
throw new UnsupportedOperationException();
}
@Override
public Collection<ChildScorer> getChildren() {
throw new UnsupportedOperationException();
}
}
private static class TerminateAfterBucketCollector extends BucketCollector {
private int count = 0;
private final int terminateAfter;
private final BucketCollector in;
TerminateAfterBucketCollector(BucketCollector in, int terminateAfter) {
this.in = in;
this.terminateAfter = terminateAfter;
}
@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext context) throws IOException {
if (count >= terminateAfter) {
throw new CollectionTerminatedException();
}
final LeafBucketCollector leafCollector = in.getLeafCollector(context);
return new LeafBucketCollectorBase(leafCollector, null) {
@Override
public void collect(int doc, long bucket) throws IOException {
if (count >= terminateAfter) {
throw new CollectionTerminatedException();
}
super.collect(doc, bucket);
count++;
}
};
}
@Override
public boolean needsScores() {
return false;
}
@Override
public void preCollection() {}
@Override
public void postCollection() {}
}
private static class TotalHitCountBucketCollector extends BucketCollector {
private int count = 0;
TotalHitCountBucketCollector() {
}
@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext context) {
return new LeafBucketCollector() {
@Override
public void collect(int doc, long bucket) throws IOException {
count++;
}
};
}
@Override
public boolean needsScores() {
return false;
}
@Override
public void preCollection() {}
@Override
public void postCollection() {}
int getTotalHits() {
return count;
}
}
private static class SetScorerBucketCollector extends BucketCollector {
private final BucketCollector in;
private final AtomicBoolean setScorerCalled;
SetScorerBucketCollector(BucketCollector in, AtomicBoolean setScorerCalled) {
this.in = in;
this.setScorerCalled = setScorerCalled;
}
@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext context) throws IOException {
final LeafBucketCollector leafCollector = in.getLeafCollector(context);
return new LeafBucketCollectorBase(leafCollector, null) {
@Override
public void setScorer(Scorer scorer) throws IOException {
super.setScorer(scorer);
setScorerCalled.set(true);
}
};
}
@Override
public boolean needsScores() {
return false;
}
@Override
public void preCollection() {}
@Override
public void postCollection() {}
}
public void testCollectionTerminatedExceptionHandling() throws IOException {
final int iters = atLeast(3);
for (int iter = 0; iter < iters; ++iter) {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
final int numDocs = randomIntBetween(100, 1000);
final Document doc = new Document();
for (int i = 0; i < numDocs; ++i) {
w.addDocument(doc);
}
final IndexReader reader = w.getReader();
w.close();
final IndexSearcher searcher = newSearcher(reader);
Map<TotalHitCountBucketCollector, Integer> expectedCounts = new HashMap<>();
List<BucketCollector> collectors = new ArrayList<>();
final int numCollectors = randomIntBetween(1, 5);
for (int i = 0; i < numCollectors; ++i) {
final int terminateAfter = random().nextInt(numDocs + 10);
final int expectedCount = terminateAfter > numDocs ? numDocs : terminateAfter;
TotalHitCountBucketCollector collector = new TotalHitCountBucketCollector();
expectedCounts.put(collector, expectedCount);
collectors.add(new TerminateAfterBucketCollector(collector, terminateAfter));
}
searcher.search(new MatchAllDocsQuery(), MultiBucketCollector.wrap(collectors));
for (Map.Entry<TotalHitCountBucketCollector, Integer> expectedCount : expectedCounts.entrySet()) {
assertEquals(expectedCount.getValue().intValue(), expectedCount.getKey().getTotalHits());
}
reader.close();
dir.close();
}
}
public void testSetScorerAfterCollectionTerminated() throws IOException {
BucketCollector collector1 = new TotalHitCountBucketCollector();
BucketCollector collector2 = new TotalHitCountBucketCollector();
AtomicBoolean setScorerCalled1 = new AtomicBoolean();
collector1 = new SetScorerBucketCollector(collector1, setScorerCalled1);
AtomicBoolean setScorerCalled2 = new AtomicBoolean();
collector2 = new SetScorerBucketCollector(collector2, setScorerCalled2);
collector1 = new TerminateAfterBucketCollector(collector1, 1);
collector2 = new TerminateAfterBucketCollector(collector2, 2);
Scorer scorer = new FakeScorer();
List<BucketCollector> collectors = Arrays.asList(collector1, collector2);
Collections.shuffle(collectors, random());
BucketCollector collector = MultiBucketCollector.wrap(collectors);
LeafBucketCollector leafCollector = collector.getLeafCollector(null);
leafCollector.setScorer(scorer);
assertTrue(setScorerCalled1.get());
assertTrue(setScorerCalled2.get());
leafCollector.collect(0);
leafCollector.collect(1);
setScorerCalled1.set(false);
setScorerCalled2.set(false);
leafCollector.setScorer(scorer);
assertFalse(setScorerCalled1.get());
assertTrue(setScorerCalled2.get());
expectThrows(CollectionTerminatedException.class, () -> {
leafCollector.collect(1);
});
setScorerCalled1.set(false);
setScorerCalled2.set(false);
leafCollector.setScorer(scorer);
assertFalse(setScorerCalled1.get());
assertFalse(setScorerCalled2.get());
}
}