Cache the score of the parent document in the nested agg (#36019)

The nested agg can defer the collection of children if it is nested
under another aggregation. In such case accessing the score in the children
aggregation throws an error because the scorer has already advanced to the next
parent. This change fixes this error by caching the score of the parent in the
nested aggregation. Children aggregations that work on nested documents will be
able to access the _score. Also note that the _score in this case is always the
parent's score, there is no way to retrieve the score of a nested docs in aggregations.

Closes #35985
Closes #34555
This commit is contained in:
Jim Ferenczi 2018-11-29 14:35:25 +01:00 committed by GitHub
parent 85cdf4f913
commit ecd29089a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 15 deletions

View File

@ -26,6 +26,7 @@ import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
@ -141,7 +142,9 @@ public class NestedAggregator extends BucketsAggregator implements SingleBucketA
final DocIdSetIterator childDocs;
final LongArrayList bucketBuffer = new LongArrayList();
Scorable scorer;
int currentParentDoc = -1;
final CachedScorable cachedScorer = new CachedScorable();
BufferingNestedLeafBucketCollector(LeafBucketCollector sub, BitSet parentDocs, DocIdSetIterator childDocs) {
super(sub, null);
@ -150,6 +153,12 @@ public class NestedAggregator extends BucketsAggregator implements SingleBucketA
this.childDocs = childDocs;
}
@Override
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
super.setScorer(cachedScorer);
}
@Override
public void collect(int parentDoc, long bucket) throws IOException {
// if parentDoc is 0 then this means that this parent doesn't have child docs (b/c these appear always before the parent
@ -160,7 +169,12 @@ public class NestedAggregator extends BucketsAggregator implements SingleBucketA
if (currentParentDoc != parentDoc) {
processBufferedChildBuckets();
if (scoreMode().needsScores()) {
// cache the score of the current parent
cachedScorer.score = scorer.score();
}
currentParentDoc = parentDoc;
}
bucketBuffer.add(bucket);
}
@ -178,6 +192,7 @@ public class NestedAggregator extends BucketsAggregator implements SingleBucketA
}
for (; childDocId < currentParentDoc; childDocId = childDocs.nextDoc()) {
cachedScorer.doc = childDocId;
final long[] buffer = bucketBuffer.buffer;
final int size = bucketBuffer.size();
for (int i = 0; i < size; i++) {
@ -186,6 +201,19 @@ public class NestedAggregator extends BucketsAggregator implements SingleBucketA
}
bucketBuffer.clear();
}
}
private static class CachedScorable extends Scorable {
int doc;
float score;
@Override
public final float score() { return score; }
@Override
public int docID() {
return doc;
}
}

View File

@ -48,6 +48,7 @@ import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.index.mapper.SeqNoFieldMapper;
import org.elasticsearch.index.mapper.TypeFieldMapper;
import org.elasticsearch.index.mapper.Uid;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.script.Script;
@ -63,6 +64,7 @@ import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation;
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.filter.InternalFilter;
import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.global.InternalGlobal;
import org.elasticsearch.search.aggregations.bucket.nested.InternalNested;
@ -1048,21 +1050,23 @@ public class TermsAggregatorTests extends AggregatorTestCase {
fieldType.setHasDocValues(true);
fieldType.setName("nested_value");
try (IndexReader indexReader = wrap(DirectoryReader.open(directory))) {
InternalNested result = search(newSearcher(indexReader, false, true),
// match root document only
new DocValuesFieldExistsQuery(PRIMARY_TERM_NAME), nested, fieldType);
InternalMultiBucketAggregation<?, ?> terms = result.getAggregations().get("terms");
assertThat(terms.getBuckets().size(), equalTo(9));
int ptr = 9;
for (MultiBucketsAggregation.Bucket bucket : terms.getBuckets()) {
InternalTopHits topHits = bucket.getAggregations().get("top_hits");
assertThat(topHits.getHits().totalHits, equalTo((long) ptr));
if (withScore) {
assertThat(topHits.getHits().getMaxScore(), equalTo(1f));
} else {
assertThat(topHits.getHits().getMaxScore(), equalTo(Float.NaN));
}
--ptr;
{
InternalNested result = search(newSearcher(indexReader, false, true),
// match root document only
new DocValuesFieldExistsQuery(PRIMARY_TERM_NAME), nested, fieldType);
InternalMultiBucketAggregation<?, ?> terms = result.getAggregations().get("terms");
assertNestedTopHitsScore(terms, withScore);
}
{
FilterAggregationBuilder filter = new FilterAggregationBuilder("filter", new MatchAllQueryBuilder())
.subAggregation(nested);
InternalFilter result = search(newSearcher(indexReader, false, true),
// match root document only
new DocValuesFieldExistsQuery(PRIMARY_TERM_NAME), filter, fieldType);
InternalNested nestedResult = result.getAggregations().get("nested");
InternalMultiBucketAggregation<?, ?> terms = nestedResult.getAggregations().get("terms");
assertNestedTopHitsScore(terms, withScore);
}
}
}
@ -1071,6 +1075,21 @@ public class TermsAggregatorTests extends AggregatorTestCase {
}
}
private void assertNestedTopHitsScore(InternalMultiBucketAggregation<?, ?> terms, boolean withScore) {
assertThat(terms.getBuckets().size(), equalTo(9));
int ptr = 9;
for (MultiBucketsAggregation.Bucket bucket : terms.getBuckets()) {
InternalTopHits topHits = bucket.getAggregations().get("top_hits");
assertThat(topHits.getHits().totalHits, equalTo((long) ptr));
if (withScore) {
assertThat(topHits.getHits().getMaxScore(), equalTo(1f));
} else {
assertThat(topHits.getHits().getMaxScore(), equalTo(Float.NaN));
}
--ptr;
}
}
public void testOrderByPipelineAggregation() throws Exception {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {