Fix NPE on composite aggregation with sub-aggregations that need scores (#28129)
The composite aggregation defers the collection of sub-aggregations to a second pass that visits documents only if they appear in the top buckets. Though the scorer for sub-aggregations is not set on this second pass and generates an NPE if any sub-aggregation tries to access the score. This change creates a scorer for the second pass and makes sure that sub-aggs can use it safely to check the score of the collected documents.
This commit is contained in:
parent
ee7eac8dc1
commit
bd11e6c441
|
@ -23,6 +23,9 @@ import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.search.CollectionTerminatedException;
|
import org.apache.lucene.search.CollectionTerminatedException;
|
||||||
import org.apache.lucene.search.DocIdSet;
|
import org.apache.lucene.search.DocIdSet;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
|
import org.apache.lucene.search.Scorer;
|
||||||
|
import org.apache.lucene.search.Weight;
|
||||||
import org.apache.lucene.util.RoaringDocIdSet;
|
import org.apache.lucene.util.RoaringDocIdSet;
|
||||||
import org.elasticsearch.search.aggregations.Aggregator;
|
import org.elasticsearch.search.aggregations.Aggregator;
|
||||||
import org.elasticsearch.search.aggregations.AggregatorFactories;
|
import org.elasticsearch.search.aggregations.AggregatorFactories;
|
||||||
|
@ -87,6 +90,12 @@ final class CompositeAggregator extends BucketsAggregator {
|
||||||
|
|
||||||
// Replay all documents that contain at least one top bucket (collected during the first pass).
|
// Replay all documents that contain at least one top bucket (collected during the first pass).
|
||||||
grow(keys.size()+1);
|
grow(keys.size()+1);
|
||||||
|
final boolean needsScores = needsScores();
|
||||||
|
Weight weight = null;
|
||||||
|
if (needsScores) {
|
||||||
|
Query query = context.query();
|
||||||
|
weight = context.searcher().createNormalizedWeight(query, true);
|
||||||
|
}
|
||||||
for (LeafContext context : contexts) {
|
for (LeafContext context : contexts) {
|
||||||
DocIdSetIterator docIdSetIterator = context.docIdSet.iterator();
|
DocIdSetIterator docIdSetIterator = context.docIdSet.iterator();
|
||||||
if (docIdSetIterator == null) {
|
if (docIdSetIterator == null) {
|
||||||
|
@ -95,7 +104,21 @@ final class CompositeAggregator extends BucketsAggregator {
|
||||||
final CompositeValuesSource.Collector collector =
|
final CompositeValuesSource.Collector collector =
|
||||||
array.getLeafCollector(context.ctx, getSecondPassCollector(context.subCollector));
|
array.getLeafCollector(context.ctx, getSecondPassCollector(context.subCollector));
|
||||||
int docID;
|
int docID;
|
||||||
|
DocIdSetIterator scorerIt = null;
|
||||||
|
if (needsScores) {
|
||||||
|
Scorer scorer = weight.scorer(context.ctx);
|
||||||
|
// We don't need to check if the scorer is null
|
||||||
|
// since we are sure that there are documents to replay (docIdSetIterator it not empty).
|
||||||
|
scorerIt = scorer.iterator();
|
||||||
|
context.subCollector.setScorer(scorer);
|
||||||
|
}
|
||||||
while ((docID = docIdSetIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
while ((docID = docIdSetIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||||
|
if (needsScores) {
|
||||||
|
assert scorerIt.docID() < docID;
|
||||||
|
scorerIt.advance(docID);
|
||||||
|
// aggregations should only be replayed on matching documents
|
||||||
|
assert scorerIt.docID() == docID;
|
||||||
|
}
|
||||||
collector.collect(docID);
|
collector.collect(docID);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,6 +50,8 @@ import org.elasticsearch.index.mapper.Mapper;
|
||||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||||
import org.elasticsearch.search.aggregations.AggregatorTestCase;
|
import org.elasticsearch.search.aggregations.AggregatorTestCase;
|
||||||
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval;
|
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval;
|
||||||
|
import org.elasticsearch.search.aggregations.metrics.tophits.TopHits;
|
||||||
|
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder;
|
||||||
import org.elasticsearch.search.sort.SortOrder;
|
import org.elasticsearch.search.sort.SortOrder;
|
||||||
import org.elasticsearch.test.IndexSettingsModule;
|
import org.elasticsearch.test.IndexSettingsModule;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
|
@ -1065,8 +1067,73 @@ public class CompositeAggregatorTests extends AggregatorTestCase {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void testSearchCase(Query query,
|
public void testWithKeywordAndTopHits() throws Exception {
|
||||||
Sort sort,
|
final List<Map<String, List<Object>>> dataset = new ArrayList<>();
|
||||||
|
dataset.addAll(
|
||||||
|
Arrays.asList(
|
||||||
|
createDocument("keyword", "a"),
|
||||||
|
createDocument("keyword", "c"),
|
||||||
|
createDocument("keyword", "a"),
|
||||||
|
createDocument("keyword", "d"),
|
||||||
|
createDocument("keyword", "c")
|
||||||
|
)
|
||||||
|
);
|
||||||
|
final Sort sort = new Sort(new SortedSetSortField("keyword", false));
|
||||||
|
testSearchCase(new MatchAllDocsQuery(), sort, dataset,
|
||||||
|
() -> {
|
||||||
|
TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
|
||||||
|
.field("keyword");
|
||||||
|
return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
|
||||||
|
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
|
||||||
|
}, (result) -> {
|
||||||
|
assertEquals(3, result.getBuckets().size());
|
||||||
|
assertEquals("{keyword=a}", result.getBuckets().get(0).getKeyAsString());
|
||||||
|
assertEquals(2L, result.getBuckets().get(0).getDocCount());
|
||||||
|
TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
|
||||||
|
assertNotNull(topHits);
|
||||||
|
assertEquals(topHits.getHits().getHits().length, 2);
|
||||||
|
assertEquals(topHits.getHits().getTotalHits(), 2L);
|
||||||
|
assertEquals("{keyword=c}", result.getBuckets().get(1).getKeyAsString());
|
||||||
|
assertEquals(2L, result.getBuckets().get(1).getDocCount());
|
||||||
|
topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
|
||||||
|
assertNotNull(topHits);
|
||||||
|
assertEquals(topHits.getHits().getHits().length, 2);
|
||||||
|
assertEquals(topHits.getHits().getTotalHits(), 2L);
|
||||||
|
assertEquals("{keyword=d}", result.getBuckets().get(2).getKeyAsString());
|
||||||
|
assertEquals(1L, result.getBuckets().get(2).getDocCount());
|
||||||
|
topHits = result.getBuckets().get(2).getAggregations().get("top_hits");
|
||||||
|
assertNotNull(topHits);
|
||||||
|
assertEquals(topHits.getHits().getHits().length, 1);
|
||||||
|
assertEquals(topHits.getHits().getTotalHits(), 1L);;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
testSearchCase(new MatchAllDocsQuery(), sort, dataset,
|
||||||
|
() -> {
|
||||||
|
TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
|
||||||
|
.field("keyword");
|
||||||
|
return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
|
||||||
|
.aggregateAfter(Collections.singletonMap("keyword", "a"))
|
||||||
|
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
|
||||||
|
}, (result) -> {
|
||||||
|
assertEquals(2, result.getBuckets().size());
|
||||||
|
assertEquals("{keyword=c}", result.getBuckets().get(0).getKeyAsString());
|
||||||
|
assertEquals(2L, result.getBuckets().get(0).getDocCount());
|
||||||
|
TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
|
||||||
|
assertNotNull(topHits);
|
||||||
|
assertEquals(topHits.getHits().getHits().length, 2);
|
||||||
|
assertEquals(topHits.getHits().getTotalHits(), 2L);
|
||||||
|
assertEquals("{keyword=d}", result.getBuckets().get(1).getKeyAsString());
|
||||||
|
assertEquals(1L, result.getBuckets().get(1).getDocCount());
|
||||||
|
topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
|
||||||
|
assertNotNull(topHits);
|
||||||
|
assertEquals(topHits.getHits().getHits().length, 1);
|
||||||
|
assertEquals(topHits.getHits().getTotalHits(), 1L);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testSearchCase(Query query, Sort sort,
|
||||||
List<Map<String, List<Object>>> dataset,
|
List<Map<String, List<Object>>> dataset,
|
||||||
Supplier<CompositeAggregationBuilder> create,
|
Supplier<CompositeAggregationBuilder> create,
|
||||||
Consumer<InternalComposite> verify) throws IOException {
|
Consumer<InternalComposite> verify) throws IOException {
|
||||||
|
@ -1107,7 +1174,7 @@ public class CompositeAggregatorTests extends AggregatorTestCase {
|
||||||
IndexSearcher indexSearcher = newSearcher(indexReader, sort == null, sort == null);
|
IndexSearcher indexSearcher = newSearcher(indexReader, sort == null, sort == null);
|
||||||
CompositeAggregationBuilder aggregationBuilder = create.get();
|
CompositeAggregationBuilder aggregationBuilder = create.get();
|
||||||
if (sort != null) {
|
if (sort != null) {
|
||||||
CompositeAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
|
CompositeAggregator aggregator = createAggregator(query, aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
|
||||||
assertTrue(aggregator.canEarlyTerminate());
|
assertTrue(aggregator.canEarlyTerminate());
|
||||||
}
|
}
|
||||||
final InternalComposite composite;
|
final InternalComposite composite;
|
||||||
|
|
|
@ -103,16 +103,27 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
||||||
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
|
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a factory for the given aggregation builder. */
|
|
||||||
protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggregationBuilder,
|
protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggregationBuilder,
|
||||||
IndexSearcher indexSearcher,
|
IndexSearcher indexSearcher,
|
||||||
IndexSettings indexSettings,
|
IndexSettings indexSettings,
|
||||||
MultiBucketConsumer bucketConsumer,
|
MultiBucketConsumer bucketConsumer,
|
||||||
MappedFieldType... fieldTypes) throws IOException {
|
MappedFieldType... fieldTypes) throws IOException {
|
||||||
|
return createAggregatorFactory(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Create a factory for the given aggregation builder. */
|
||||||
|
protected AggregatorFactory<?> createAggregatorFactory(Query query,
|
||||||
|
AggregationBuilder aggregationBuilder,
|
||||||
|
IndexSearcher indexSearcher,
|
||||||
|
IndexSettings indexSettings,
|
||||||
|
MultiBucketConsumer bucketConsumer,
|
||||||
|
MappedFieldType... fieldTypes) throws IOException {
|
||||||
SearchContext searchContext = createSearchContext(indexSearcher, indexSettings);
|
SearchContext searchContext = createSearchContext(indexSearcher, indexSettings);
|
||||||
CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService();
|
CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService();
|
||||||
when(searchContext.aggregations())
|
when(searchContext.aggregations())
|
||||||
.thenReturn(new SearchContextAggregations(AggregatorFactories.EMPTY, bucketConsumer));
|
.thenReturn(new SearchContextAggregations(AggregatorFactories.EMPTY, bucketConsumer));
|
||||||
|
when(searchContext.query()).thenReturn(query);
|
||||||
when(searchContext.bigArrays()).thenReturn(new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), circuitBreakerService));
|
when(searchContext.bigArrays()).thenReturn(new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), circuitBreakerService));
|
||||||
// TODO: now just needed for top_hits, this will need to be revised for other agg unit tests:
|
// TODO: now just needed for top_hits, this will need to be revised for other agg unit tests:
|
||||||
MapperService mapperService = mapperServiceMock();
|
MapperService mapperService = mapperServiceMock();
|
||||||
|
@ -146,28 +157,38 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
||||||
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
|
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
|
protected <A extends Aggregator> A createAggregator(Query query,
|
||||||
|
AggregationBuilder aggregationBuilder,
|
||||||
IndexSearcher indexSearcher,
|
IndexSearcher indexSearcher,
|
||||||
IndexSettings indexSettings,
|
IndexSettings indexSettings,
|
||||||
MappedFieldType... fieldTypes) throws IOException {
|
MappedFieldType... fieldTypes) throws IOException {
|
||||||
return createAggregator(aggregationBuilder, indexSearcher, indexSettings,
|
return createAggregator(query, aggregationBuilder, indexSearcher, indexSettings,
|
||||||
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
|
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected <A extends Aggregator> A createAggregator(Query query, AggregationBuilder aggregationBuilder,
|
||||||
|
IndexSearcher indexSearcher,
|
||||||
|
MultiBucketConsumer bucketConsumer,
|
||||||
|
MappedFieldType... fieldTypes) throws IOException {
|
||||||
|
return createAggregator(query, aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
|
||||||
|
}
|
||||||
|
|
||||||
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
|
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
|
||||||
IndexSearcher indexSearcher,
|
IndexSearcher indexSearcher,
|
||||||
|
IndexSettings indexSettings,
|
||||||
MultiBucketConsumer bucketConsumer,
|
MultiBucketConsumer bucketConsumer,
|
||||||
MappedFieldType... fieldTypes) throws IOException {
|
MappedFieldType... fieldTypes) throws IOException {
|
||||||
return createAggregator(aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
|
return createAggregator(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
|
protected <A extends Aggregator> A createAggregator(Query query,
|
||||||
|
AggregationBuilder aggregationBuilder,
|
||||||
IndexSearcher indexSearcher,
|
IndexSearcher indexSearcher,
|
||||||
IndexSettings indexSettings,
|
IndexSettings indexSettings,
|
||||||
MultiBucketConsumer bucketConsumer,
|
MultiBucketConsumer bucketConsumer,
|
||||||
MappedFieldType... fieldTypes) throws IOException {
|
MappedFieldType... fieldTypes) throws IOException {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
A aggregator = (A) createAggregatorFactory(aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
|
A aggregator = (A) createAggregatorFactory(query, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
|
||||||
.create(null, true);
|
.create(null, true);
|
||||||
return aggregator;
|
return aggregator;
|
||||||
}
|
}
|
||||||
|
@ -262,7 +283,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
||||||
int maxBucket,
|
int maxBucket,
|
||||||
MappedFieldType... fieldTypes) throws IOException {
|
MappedFieldType... fieldTypes) throws IOException {
|
||||||
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
|
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
|
||||||
C a = createAggregator(builder, searcher, bucketConsumer, fieldTypes);
|
C a = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);
|
||||||
a.preCollection();
|
a.preCollection();
|
||||||
searcher.search(query, a);
|
searcher.search(query, a);
|
||||||
a.postCollection();
|
a.postCollection();
|
||||||
|
@ -310,11 +331,11 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
||||||
Query rewritten = searcher.rewrite(query);
|
Query rewritten = searcher.rewrite(query);
|
||||||
Weight weight = searcher.createWeight(rewritten, true, 1f);
|
Weight weight = searcher.createWeight(rewritten, true, 1f);
|
||||||
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
|
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
|
||||||
C root = createAggregator(builder, searcher, bucketConsumer, fieldTypes);
|
C root = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);
|
||||||
|
|
||||||
for (ShardSearcher subSearcher : subSearchers) {
|
for (ShardSearcher subSearcher : subSearchers) {
|
||||||
MultiBucketConsumer shardBucketConsumer = new MultiBucketConsumer(maxBucket);
|
MultiBucketConsumer shardBucketConsumer = new MultiBucketConsumer(maxBucket);
|
||||||
C a = createAggregator(builder, subSearcher, shardBucketConsumer, fieldTypes);
|
C a = createAggregator(query, builder, subSearcher, shardBucketConsumer, fieldTypes);
|
||||||
a.preCollection();
|
a.preCollection();
|
||||||
subSearcher.search(weight, a);
|
subSearcher.search(weight, a);
|
||||||
a.postCollection();
|
a.postCollection();
|
||||||
|
|
Loading…
Reference in New Issue