Save memory when parent and child are not on top (#57892) (#57944)

Reworks the `parent` and `child` aggregation are not at the top level
using the optimization from #55873. Instead of wrapping all
non-top-level `parent` and `child` aggregators we now handle being a
child aggregator in the aggregator, specifically by adding recording
which global ordinals show up in the parent and then checking if they
match the child.
This commit is contained in:
Nik Everett 2020-06-10 16:25:10 -04:00 committed by GitHub
parent 9eb8085ac0
commit 0a2bd10758
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 111 additions and 46 deletions

View File

@ -79,12 +79,8 @@ public class ChildrenAggregatorFactory extends ValuesSourceAggregatorFactory {
}
WithOrdinals valuesSource = (WithOrdinals) rawValuesSource;
long maxOrd = valuesSource.globalMaxOrd(searchContext.searcher());
if (collectsFromSingleBucket) {
return new ParentToChildrenAggregator(name, factories, searchContext, parent, childFilter,
parentFilter, valuesSource, maxOrd, metadata);
} else {
return asMultiBucketAggregator(this, searchContext, parent);
}
return new ParentToChildrenAggregator(name, factories, searchContext, parent, childFilter,
parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
}
@Override

View File

@ -40,8 +40,8 @@ public class ChildrenToParentAggregator extends ParentJoinAggregator {
public ChildrenToParentAggregator(String name, AggregatorFactories factories,
SearchContext context, Aggregator parent, Query childFilter,
Query parentFilter, ValuesSource.Bytes.WithOrdinals valuesSource,
long maxOrd, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, childFilter, parentFilter, valuesSource, maxOrd, metadata);
long maxOrd, boolean collectsFromSingleBucket, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, childFilter, parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
}
@Override

View File

@ -80,12 +80,8 @@ public class ParentAggregatorFactory extends ValuesSourceAggregatorFactory {
}
WithOrdinals valuesSource = (WithOrdinals) rawValuesSource;
long maxOrd = valuesSource.globalMaxOrd(searchContext.searcher());
if (collectsFromSingleBucket) {
return new ChildrenToParentAggregator(name, factories, searchContext, children, childFilter,
parentFilter, valuesSource, maxOrd, metadata);
} else {
return asMultiBucketAggregator(this, searchContext, children);
}
return new ChildrenToParentAggregator(name, factories, searchContext, children, childFilter,
parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
}
@Override

View File

@ -33,12 +33,12 @@ import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.LongHash;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator;
import org.elasticsearch.search.aggregations.bucket.terms.LongKeyedBucketOrds;
import org.elasticsearch.search.aggregations.support.ValuesSource;
import org.elasticsearch.search.internal.SearchContext;
@ -68,6 +68,7 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
Query outFilter,
ValuesSource.Bytes.WithOrdinals valuesSource,
long maxOrd,
boolean collectsFromSingleBucket,
Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, metadata);
@ -81,8 +82,9 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
this.outFilter = context.searcher().createWeight(context.searcher().rewrite(outFilter), ScoreMode.COMPLETE_NO_SCORES, 1f);
this.valuesSource = valuesSource;
boolean singleAggregator = parent == null;
collectionStrategy = singleAggregator ?
new DenseCollectionStrategy(maxOrd, context.bigArrays()) : new SparseCollectionStrategy(context.bigArrays());
collectionStrategy = singleAggregator && collectsFromSingleBucket
? new DenseCollectionStrategy(maxOrd, context.bigArrays())
: new SparseCollectionStrategy(context.bigArrays(), collectsFromSingleBucket);
}
@Override
@ -95,19 +97,18 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
final Bits parentDocs = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), inFilter.scorerSupplier(ctx));
return new LeafBucketCollector() {
@Override
public void collect(int docId, long bucket) throws IOException {
assert bucket == 0;
public void collect(int docId, long owningBucketOrd) throws IOException {
if (parentDocs.get(docId) && globalOrdinals.advanceExact(docId)) {
int globalOrdinal = (int) globalOrdinals.nextOrd();
assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
collectionStrategy.addGlobalOrdinal(globalOrdinal);
collectionStrategy.add(owningBucketOrd, globalOrdinal);
}
}
};
}
@Override
protected final void doPostCollection() throws IOException {
protected void beforeBuildingBuckets(long[] ordsToCollect) throws IOException {
IndexReader indexReader = context().searcher().getIndexReader();
for (LeafReaderContext ctx : indexReader.leaves()) {
Scorer childDocsScorer = outFilter.scorer(ctx);
@ -137,11 +138,21 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
if (liveDocs != null && liveDocs.get(docId) == false) {
continue;
}
if (globalOrdinals.advanceExact(docId)) {
int globalOrdinal = (int) globalOrdinals.nextOrd();
assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
if (collectionStrategy.existsGlobalOrdinal(globalOrdinal)) {
collectBucket(sub, docId, 0);
if (false == globalOrdinals.advanceExact(docId)) {
continue;
}
int globalOrdinal = (int) globalOrdinals.nextOrd();
assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
/*
* Check if we contain every ordinal. It's almost certainly be
* faster to replay all the matching ordinals and filter them down
* to just those listed in ordsToCollect, but we don't have a data
* structure that maps a primitive long to a list of primitive
* longs.
*/
for (long owningBucketOrd: ordsToCollect) {
if (collectionStrategy.exists(owningBucketOrd, globalOrdinal)) {
collectBucket(sub, docId, owningBucketOrd);
}
}
}
@ -160,8 +171,8 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
* {@code ParentJoinAggregator#outFilter} also have the ordinal.
*/
protected interface CollectionStrategy extends Releasable {
void addGlobalOrdinal(int globalOrdinal);
boolean existsGlobalOrdinal(int globalOrdinal);
void add(long owningBucketOrd, int globalOrdinal);
boolean exists(long owningBucketOrd, int globalOrdinal);
}
/**
@ -178,12 +189,14 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
}
@Override
public void addGlobalOrdinal(int globalOrdinal) {
public void add(long owningBucketOrd, int globalOrdinal) {
assert owningBucketOrd == 0;
ordsBits.set(globalOrdinal);
}
@Override
public boolean existsGlobalOrdinal(int globalOrdinal) {
public boolean exists(long owningBucketOrd, int globalOrdinal) {
assert owningBucketOrd == 0;
return ordsBits.get(globalOrdinal);
}
@ -200,20 +213,20 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
* when only some docs might match.
*/
protected class SparseCollectionStrategy implements CollectionStrategy {
private final LongHash ordsHash;
private final LongKeyedBucketOrds ordsHash;
public SparseCollectionStrategy(BigArrays bigArrays) {
ordsHash = new LongHash(1, bigArrays);
public SparseCollectionStrategy(BigArrays bigArrays, boolean collectsFromSingleBucket) {
ordsHash = LongKeyedBucketOrds.build(bigArrays, collectsFromSingleBucket);
}
@Override
public void addGlobalOrdinal(int globalOrdinal) {
ordsHash.add(globalOrdinal);
public void add(long owningBucketOrd, int globalOrdinal) {
ordsHash.add(owningBucketOrd, globalOrdinal);
}
@Override
public boolean existsGlobalOrdinal(int globalOrdinal) {
return ordsHash.find(globalOrdinal) >= 0;
public boolean exists(long owningBucketOrd, int globalOrdinal) {
return ordsHash.find(owningBucketOrd, globalOrdinal) >= 0;
}
@Override

View File

@ -36,8 +36,8 @@ public class ParentToChildrenAggregator extends ParentJoinAggregator {
public ParentToChildrenAggregator(String name, AggregatorFactories factories,
SearchContext context, Aggregator parent, Query childFilter,
Query parentFilter, ValuesSource.Bytes.WithOrdinals valuesSource,
long maxOrd, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, parentFilter, childFilter, valuesSource, maxOrd, metadata);
long maxOrd, boolean collectsFromSingleBucket, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, parentFilter, childFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
}
@Override

View File

@ -59,7 +59,6 @@ import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.InternalMin;
import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValueType;
import java.io.IOException;
import java.util.ArrayList;
@ -313,8 +312,7 @@ public class ChildrenToParentAggregatorTests extends AggregatorTestCase {
throws IOException {
ParentAggregationBuilder aggregationBuilder = new ParentAggregationBuilder("_name", CHILD_TYPE);
aggregationBuilder.subAggregation(new TermsAggregationBuilder("value_terms").userValueTypeHint(ValueType.LONG)
.field("number"));
aggregationBuilder.subAggregation(new TermsAggregationBuilder("value_terms").field("number"));
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName("number");
@ -326,9 +324,9 @@ public class ChildrenToParentAggregatorTests extends AggregatorTestCase {
private void testCaseTermsParentTerms(Query query, IndexSearcher indexSearcher, Consumer<LongTerms> verify)
throws IOException {
AggregationBuilder aggregationBuilder =
new TermsAggregationBuilder("subvalue_terms").userValueTypeHint(ValueType.LONG).field("subNumber").
new TermsAggregationBuilder("subvalue_terms").field("subNumber").
subAggregation(new ParentAggregationBuilder("to_parent", CHILD_TYPE).
subAggregation(new TermsAggregationBuilder("value_terms").userValueTypeHint(ValueType.LONG).field("number")));
subAggregation(new TermsAggregationBuilder("value_terms").field("number")));
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName("number");

View File

@ -22,6 +22,7 @@ package org.elasticsearch.join.aggregations;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
@ -52,7 +53,10 @@ import org.elasticsearch.join.ParentJoinPlugin;
import org.elasticsearch.join.mapper.MetaJoinFieldMapper;
import org.elasticsearch.join.mapper.ParentJoinFieldMapper;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.InternalMin;
import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
@ -64,6 +68,7 @@ import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -124,12 +129,68 @@ public class ParentToChildrenAggregatorTests extends AggregatorTestCase {
directory.close();
}
public void testParentChildAsSubAgg() throws IOException {
try (Directory directory = newDirectory()) {
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
final Map<String, Tuple<Integer, Integer>> expectedParentChildRelations = setupIndex(indexWriter);
indexWriter.close();
try (
IndexReader indexReader = ElasticsearchDirectoryReader.wrap(
DirectoryReader.open(directory),
new ShardId(new Index("foo", "_na_"), 1)
)
) {
IndexSearcher indexSearcher = newSearcher(indexReader, false, true);
AggregationBuilder request = new TermsAggregationBuilder("t").field("kwd")
.subAggregation(
new ChildrenAggregationBuilder("children", CHILD_TYPE).subAggregation(
new MinAggregationBuilder("min").field("number")
)
);
long expectedEvenChildCount = 0;
double expectedEvenMin = Double.MAX_VALUE;
long expectedOddChildCount = 0;
double expectedOddMin = Double.MAX_VALUE;
for (Map.Entry<String, Tuple<Integer, Integer>> e : expectedParentChildRelations.entrySet()) {
if (Integer.valueOf(e.getKey().substring("parent".length())) % 2 == 0) {
expectedEvenChildCount += e.getValue().v1();
expectedEvenMin = Math.min(expectedEvenMin, e.getValue().v2());
} else {
expectedOddChildCount += e.getValue().v1();
expectedOddMin = Math.min(expectedOddMin, e.getValue().v2());
}
}
StringTerms result = search(indexSearcher, new MatchAllDocsQuery(), request, longField("number"), keywordField("kwd"));
StringTerms.Bucket evenBucket = result.getBucketByKey("even");
InternalChildren evenChildren = evenBucket.getAggregations().get("children");
InternalMin evenMin = evenChildren.getAggregations().get("min");
assertThat(evenChildren.getDocCount(), equalTo(expectedEvenChildCount));
assertThat(evenMin.getValue(), equalTo(expectedEvenMin));
if (expectedOddChildCount > 0) {
StringTerms.Bucket oddBucket = result.getBucketByKey("odd");
InternalChildren oddChildren = oddBucket.getAggregations().get("children");
InternalMin oddMin = oddChildren.getAggregations().get("min");
assertThat(oddChildren.getDocCount(), equalTo(expectedOddChildCount));
assertThat(oddMin.getValue(), equalTo(expectedOddMin));
} else {
assertNull(result.getBucketByKey("odd"));
}
}
}
}
private static Map<String, Tuple<Integer, Integer>> setupIndex(RandomIndexWriter iw) throws IOException {
Map<String, Tuple<Integer, Integer>> expectedValues = new HashMap<>();
int numParents = randomIntBetween(1, 10);
for (int i = 0; i < numParents; i++) {
String parent = "parent" + i;
iw.addDocument(createParentDocument(parent));
iw.addDocument(createParentDocument(parent, i % 2 == 0 ? "even" : "odd"));
int numChildren = randomIntBetween(1, 10);
int minValue = Integer.MAX_VALUE;
for (int c = 0; c < numChildren; c++) {
@ -142,9 +203,10 @@ public class ParentToChildrenAggregatorTests extends AggregatorTestCase {
return expectedValues;
}
private static List<Field> createParentDocument(String id) {
private static List<Field> createParentDocument(String id, String kwd) {
return Arrays.asList(
new StringField(IdFieldMapper.NAME, Uid.encodeId(id), Field.Store.NO),
new SortedSetDocValuesField("kwd", new BytesRef(kwd)),
new StringField("join_field", PARENT_TYPE, Field.Store.NO),
createJoinField(PARENT_TYPE, id)
);