Fix NPE on composite aggregation with sub-aggregations that need scores (#28129)
The composite aggregation defers the collection of sub-aggregations to a second pass that visits documents only if they appear in the top buckets. Though the scorer for sub-aggregations is not set on this second pass and generates an NPE if any sub-aggregation tries to access the score. This change creates a scorer for the second pass and makes sure that sub-aggs can use it safely to check the score of the collected documents.
This commit is contained in:
parent
ee7eac8dc1
commit
bd11e6c441
|
@ -23,6 +23,9 @@ import org.apache.lucene.index.LeafReaderContext;
|
|||
import org.apache.lucene.search.CollectionTerminatedException;
|
||||
import org.apache.lucene.search.DocIdSet;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.Scorer;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.util.RoaringDocIdSet;
|
||||
import org.elasticsearch.search.aggregations.Aggregator;
|
||||
import org.elasticsearch.search.aggregations.AggregatorFactories;
|
||||
|
@ -87,6 +90,12 @@ final class CompositeAggregator extends BucketsAggregator {
|
|||
|
||||
// Replay all documents that contain at least one top bucket (collected during the first pass).
|
||||
grow(keys.size()+1);
|
||||
final boolean needsScores = needsScores();
|
||||
Weight weight = null;
|
||||
if (needsScores) {
|
||||
Query query = context.query();
|
||||
weight = context.searcher().createNormalizedWeight(query, true);
|
||||
}
|
||||
for (LeafContext context : contexts) {
|
||||
DocIdSetIterator docIdSetIterator = context.docIdSet.iterator();
|
||||
if (docIdSetIterator == null) {
|
||||
|
@ -95,7 +104,21 @@ final class CompositeAggregator extends BucketsAggregator {
|
|||
final CompositeValuesSource.Collector collector =
|
||||
array.getLeafCollector(context.ctx, getSecondPassCollector(context.subCollector));
|
||||
int docID;
|
||||
DocIdSetIterator scorerIt = null;
|
||||
if (needsScores) {
|
||||
Scorer scorer = weight.scorer(context.ctx);
|
||||
// We don't need to check if the scorer is null
|
||||
// since we are sure that there are documents to replay (docIdSetIterator it not empty).
|
||||
scorerIt = scorer.iterator();
|
||||
context.subCollector.setScorer(scorer);
|
||||
}
|
||||
while ((docID = docIdSetIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
if (needsScores) {
|
||||
assert scorerIt.docID() < docID;
|
||||
scorerIt.advance(docID);
|
||||
// aggregations should only be replayed on matching documents
|
||||
assert scorerIt.docID() == docID;
|
||||
}
|
||||
collector.collect(docID);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,6 +50,8 @@ import org.elasticsearch.index.mapper.Mapper;
|
|||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.search.aggregations.AggregatorTestCase;
|
||||
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval;
|
||||
import org.elasticsearch.search.aggregations.metrics.tophits.TopHits;
|
||||
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.test.IndexSettingsModule;
|
||||
import org.joda.time.DateTimeZone;
|
||||
|
@ -1065,8 +1067,73 @@ public class CompositeAggregatorTests extends AggregatorTestCase {
|
|||
);
|
||||
}
|
||||
|
||||
private void testSearchCase(Query query,
|
||||
Sort sort,
|
||||
public void testWithKeywordAndTopHits() throws Exception {
|
||||
final List<Map<String, List<Object>>> dataset = new ArrayList<>();
|
||||
dataset.addAll(
|
||||
Arrays.asList(
|
||||
createDocument("keyword", "a"),
|
||||
createDocument("keyword", "c"),
|
||||
createDocument("keyword", "a"),
|
||||
createDocument("keyword", "d"),
|
||||
createDocument("keyword", "c")
|
||||
)
|
||||
);
|
||||
final Sort sort = new Sort(new SortedSetSortField("keyword", false));
|
||||
testSearchCase(new MatchAllDocsQuery(), sort, dataset,
|
||||
() -> {
|
||||
TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
|
||||
.field("keyword");
|
||||
return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
|
||||
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
|
||||
}, (result) -> {
|
||||
assertEquals(3, result.getBuckets().size());
|
||||
assertEquals("{keyword=a}", result.getBuckets().get(0).getKeyAsString());
|
||||
assertEquals(2L, result.getBuckets().get(0).getDocCount());
|
||||
TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
|
||||
assertNotNull(topHits);
|
||||
assertEquals(topHits.getHits().getHits().length, 2);
|
||||
assertEquals(topHits.getHits().getTotalHits(), 2L);
|
||||
assertEquals("{keyword=c}", result.getBuckets().get(1).getKeyAsString());
|
||||
assertEquals(2L, result.getBuckets().get(1).getDocCount());
|
||||
topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
|
||||
assertNotNull(topHits);
|
||||
assertEquals(topHits.getHits().getHits().length, 2);
|
||||
assertEquals(topHits.getHits().getTotalHits(), 2L);
|
||||
assertEquals("{keyword=d}", result.getBuckets().get(2).getKeyAsString());
|
||||
assertEquals(1L, result.getBuckets().get(2).getDocCount());
|
||||
topHits = result.getBuckets().get(2).getAggregations().get("top_hits");
|
||||
assertNotNull(topHits);
|
||||
assertEquals(topHits.getHits().getHits().length, 1);
|
||||
assertEquals(topHits.getHits().getTotalHits(), 1L);;
|
||||
}
|
||||
);
|
||||
|
||||
testSearchCase(new MatchAllDocsQuery(), sort, dataset,
|
||||
() -> {
|
||||
TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
|
||||
.field("keyword");
|
||||
return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
|
||||
.aggregateAfter(Collections.singletonMap("keyword", "a"))
|
||||
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
|
||||
}, (result) -> {
|
||||
assertEquals(2, result.getBuckets().size());
|
||||
assertEquals("{keyword=c}", result.getBuckets().get(0).getKeyAsString());
|
||||
assertEquals(2L, result.getBuckets().get(0).getDocCount());
|
||||
TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
|
||||
assertNotNull(topHits);
|
||||
assertEquals(topHits.getHits().getHits().length, 2);
|
||||
assertEquals(topHits.getHits().getTotalHits(), 2L);
|
||||
assertEquals("{keyword=d}", result.getBuckets().get(1).getKeyAsString());
|
||||
assertEquals(1L, result.getBuckets().get(1).getDocCount());
|
||||
topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
|
||||
assertNotNull(topHits);
|
||||
assertEquals(topHits.getHits().getHits().length, 1);
|
||||
assertEquals(topHits.getHits().getTotalHits(), 1L);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
private void testSearchCase(Query query, Sort sort,
|
||||
List<Map<String, List<Object>>> dataset,
|
||||
Supplier<CompositeAggregationBuilder> create,
|
||||
Consumer<InternalComposite> verify) throws IOException {
|
||||
|
@ -1107,7 +1174,7 @@ public class CompositeAggregatorTests extends AggregatorTestCase {
|
|||
IndexSearcher indexSearcher = newSearcher(indexReader, sort == null, sort == null);
|
||||
CompositeAggregationBuilder aggregationBuilder = create.get();
|
||||
if (sort != null) {
|
||||
CompositeAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
|
||||
CompositeAggregator aggregator = createAggregator(query, aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
|
||||
assertTrue(aggregator.canEarlyTerminate());
|
||||
}
|
||||
final InternalComposite composite;
|
||||
|
|
|
@ -103,16 +103,27 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
|||
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
|
||||
}
|
||||
|
||||
/** Create a factory for the given aggregation builder. */
|
||||
|
||||
protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggregationBuilder,
|
||||
IndexSearcher indexSearcher,
|
||||
IndexSettings indexSettings,
|
||||
MultiBucketConsumer bucketConsumer,
|
||||
MappedFieldType... fieldTypes) throws IOException {
|
||||
return createAggregatorFactory(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
|
||||
}
|
||||
|
||||
/** Create a factory for the given aggregation builder. */
|
||||
protected AggregatorFactory<?> createAggregatorFactory(Query query,
|
||||
AggregationBuilder aggregationBuilder,
|
||||
IndexSearcher indexSearcher,
|
||||
IndexSettings indexSettings,
|
||||
MultiBucketConsumer bucketConsumer,
|
||||
MappedFieldType... fieldTypes) throws IOException {
|
||||
SearchContext searchContext = createSearchContext(indexSearcher, indexSettings);
|
||||
CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService();
|
||||
when(searchContext.aggregations())
|
||||
.thenReturn(new SearchContextAggregations(AggregatorFactories.EMPTY, bucketConsumer));
|
||||
when(searchContext.query()).thenReturn(query);
|
||||
when(searchContext.bigArrays()).thenReturn(new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), circuitBreakerService));
|
||||
// TODO: now just needed for top_hits, this will need to be revised for other agg unit tests:
|
||||
MapperService mapperService = mapperServiceMock();
|
||||
|
@ -146,28 +157,38 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
|||
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
|
||||
}
|
||||
|
||||
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
|
||||
protected <A extends Aggregator> A createAggregator(Query query,
|
||||
AggregationBuilder aggregationBuilder,
|
||||
IndexSearcher indexSearcher,
|
||||
IndexSettings indexSettings,
|
||||
MappedFieldType... fieldTypes) throws IOException {
|
||||
return createAggregator(aggregationBuilder, indexSearcher, indexSettings,
|
||||
return createAggregator(query, aggregationBuilder, indexSearcher, indexSettings,
|
||||
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
|
||||
}
|
||||
|
||||
protected <A extends Aggregator> A createAggregator(Query query, AggregationBuilder aggregationBuilder,
|
||||
IndexSearcher indexSearcher,
|
||||
MultiBucketConsumer bucketConsumer,
|
||||
MappedFieldType... fieldTypes) throws IOException {
|
||||
return createAggregator(query, aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
|
||||
}
|
||||
|
||||
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
|
||||
IndexSearcher indexSearcher,
|
||||
IndexSettings indexSettings,
|
||||
MultiBucketConsumer bucketConsumer,
|
||||
MappedFieldType... fieldTypes) throws IOException {
|
||||
return createAggregator(aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
|
||||
return createAggregator(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
|
||||
}
|
||||
|
||||
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
|
||||
protected <A extends Aggregator> A createAggregator(Query query,
|
||||
AggregationBuilder aggregationBuilder,
|
||||
IndexSearcher indexSearcher,
|
||||
IndexSettings indexSettings,
|
||||
MultiBucketConsumer bucketConsumer,
|
||||
MappedFieldType... fieldTypes) throws IOException {
|
||||
@SuppressWarnings("unchecked")
|
||||
A aggregator = (A) createAggregatorFactory(aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
|
||||
A aggregator = (A) createAggregatorFactory(query, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
|
||||
.create(null, true);
|
||||
return aggregator;
|
||||
}
|
||||
|
@ -262,7 +283,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
|||
int maxBucket,
|
||||
MappedFieldType... fieldTypes) throws IOException {
|
||||
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
|
||||
C a = createAggregator(builder, searcher, bucketConsumer, fieldTypes);
|
||||
C a = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);
|
||||
a.preCollection();
|
||||
searcher.search(query, a);
|
||||
a.postCollection();
|
||||
|
@ -310,11 +331,11 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
|||
Query rewritten = searcher.rewrite(query);
|
||||
Weight weight = searcher.createWeight(rewritten, true, 1f);
|
||||
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
|
||||
C root = createAggregator(builder, searcher, bucketConsumer, fieldTypes);
|
||||
C root = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);
|
||||
|
||||
for (ShardSearcher subSearcher : subSearchers) {
|
||||
MultiBucketConsumer shardBucketConsumer = new MultiBucketConsumer(maxBucket);
|
||||
C a = createAggregator(builder, subSearcher, shardBucketConsumer, fieldTypes);
|
||||
C a = createAggregator(query, builder, subSearcher, shardBucketConsumer, fieldTypes);
|
||||
a.preCollection();
|
||||
subSearcher.search(weight, a);
|
||||
a.postCollection();
|
||||
|
|
Loading…
Reference in New Issue