Fix NPE on composite aggregation with sub-aggregations that need scores (#28129)

The composite aggregation defers the collection of sub-aggregations to a second pass that visits documents only if they
appear in the top buckets. Though the scorer for sub-aggregations is not set on this second pass and generates an NPE if any sub-aggregation
tries to access the score. This change creates a scorer for the second pass and makes sure that sub-aggs can use it safely to check the score of
the collected documents.
This commit is contained in:
Jim Ferenczi 2018-01-15 18:30:38 +01:00 committed by GitHub
parent ee7eac8dc1
commit bd11e6c441
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 123 additions and 12 deletions

View File

@ -23,6 +23,9 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSet; import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.RoaringDocIdSet; import org.apache.lucene.util.RoaringDocIdSet;
import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactories;
@ -87,6 +90,12 @@ final class CompositeAggregator extends BucketsAggregator {
// Replay all documents that contain at least one top bucket (collected during the first pass). // Replay all documents that contain at least one top bucket (collected during the first pass).
grow(keys.size()+1); grow(keys.size()+1);
final boolean needsScores = needsScores();
Weight weight = null;
if (needsScores) {
Query query = context.query();
weight = context.searcher().createNormalizedWeight(query, true);
}
for (LeafContext context : contexts) { for (LeafContext context : contexts) {
DocIdSetIterator docIdSetIterator = context.docIdSet.iterator(); DocIdSetIterator docIdSetIterator = context.docIdSet.iterator();
if (docIdSetIterator == null) { if (docIdSetIterator == null) {
@ -95,7 +104,21 @@ final class CompositeAggregator extends BucketsAggregator {
final CompositeValuesSource.Collector collector = final CompositeValuesSource.Collector collector =
array.getLeafCollector(context.ctx, getSecondPassCollector(context.subCollector)); array.getLeafCollector(context.ctx, getSecondPassCollector(context.subCollector));
int docID; int docID;
DocIdSetIterator scorerIt = null;
if (needsScores) {
Scorer scorer = weight.scorer(context.ctx);
// We don't need to check if the scorer is null
// since we are sure that there are documents to replay (docIdSetIterator it not empty).
scorerIt = scorer.iterator();
context.subCollector.setScorer(scorer);
}
while ((docID = docIdSetIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { while ((docID = docIdSetIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
if (needsScores) {
assert scorerIt.docID() < docID;
scorerIt.advance(docID);
// aggregations should only be replayed on matching documents
assert scorerIt.docID() == docID;
}
collector.collect(docID); collector.collect(docID);
} }
} }

View File

@ -50,6 +50,8 @@ import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval; import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHits;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder;
import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.IndexSettingsModule; import org.elasticsearch.test.IndexSettingsModule;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
@ -1065,8 +1067,73 @@ public class CompositeAggregatorTests extends AggregatorTestCase {
); );
} }
private void testSearchCase(Query query, public void testWithKeywordAndTopHits() throws Exception {
Sort sort, final List<Map<String, List<Object>>> dataset = new ArrayList<>();
dataset.addAll(
Arrays.asList(
createDocument("keyword", "a"),
createDocument("keyword", "c"),
createDocument("keyword", "a"),
createDocument("keyword", "d"),
createDocument("keyword", "c")
)
);
final Sort sort = new Sort(new SortedSetSortField("keyword", false));
testSearchCase(new MatchAllDocsQuery(), sort, dataset,
() -> {
TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
.field("keyword");
return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
}, (result) -> {
assertEquals(3, result.getBuckets().size());
assertEquals("{keyword=a}", result.getBuckets().get(0).getKeyAsString());
assertEquals(2L, result.getBuckets().get(0).getDocCount());
TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 2);
assertEquals(topHits.getHits().getTotalHits(), 2L);
assertEquals("{keyword=c}", result.getBuckets().get(1).getKeyAsString());
assertEquals(2L, result.getBuckets().get(1).getDocCount());
topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 2);
assertEquals(topHits.getHits().getTotalHits(), 2L);
assertEquals("{keyword=d}", result.getBuckets().get(2).getKeyAsString());
assertEquals(1L, result.getBuckets().get(2).getDocCount());
topHits = result.getBuckets().get(2).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 1);
assertEquals(topHits.getHits().getTotalHits(), 1L);;
}
);
testSearchCase(new MatchAllDocsQuery(), sort, dataset,
() -> {
TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
.field("keyword");
return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
.aggregateAfter(Collections.singletonMap("keyword", "a"))
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
}, (result) -> {
assertEquals(2, result.getBuckets().size());
assertEquals("{keyword=c}", result.getBuckets().get(0).getKeyAsString());
assertEquals(2L, result.getBuckets().get(0).getDocCount());
TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 2);
assertEquals(topHits.getHits().getTotalHits(), 2L);
assertEquals("{keyword=d}", result.getBuckets().get(1).getKeyAsString());
assertEquals(1L, result.getBuckets().get(1).getDocCount());
topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 1);
assertEquals(topHits.getHits().getTotalHits(), 1L);
}
);
}
private void testSearchCase(Query query, Sort sort,
List<Map<String, List<Object>>> dataset, List<Map<String, List<Object>>> dataset,
Supplier<CompositeAggregationBuilder> create, Supplier<CompositeAggregationBuilder> create,
Consumer<InternalComposite> verify) throws IOException { Consumer<InternalComposite> verify) throws IOException {
@ -1107,7 +1174,7 @@ public class CompositeAggregatorTests extends AggregatorTestCase {
IndexSearcher indexSearcher = newSearcher(indexReader, sort == null, sort == null); IndexSearcher indexSearcher = newSearcher(indexReader, sort == null, sort == null);
CompositeAggregationBuilder aggregationBuilder = create.get(); CompositeAggregationBuilder aggregationBuilder = create.get();
if (sort != null) { if (sort != null) {
CompositeAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES); CompositeAggregator aggregator = createAggregator(query, aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
assertTrue(aggregator.canEarlyTerminate()); assertTrue(aggregator.canEarlyTerminate());
} }
final InternalComposite composite; final InternalComposite composite;

View File

@ -103,16 +103,27 @@ public abstract class AggregatorTestCase extends ESTestCase {
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes); new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
} }
/** Create a factory for the given aggregation builder. */
protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggregationBuilder, protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher, IndexSearcher indexSearcher,
IndexSettings indexSettings, IndexSettings indexSettings,
MultiBucketConsumer bucketConsumer, MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException { MappedFieldType... fieldTypes) throws IOException {
return createAggregatorFactory(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
}
/** Create a factory for the given aggregation builder. */
protected AggregatorFactory<?> createAggregatorFactory(Query query,
AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
IndexSettings indexSettings,
MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException {
SearchContext searchContext = createSearchContext(indexSearcher, indexSettings); SearchContext searchContext = createSearchContext(indexSearcher, indexSettings);
CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService(); CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService();
when(searchContext.aggregations()) when(searchContext.aggregations())
.thenReturn(new SearchContextAggregations(AggregatorFactories.EMPTY, bucketConsumer)); .thenReturn(new SearchContextAggregations(AggregatorFactories.EMPTY, bucketConsumer));
when(searchContext.query()).thenReturn(query);
when(searchContext.bigArrays()).thenReturn(new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), circuitBreakerService)); when(searchContext.bigArrays()).thenReturn(new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), circuitBreakerService));
// TODO: now just needed for top_hits, this will need to be revised for other agg unit tests: // TODO: now just needed for top_hits, this will need to be revised for other agg unit tests:
MapperService mapperService = mapperServiceMock(); MapperService mapperService = mapperServiceMock();
@ -146,28 +157,38 @@ public abstract class AggregatorTestCase extends ESTestCase {
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes); new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
} }
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder, protected <A extends Aggregator> A createAggregator(Query query,
AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher, IndexSearcher indexSearcher,
IndexSettings indexSettings, IndexSettings indexSettings,
MappedFieldType... fieldTypes) throws IOException { MappedFieldType... fieldTypes) throws IOException {
return createAggregator(aggregationBuilder, indexSearcher, indexSettings, return createAggregator(query, aggregationBuilder, indexSearcher, indexSettings,
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes); new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
} }
protected <A extends Aggregator> A createAggregator(Query query, AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException {
return createAggregator(query, aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
}
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder, protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher, IndexSearcher indexSearcher,
IndexSettings indexSettings,
MultiBucketConsumer bucketConsumer, MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException { MappedFieldType... fieldTypes) throws IOException {
return createAggregator(aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes); return createAggregator(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
} }
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder, protected <A extends Aggregator> A createAggregator(Query query,
AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher, IndexSearcher indexSearcher,
IndexSettings indexSettings, IndexSettings indexSettings,
MultiBucketConsumer bucketConsumer, MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException { MappedFieldType... fieldTypes) throws IOException {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
A aggregator = (A) createAggregatorFactory(aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes) A aggregator = (A) createAggregatorFactory(query, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
.create(null, true); .create(null, true);
return aggregator; return aggregator;
} }
@ -262,7 +283,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
int maxBucket, int maxBucket,
MappedFieldType... fieldTypes) throws IOException { MappedFieldType... fieldTypes) throws IOException {
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket); MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
C a = createAggregator(builder, searcher, bucketConsumer, fieldTypes); C a = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);
a.preCollection(); a.preCollection();
searcher.search(query, a); searcher.search(query, a);
a.postCollection(); a.postCollection();
@ -310,11 +331,11 @@ public abstract class AggregatorTestCase extends ESTestCase {
Query rewritten = searcher.rewrite(query); Query rewritten = searcher.rewrite(query);
Weight weight = searcher.createWeight(rewritten, true, 1f); Weight weight = searcher.createWeight(rewritten, true, 1f);
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket); MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
C root = createAggregator(builder, searcher, bucketConsumer, fieldTypes); C root = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);
for (ShardSearcher subSearcher : subSearchers) { for (ShardSearcher subSearcher : subSearchers) {
MultiBucketConsumer shardBucketConsumer = new MultiBucketConsumer(maxBucket); MultiBucketConsumer shardBucketConsumer = new MultiBucketConsumer(maxBucket);
C a = createAggregator(builder, subSearcher, shardBucketConsumer, fieldTypes); C a = createAggregator(query, builder, subSearcher, shardBucketConsumer, fieldTypes);
a.preCollection(); a.preCollection();
subSearcher.search(weight, a); subSearcher.search(weight, a);
a.postCollection(); a.postCollection();