Save memory when parent and child are not on top (#57892) (#57944)

Reworks the `parent` and `child` aggregation are not at the top level
using the optimization from #55873. Instead of wrapping all
non-top-level `parent` and `child` aggregators we now handle being a
child aggregator in the aggregator, specifically by adding recording
which global ordinals show up in the parent and then checking if they
match the child.
This commit is contained in:
Nik Everett 2020-06-10 16:25:10 -04:00 committed by GitHub
parent 9eb8085ac0
commit 0a2bd10758
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 111 additions and 46 deletions

View File

@ -79,12 +79,8 @@ public class ChildrenAggregatorFactory extends ValuesSourceAggregatorFactory {
} }
WithOrdinals valuesSource = (WithOrdinals) rawValuesSource; WithOrdinals valuesSource = (WithOrdinals) rawValuesSource;
long maxOrd = valuesSource.globalMaxOrd(searchContext.searcher()); long maxOrd = valuesSource.globalMaxOrd(searchContext.searcher());
if (collectsFromSingleBucket) {
return new ParentToChildrenAggregator(name, factories, searchContext, parent, childFilter, return new ParentToChildrenAggregator(name, factories, searchContext, parent, childFilter,
parentFilter, valuesSource, maxOrd, metadata); parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
} else {
return asMultiBucketAggregator(this, searchContext, parent);
}
} }
@Override @Override

View File

@ -40,8 +40,8 @@ public class ChildrenToParentAggregator extends ParentJoinAggregator {
public ChildrenToParentAggregator(String name, AggregatorFactories factories, public ChildrenToParentAggregator(String name, AggregatorFactories factories,
SearchContext context, Aggregator parent, Query childFilter, SearchContext context, Aggregator parent, Query childFilter,
Query parentFilter, ValuesSource.Bytes.WithOrdinals valuesSource, Query parentFilter, ValuesSource.Bytes.WithOrdinals valuesSource,
long maxOrd, Map<String, Object> metadata) throws IOException { long maxOrd, boolean collectsFromSingleBucket, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, childFilter, parentFilter, valuesSource, maxOrd, metadata); super(name, factories, context, parent, childFilter, parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
} }
@Override @Override

View File

@ -80,12 +80,8 @@ public class ParentAggregatorFactory extends ValuesSourceAggregatorFactory {
} }
WithOrdinals valuesSource = (WithOrdinals) rawValuesSource; WithOrdinals valuesSource = (WithOrdinals) rawValuesSource;
long maxOrd = valuesSource.globalMaxOrd(searchContext.searcher()); long maxOrd = valuesSource.globalMaxOrd(searchContext.searcher());
if (collectsFromSingleBucket) {
return new ChildrenToParentAggregator(name, factories, searchContext, children, childFilter, return new ChildrenToParentAggregator(name, factories, searchContext, children, childFilter,
parentFilter, valuesSource, maxOrd, metadata); parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
} else {
return asMultiBucketAggregator(this, searchContext, children);
}
} }
@Override @Override

View File

@ -33,12 +33,12 @@ import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.LongHash;
import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.LeafBucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.bucket.BucketsAggregator; import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator; import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator;
import org.elasticsearch.search.aggregations.bucket.terms.LongKeyedBucketOrds;
import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSource;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
@ -68,6 +68,7 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
Query outFilter, Query outFilter,
ValuesSource.Bytes.WithOrdinals valuesSource, ValuesSource.Bytes.WithOrdinals valuesSource,
long maxOrd, long maxOrd,
boolean collectsFromSingleBucket,
Map<String, Object> metadata) throws IOException { Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, metadata); super(name, factories, context, parent, metadata);
@ -81,8 +82,9 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
this.outFilter = context.searcher().createWeight(context.searcher().rewrite(outFilter), ScoreMode.COMPLETE_NO_SCORES, 1f); this.outFilter = context.searcher().createWeight(context.searcher().rewrite(outFilter), ScoreMode.COMPLETE_NO_SCORES, 1f);
this.valuesSource = valuesSource; this.valuesSource = valuesSource;
boolean singleAggregator = parent == null; boolean singleAggregator = parent == null;
collectionStrategy = singleAggregator ? collectionStrategy = singleAggregator && collectsFromSingleBucket
new DenseCollectionStrategy(maxOrd, context.bigArrays()) : new SparseCollectionStrategy(context.bigArrays()); ? new DenseCollectionStrategy(maxOrd, context.bigArrays())
: new SparseCollectionStrategy(context.bigArrays(), collectsFromSingleBucket);
} }
@Override @Override
@ -95,19 +97,18 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
final Bits parentDocs = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), inFilter.scorerSupplier(ctx)); final Bits parentDocs = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), inFilter.scorerSupplier(ctx));
return new LeafBucketCollector() { return new LeafBucketCollector() {
@Override @Override
public void collect(int docId, long bucket) throws IOException { public void collect(int docId, long owningBucketOrd) throws IOException {
assert bucket == 0;
if (parentDocs.get(docId) && globalOrdinals.advanceExact(docId)) { if (parentDocs.get(docId) && globalOrdinals.advanceExact(docId)) {
int globalOrdinal = (int) globalOrdinals.nextOrd(); int globalOrdinal = (int) globalOrdinals.nextOrd();
assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS; assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
collectionStrategy.addGlobalOrdinal(globalOrdinal); collectionStrategy.add(owningBucketOrd, globalOrdinal);
} }
} }
}; };
} }
@Override @Override
protected final void doPostCollection() throws IOException { protected void beforeBuildingBuckets(long[] ordsToCollect) throws IOException {
IndexReader indexReader = context().searcher().getIndexReader(); IndexReader indexReader = context().searcher().getIndexReader();
for (LeafReaderContext ctx : indexReader.leaves()) { for (LeafReaderContext ctx : indexReader.leaves()) {
Scorer childDocsScorer = outFilter.scorer(ctx); Scorer childDocsScorer = outFilter.scorer(ctx);
@ -137,11 +138,21 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
if (liveDocs != null && liveDocs.get(docId) == false) { if (liveDocs != null && liveDocs.get(docId) == false) {
continue; continue;
} }
if (globalOrdinals.advanceExact(docId)) { if (false == globalOrdinals.advanceExact(docId)) {
continue;
}
int globalOrdinal = (int) globalOrdinals.nextOrd(); int globalOrdinal = (int) globalOrdinals.nextOrd();
assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS; assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
if (collectionStrategy.existsGlobalOrdinal(globalOrdinal)) { /*
collectBucket(sub, docId, 0); * Check if we contain every ordinal. It's almost certainly be
* faster to replay all the matching ordinals and filter them down
* to just those listed in ordsToCollect, but we don't have a data
* structure that maps a primitive long to a list of primitive
* longs.
*/
for (long owningBucketOrd: ordsToCollect) {
if (collectionStrategy.exists(owningBucketOrd, globalOrdinal)) {
collectBucket(sub, docId, owningBucketOrd);
} }
} }
} }
@ -160,8 +171,8 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
* {@code ParentJoinAggregator#outFilter} also have the ordinal. * {@code ParentJoinAggregator#outFilter} also have the ordinal.
*/ */
protected interface CollectionStrategy extends Releasable { protected interface CollectionStrategy extends Releasable {
void addGlobalOrdinal(int globalOrdinal); void add(long owningBucketOrd, int globalOrdinal);
boolean existsGlobalOrdinal(int globalOrdinal); boolean exists(long owningBucketOrd, int globalOrdinal);
} }
/** /**
@ -178,12 +189,14 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
} }
@Override @Override
public void addGlobalOrdinal(int globalOrdinal) { public void add(long owningBucketOrd, int globalOrdinal) {
assert owningBucketOrd == 0;
ordsBits.set(globalOrdinal); ordsBits.set(globalOrdinal);
} }
@Override @Override
public boolean existsGlobalOrdinal(int globalOrdinal) { public boolean exists(long owningBucketOrd, int globalOrdinal) {
assert owningBucketOrd == 0;
return ordsBits.get(globalOrdinal); return ordsBits.get(globalOrdinal);
} }
@ -200,20 +213,20 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
* when only some docs might match. * when only some docs might match.
*/ */
protected class SparseCollectionStrategy implements CollectionStrategy { protected class SparseCollectionStrategy implements CollectionStrategy {
private final LongHash ordsHash; private final LongKeyedBucketOrds ordsHash;
public SparseCollectionStrategy(BigArrays bigArrays) { public SparseCollectionStrategy(BigArrays bigArrays, boolean collectsFromSingleBucket) {
ordsHash = new LongHash(1, bigArrays); ordsHash = LongKeyedBucketOrds.build(bigArrays, collectsFromSingleBucket);
} }
@Override @Override
public void addGlobalOrdinal(int globalOrdinal) { public void add(long owningBucketOrd, int globalOrdinal) {
ordsHash.add(globalOrdinal); ordsHash.add(owningBucketOrd, globalOrdinal);
} }
@Override @Override
public boolean existsGlobalOrdinal(int globalOrdinal) { public boolean exists(long owningBucketOrd, int globalOrdinal) {
return ordsHash.find(globalOrdinal) >= 0; return ordsHash.find(owningBucketOrd, globalOrdinal) >= 0;
} }
@Override @Override

View File

@ -36,8 +36,8 @@ public class ParentToChildrenAggregator extends ParentJoinAggregator {
public ParentToChildrenAggregator(String name, AggregatorFactories factories, public ParentToChildrenAggregator(String name, AggregatorFactories factories,
SearchContext context, Aggregator parent, Query childFilter, SearchContext context, Aggregator parent, Query childFilter,
Query parentFilter, ValuesSource.Bytes.WithOrdinals valuesSource, Query parentFilter, ValuesSource.Bytes.WithOrdinals valuesSource,
long maxOrd, Map<String, Object> metadata) throws IOException { long maxOrd, boolean collectsFromSingleBucket, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, parentFilter, childFilter, valuesSource, maxOrd, metadata); super(name, factories, context, parent, parentFilter, childFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
} }
@Override @Override

View File

@ -59,7 +59,6 @@ import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.InternalMin; import org.elasticsearch.search.aggregations.metrics.InternalMin;
import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValueType;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -313,8 +312,7 @@ public class ChildrenToParentAggregatorTests extends AggregatorTestCase {
throws IOException { throws IOException {
ParentAggregationBuilder aggregationBuilder = new ParentAggregationBuilder("_name", CHILD_TYPE); ParentAggregationBuilder aggregationBuilder = new ParentAggregationBuilder("_name", CHILD_TYPE);
aggregationBuilder.subAggregation(new TermsAggregationBuilder("value_terms").userValueTypeHint(ValueType.LONG) aggregationBuilder.subAggregation(new TermsAggregationBuilder("value_terms").field("number"));
.field("number"));
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG); MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName("number"); fieldType.setName("number");
@ -326,9 +324,9 @@ public class ChildrenToParentAggregatorTests extends AggregatorTestCase {
private void testCaseTermsParentTerms(Query query, IndexSearcher indexSearcher, Consumer<LongTerms> verify) private void testCaseTermsParentTerms(Query query, IndexSearcher indexSearcher, Consumer<LongTerms> verify)
throws IOException { throws IOException {
AggregationBuilder aggregationBuilder = AggregationBuilder aggregationBuilder =
new TermsAggregationBuilder("subvalue_terms").userValueTypeHint(ValueType.LONG).field("subNumber"). new TermsAggregationBuilder("subvalue_terms").field("subNumber").
subAggregation(new ParentAggregationBuilder("to_parent", CHILD_TYPE). subAggregation(new ParentAggregationBuilder("to_parent", CHILD_TYPE).
subAggregation(new TermsAggregationBuilder("value_terms").userValueTypeHint(ValueType.LONG).field("number"))); subAggregation(new TermsAggregationBuilder("value_terms").field("number")));
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG); MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName("number"); fieldType.setName("number");

View File

@ -22,6 +22,7 @@ package org.elasticsearch.join.aggregations;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.SortedDocValuesField; import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StringField; import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
@ -52,7 +53,10 @@ import org.elasticsearch.join.ParentJoinPlugin;
import org.elasticsearch.join.mapper.MetaJoinFieldMapper; import org.elasticsearch.join.mapper.MetaJoinFieldMapper;
import org.elasticsearch.join.mapper.ParentJoinFieldMapper; import org.elasticsearch.join.mapper.ParentJoinFieldMapper;
import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.InternalMin; import org.elasticsearch.search.aggregations.metrics.InternalMin;
import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
@ -64,6 +68,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -124,12 +129,68 @@ public class ParentToChildrenAggregatorTests extends AggregatorTestCase {
directory.close(); directory.close();
} }
public void testParentChildAsSubAgg() throws IOException {
try (Directory directory = newDirectory()) {
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
final Map<String, Tuple<Integer, Integer>> expectedParentChildRelations = setupIndex(indexWriter);
indexWriter.close();
try (
IndexReader indexReader = ElasticsearchDirectoryReader.wrap(
DirectoryReader.open(directory),
new ShardId(new Index("foo", "_na_"), 1)
)
) {
IndexSearcher indexSearcher = newSearcher(indexReader, false, true);
AggregationBuilder request = new TermsAggregationBuilder("t").field("kwd")
.subAggregation(
new ChildrenAggregationBuilder("children", CHILD_TYPE).subAggregation(
new MinAggregationBuilder("min").field("number")
)
);
long expectedEvenChildCount = 0;
double expectedEvenMin = Double.MAX_VALUE;
long expectedOddChildCount = 0;
double expectedOddMin = Double.MAX_VALUE;
for (Map.Entry<String, Tuple<Integer, Integer>> e : expectedParentChildRelations.entrySet()) {
if (Integer.valueOf(e.getKey().substring("parent".length())) % 2 == 0) {
expectedEvenChildCount += e.getValue().v1();
expectedEvenMin = Math.min(expectedEvenMin, e.getValue().v2());
} else {
expectedOddChildCount += e.getValue().v1();
expectedOddMin = Math.min(expectedOddMin, e.getValue().v2());
}
}
StringTerms result = search(indexSearcher, new MatchAllDocsQuery(), request, longField("number"), keywordField("kwd"));
StringTerms.Bucket evenBucket = result.getBucketByKey("even");
InternalChildren evenChildren = evenBucket.getAggregations().get("children");
InternalMin evenMin = evenChildren.getAggregations().get("min");
assertThat(evenChildren.getDocCount(), equalTo(expectedEvenChildCount));
assertThat(evenMin.getValue(), equalTo(expectedEvenMin));
if (expectedOddChildCount > 0) {
StringTerms.Bucket oddBucket = result.getBucketByKey("odd");
InternalChildren oddChildren = oddBucket.getAggregations().get("children");
InternalMin oddMin = oddChildren.getAggregations().get("min");
assertThat(oddChildren.getDocCount(), equalTo(expectedOddChildCount));
assertThat(oddMin.getValue(), equalTo(expectedOddMin));
} else {
assertNull(result.getBucketByKey("odd"));
}
}
}
}
private static Map<String, Tuple<Integer, Integer>> setupIndex(RandomIndexWriter iw) throws IOException { private static Map<String, Tuple<Integer, Integer>> setupIndex(RandomIndexWriter iw) throws IOException {
Map<String, Tuple<Integer, Integer>> expectedValues = new HashMap<>(); Map<String, Tuple<Integer, Integer>> expectedValues = new HashMap<>();
int numParents = randomIntBetween(1, 10); int numParents = randomIntBetween(1, 10);
for (int i = 0; i < numParents; i++) { for (int i = 0; i < numParents; i++) {
String parent = "parent" + i; String parent = "parent" + i;
iw.addDocument(createParentDocument(parent)); iw.addDocument(createParentDocument(parent, i % 2 == 0 ? "even" : "odd"));
int numChildren = randomIntBetween(1, 10); int numChildren = randomIntBetween(1, 10);
int minValue = Integer.MAX_VALUE; int minValue = Integer.MAX_VALUE;
for (int c = 0; c < numChildren; c++) { for (int c = 0; c < numChildren; c++) {
@ -142,9 +203,10 @@ public class ParentToChildrenAggregatorTests extends AggregatorTestCase {
return expectedValues; return expectedValues;
} }
private static List<Field> createParentDocument(String id) { private static List<Field> createParentDocument(String id, String kwd) {
return Arrays.asList( return Arrays.asList(
new StringField(IdFieldMapper.NAME, Uid.encodeId(id), Field.Store.NO), new StringField(IdFieldMapper.NAME, Uid.encodeId(id), Field.Store.NO),
new SortedSetDocValuesField("kwd", new BytesRef(kwd)),
new StringField("join_field", PARENT_TYPE, Field.Store.NO), new StringField("join_field", PARENT_TYPE, Field.Store.NO),
createJoinField(PARENT_TYPE, id) createJoinField(PARENT_TYPE, id)
); );