Replace the TreeMap in the composite aggregation (#36675)

The `composite` aggregation uses a TreeMap to keep track of the best buckets.
This ensures a log(n) time cost to insert new buckets but also to retrieve buckets
that are already present in the map. In order to speed up the retrieval of buckets
this change replaces the TreeMap with a priority queue and a HashMap. The insertion
cost is still log(n) but the retrieval of buckets through the HashMap is now done in constant
time. This optimization can bring significant improvement since each document needs
to check if its associated buckets are already present in the current best buckets.
This commit is contained in:
Jim Ferenczi 2019-01-03 09:51:35 +01:00 committed by GitHub
parent a40f0545e6
commit 78ba1889cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 205 additions and 81 deletions

View File

@ -335,10 +335,11 @@ setup:
--- ---
"Composite aggregation and array size": "Composite aggregation and array size":
- skip: - skip:
version: " - 6.3.99" version: " - 6.99.99"
reason: starting in 6.4 the composite sources do not allocate arrays eagerly. reason: starting in 7.0 the composite aggregation throws an execption if the provided size is greater than search.max_buckets.
- do: - do:
catch: /.*Trying to create too many buckets.*/
search: search:
rest_total_hits_as_int: true rest_total_hits_as_int: true
index: test index: test
@ -356,8 +357,3 @@ setup:
} }
} }
] ]
- match: {hits.total: 6}
- length: { aggregations.test.buckets: 2 }
- length: { aggregations.test.after_key: 1 }
- match: { aggregations.test.after_key.keyword: "foo" }

View File

@ -118,6 +118,10 @@ public class MultiBucketConsumerService {
public int getCount() { public int getCount() {
return count; return count;
} }
public int getLimit() {
return limit;
}
} }
public MultiBucketConsumer create() { public MultiBucketConsumer create() {

View File

@ -18,8 +18,6 @@
*/ */
package org.elasticsearch.search.aggregations; package org.elasticsearch.search.aggregations;
import java.util.function.IntConsumer;
import static org.elasticsearch.search.aggregations.MultiBucketConsumerService.MultiBucketConsumer; import static org.elasticsearch.search.aggregations.MultiBucketConsumerService.MultiBucketConsumer;
/** /**
@ -60,7 +58,7 @@ public class SearchContextAggregations {
* Returns a consumer for multi bucket aggregation that checks the total number of buckets * Returns a consumer for multi bucket aggregation that checks the total number of buckets
* created in the response * created in the response
*/ */
public IntConsumer multiBucketConsumer() { public MultiBucketConsumer multiBucketConsumer() {
return multiBucketConsumer; return multiBucketConsumer;
} }

View File

@ -114,6 +114,24 @@ class BinaryValuesSource extends SingleDimensionValuesSource<BytesRef> {
return compareValues(currentValue, afterValue); return compareValues(currentValue, afterValue);
} }
@Override
int hashCode(int slot) {
if (missingBucket && values.get(slot) == null) {
return 0;
} else {
return values.get(slot).hashCode();
}
}
@Override
int hashCodeCurrent() {
if (missingBucket && currentValue == null) {
return 0;
} else {
return currentValue.hashCode();
}
}
int compareValues(BytesRef v1, BytesRef v2) { int compareValues(BytesRef v1, BytesRef v2) {
return v1.compareTo(v2) * reverseMul; return v1.compareTo(v2) * reverseMul;
} }

View File

@ -40,6 +40,7 @@ import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.LeafBucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.MultiBucketCollector; import org.elasticsearch.search.aggregations.MultiBucketCollector;
import org.elasticsearch.search.aggregations.MultiBucketConsumerService;
import org.elasticsearch.search.aggregations.bucket.BucketsAggregator; import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSource;
@ -54,6 +55,8 @@ import java.util.Map;
import java.util.function.LongUnaryOperator; import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.elasticsearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;
final class CompositeAggregator extends BucketsAggregator { final class CompositeAggregator extends BucketsAggregator {
private final int size; private final int size;
private final SortedDocsProducer sortedDocsProducer; private final SortedDocsProducer sortedDocsProducer;
@ -78,9 +81,15 @@ final class CompositeAggregator extends BucketsAggregator {
this.reverseMuls = Arrays.stream(sourceConfigs).mapToInt(CompositeValuesSourceConfig::reverseMul).toArray(); this.reverseMuls = Arrays.stream(sourceConfigs).mapToInt(CompositeValuesSourceConfig::reverseMul).toArray();
this.formats = Arrays.stream(sourceConfigs).map(CompositeValuesSourceConfig::format).collect(Collectors.toList()); this.formats = Arrays.stream(sourceConfigs).map(CompositeValuesSourceConfig::format).collect(Collectors.toList());
this.sources = new SingleDimensionValuesSource[sourceConfigs.length]; this.sources = new SingleDimensionValuesSource[sourceConfigs.length];
// check that the provided size is not greater than the search.max_buckets setting
int bucketLimit = context.aggregations().multiBucketConsumer().getLimit();
if (size > bucketLimit) {
throw new MultiBucketConsumerService.TooManyBucketsException("Trying to create too many buckets. Must be less than or equal" +
" to: [" + bucketLimit + "] but was [" + size + "]. This limit can be set by changing the [" + MAX_BUCKET_SETTING.getKey() +
"] cluster level setting.", bucketLimit);
}
for (int i = 0; i < sourceConfigs.length; i++) { for (int i = 0; i < sourceConfigs.length; i++) {
this.sources[i] = createValuesSource(context.bigArrays(), context.searcher().getIndexReader(), this.sources[i] = createValuesSource(context.bigArrays(), context.searcher().getIndexReader(), sourceConfigs[i], size);
context.query(), sourceConfigs[i], size, i);
} }
this.queue = new CompositeValuesCollectorQueue(context.bigArrays(), sources, size, rawAfterKey); this.queue = new CompositeValuesCollectorQueue(context.bigArrays(), sources, size, rawAfterKey);
this.sortedDocsProducer = sources[0].createSortedDocsProducerOrNull(context.searcher().getIndexReader(), context.query()); this.sortedDocsProducer = sources[0].createSortedDocsProducerOrNull(context.searcher().getIndexReader(), context.query());
@ -88,8 +97,11 @@ final class CompositeAggregator extends BucketsAggregator {
@Override @Override
protected void doClose() { protected void doClose() {
Releasables.close(queue); try {
Releasables.close(sources); Releasables.close(queue);
} finally {
Releasables.close(sources);
}
} }
@Override @Override
@ -116,12 +128,12 @@ final class CompositeAggregator extends BucketsAggregator {
int num = Math.min(size, queue.size()); int num = Math.min(size, queue.size());
final InternalComposite.InternalBucket[] buckets = new InternalComposite.InternalBucket[num]; final InternalComposite.InternalBucket[] buckets = new InternalComposite.InternalBucket[num];
int pos = 0; while (queue.size() > 0) {
for (int slot : queue.getSortedSlot()) { int slot = queue.pop();
CompositeKey key = queue.toCompositeKey(slot); CompositeKey key = queue.toCompositeKey(slot);
InternalAggregations aggs = bucketAggregations(slot); InternalAggregations aggs = bucketAggregations(slot);
int docCount = queue.getDocCount(slot); int docCount = queue.getDocCount(slot);
buckets[pos++] = new InternalComposite.InternalBucket(sourceNames, formats, key, reverseMuls, docCount, aggs); buckets[queue.size()] = new InternalComposite.InternalBucket(sourceNames, formats, key, reverseMuls, docCount, aggs);
} }
CompositeKey lastBucket = num > 0 ? buckets[num-1].getRawKey() : null; CompositeKey lastBucket = num > 0 ? buckets[num-1].getRawKey() : null;
return new InternalComposite(name, size, sourceNames, formats, Arrays.asList(buckets), lastBucket, reverseMuls, return new InternalComposite(name, size, sourceNames, formats, Arrays.asList(buckets), lastBucket, reverseMuls,
@ -259,13 +271,13 @@ final class CompositeAggregator extends BucketsAggregator {
}; };
} }
private SingleDimensionValuesSource<?> createValuesSource(BigArrays bigArrays, IndexReader reader, Query query, private SingleDimensionValuesSource<?> createValuesSource(BigArrays bigArrays, IndexReader reader,
CompositeValuesSourceConfig config, int sortRank, int size) { CompositeValuesSourceConfig config, int size) {
final int reverseMul = config.reverseMul(); final int reverseMul = config.reverseMul();
if (config.valuesSource() instanceof ValuesSource.Bytes.WithOrdinals && reader instanceof DirectoryReader) { if (config.valuesSource() instanceof ValuesSource.Bytes.WithOrdinals && reader instanceof DirectoryReader) {
ValuesSource.Bytes.WithOrdinals vs = (ValuesSource.Bytes.WithOrdinals) config.valuesSource(); ValuesSource.Bytes.WithOrdinals vs = (ValuesSource.Bytes.WithOrdinals) config.valuesSource();
SingleDimensionValuesSource<?> source = new GlobalOrdinalValuesSource( return new GlobalOrdinalValuesSource(
bigArrays, bigArrays,
config.fieldType(), config.fieldType(),
vs::globalOrdinalsValues, vs::globalOrdinalsValues,
@ -274,25 +286,6 @@ final class CompositeAggregator extends BucketsAggregator {
size, size,
reverseMul reverseMul
); );
if (sortRank == 0 && source.createSortedDocsProducerOrNull(reader, query) != null) {
// this the leading source and we can optimize it with the sorted docs producer but
// we don't want to use global ordinals because the number of visited documents
// should be low and global ordinals need one lookup per visited term.
Releasables.close(source);
return new BinaryValuesSource(
bigArrays,
this::addRequestCircuitBreakerBytes,
config.fieldType(),
vs::bytesValues,
config.format(),
config.missingBucket(),
size,
reverseMul
);
} else {
return source;
}
} else if (config.valuesSource() instanceof ValuesSource.Bytes) { } else if (config.valuesSource() instanceof ValuesSource.Bytes) {
ValuesSource.Bytes vs = (ValuesSource.Bytes) config.valuesSource(); ValuesSource.Bytes vs = (ValuesSource.Bytes) config.valuesSource();
return new BinaryValuesSource( return new BinaryValuesSource(

View File

@ -77,4 +77,11 @@ class CompositeKey implements Writeable {
public int hashCode() { public int hashCode() {
return Arrays.hashCode(values); return Arrays.hashCode(values);
} }
@Override
public String toString() {
return "CompositeKey{" +
"values=" + Arrays.toString(values) +
'}';
}
} }

View File

@ -20,6 +20,7 @@
package org.elasticsearch.search.aggregations.bucket.composite; package org.elasticsearch.search.aggregations.bucket.composite;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.PriorityQueue;
import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
@ -27,19 +28,40 @@ import org.elasticsearch.common.util.IntArray;
import org.elasticsearch.search.aggregations.LeafBucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollector;
import java.io.IOException; import java.io.IOException;
import java.util.Set; import java.util.HashMap;
import java.util.TreeMap; import java.util.Map;
/** /**
* A specialized queue implementation for composite buckets * A specialized {@link PriorityQueue} implementation for composite buckets.
*/ */
final class CompositeValuesCollectorQueue implements Releasable { final class CompositeValuesCollectorQueue extends PriorityQueue<Integer> implements Releasable {
private class Slot {
int value;
Slot(int initial) {
this.value = initial;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Slot slot = (Slot) o;
return CompositeValuesCollectorQueue.this.equals(value, slot.value);
}
@Override
public int hashCode() {
return CompositeValuesCollectorQueue.this.hashCode(value);
}
}
// the slot for the current candidate // the slot for the current candidate
private static final int CANDIDATE_SLOT = Integer.MAX_VALUE; private static final int CANDIDATE_SLOT = Integer.MAX_VALUE;
private final BigArrays bigArrays; private final BigArrays bigArrays;
private final int maxSize; private final int maxSize;
private final TreeMap<Integer, Integer> keys; private final Map<Slot, Integer> map;
private final SingleDimensionValuesSource<?>[] arrays; private final SingleDimensionValuesSource<?>[] arrays;
private IntArray docCounts; private IntArray docCounts;
private boolean afterKeyIsSet = false; private boolean afterKeyIsSet = false;
@ -52,10 +74,11 @@ final class CompositeValuesCollectorQueue implements Releasable {
* @param afterKey composite key * @param afterKey composite key
*/ */
CompositeValuesCollectorQueue(BigArrays bigArrays, SingleDimensionValuesSource<?>[] sources, int size, CompositeKey afterKey) { CompositeValuesCollectorQueue(BigArrays bigArrays, SingleDimensionValuesSource<?>[] sources, int size, CompositeKey afterKey) {
super(size);
this.bigArrays = bigArrays; this.bigArrays = bigArrays;
this.maxSize = size; this.maxSize = size;
this.arrays = sources; this.arrays = sources;
this.keys = new TreeMap<>(this::compare); this.map = new HashMap<>(size);
if (afterKey != null) { if (afterKey != null) {
assert afterKey.size() == sources.length; assert afterKey.size() == sources.length;
afterKeyIsSet = true; afterKeyIsSet = true;
@ -66,25 +89,16 @@ final class CompositeValuesCollectorQueue implements Releasable {
this.docCounts = bigArrays.newIntArray(1, false); this.docCounts = bigArrays.newIntArray(1, false);
} }
/** @Override
* The current size of the queue. protected boolean lessThan(Integer a, Integer b) {
*/ return compare(a, b) > 0;
int size() {
return keys.size();
} }
/** /**
* Whether the queue is full or not. * Whether the queue is full or not.
*/ */
boolean isFull() { boolean isFull() {
return keys.size() == maxSize; return size() >= maxSize;
}
/**
* Returns a sorted {@link Set} view of the slots contained in this queue.
*/
Set<Integer> getSortedSlot() {
return keys.keySet();
} }
/** /**
@ -92,7 +106,7 @@ final class CompositeValuesCollectorQueue implements Releasable {
* the slot if the candidate is already in the queue or null if the candidate is not present. * the slot if the candidate is already in the queue or null if the candidate is not present.
*/ */
Integer compareCurrent() { Integer compareCurrent() {
return keys.get(CANDIDATE_SLOT); return map.get(new Slot(CANDIDATE_SLOT));
} }
/** /**
@ -106,7 +120,7 @@ final class CompositeValuesCollectorQueue implements Releasable {
* Returns the upper value (inclusive) of the leading source. * Returns the upper value (inclusive) of the leading source.
*/ */
Comparable getUpperValueLeadSource() throws IOException { Comparable getUpperValueLeadSource() throws IOException {
return size() >= maxSize ? arrays[0].toComparable(keys.lastKey()) : null; return size() >= maxSize ? arrays[0].toComparable(top()) : null;
} }
/** /**
* Returns the document count in <code>slot</code>. * Returns the document count in <code>slot</code>.
@ -127,12 +141,17 @@ final class CompositeValuesCollectorQueue implements Releasable {
} }
/** /**
* Compares the values in <code>slot1</code> with <code>slot2</code>. * Compares the values in <code>slot1</code> with the values in <code>slot2</code>.
*/ */
int compare(int slot1, int slot2) { int compare(int slot1, int slot2) {
assert slot2 != CANDIDATE_SLOT;
for (int i = 0; i < arrays.length; i++) { for (int i = 0; i < arrays.length; i++) {
int cmp = (slot1 == CANDIDATE_SLOT) ? arrays[i].compareCurrent(slot2) : final int cmp;
arrays[i].compare(slot1, slot2); if (slot1 == CANDIDATE_SLOT) {
cmp = arrays[i].compareCurrent(slot2);
} else {
cmp = arrays[i].compare(slot1, slot2);
}
if (cmp != 0) { if (cmp != 0) {
return cmp; return cmp;
} }
@ -140,6 +159,36 @@ final class CompositeValuesCollectorQueue implements Releasable {
return 0; return 0;
} }
/**
* Returns true if the values in <code>slot1</code> are equals to the value in <code>slot2</code>.
*/
boolean equals(int slot1, int slot2) {
assert slot2 != CANDIDATE_SLOT;
for (int i = 0; i < arrays.length; i++) {
final int cmp;
if (slot1 == CANDIDATE_SLOT) {
cmp = arrays[i].compareCurrent(slot2);
} else {
cmp = arrays[i].compare(slot1, slot2);
}
if (cmp != 0) {
return false;
}
}
return true;
}
/**
* Returns a hash code value for the values in <code>slot</code>.
*/
int hashCode(int slot) {
int result = 1;
for (int i = 0; i < arrays.length; i++) {
result = 31 * result + (slot == CANDIDATE_SLOT ? arrays[i].hashCodeCurrent() : arrays[i].hashCode(slot));
}
return result;
}
/** /**
* Compares the after values with the values in <code>slot</code>. * Compares the after values with the values in <code>slot</code>.
*/ */
@ -209,28 +258,28 @@ final class CompositeValuesCollectorQueue implements Releasable {
// this key is greater than the top value collected in the previous round, skip it // this key is greater than the top value collected in the previous round, skip it
return -1; return -1;
} }
if (keys.size() >= maxSize) { if (size() >= maxSize
// the tree map is full, check if the candidate key should be kept // the tree map is full, check if the candidate key should be kept
if (compare(CANDIDATE_SLOT, keys.lastKey()) > 0) { && compare(CANDIDATE_SLOT, top()) > 0) {
// the candidate key is not competitive, skip it // the candidate key is not competitive, skip it
return -1; return -1;
}
} }
// the candidate key is competitive // the candidate key is competitive
final int newSlot; final int newSlot;
if (keys.size() >= maxSize) { if (size() >= maxSize) {
// the tree map is full, we replace the last key with this candidate // the queue is full, we replace the last key with this candidate
int slot = keys.pollLastEntry().getKey(); int slot = pop();
map.remove(new Slot(slot));
// and we recycle the deleted slot // and we recycle the deleted slot
newSlot = slot; newSlot = slot;
} else { } else {
newSlot = keys.size(); newSlot = size();
assert newSlot < maxSize;
} }
// move the candidate key to its new slot // move the candidate key to its new slot
copyCurrent(newSlot); copyCurrent(newSlot);
keys.put(newSlot, newSlot); map.put(new Slot(newSlot), newSlot);
add(newSlot);
return newSlot; return newSlot;
} }

View File

@ -103,6 +103,24 @@ class DoubleValuesSource extends SingleDimensionValuesSource<Double> {
return compareValues(currentValue, afterValue); return compareValues(currentValue, afterValue);
} }
@Override
int hashCode(int slot) {
if (missingBucket && bits.get(slot) == false) {
return 0;
} else {
return Double.hashCode(values.get(slot));
}
}
@Override
int hashCodeCurrent() {
if (missingCurrentValue) {
return 0;
} else {
return Double.hashCode(currentValue);
}
}
private int compareValues(double v1, double v2) { private int compareValues(double v1, double v2) {
return Double.compare(v1, v2) * reverseMul; return Double.compare(v1, v2) * reverseMul;
} }

View File

@ -88,6 +88,16 @@ class GlobalOrdinalValuesSource extends SingleDimensionValuesSource<BytesRef> {
return cmp * reverseMul; return cmp * reverseMul;
} }
@Override
int hashCode(int slot) {
return Long.hashCode(values.get(slot));
}
@Override
int hashCodeCurrent() {
return Long.hashCode(currentValue);
}
@Override @Override
void setAfter(Comparable value) { void setAfter(Comparable value) {
if (missingBucket && value == null) { if (missingBucket && value == null) {

View File

@ -120,6 +120,24 @@ class LongValuesSource extends SingleDimensionValuesSource<Long> {
return compareValues(currentValue, afterValue); return compareValues(currentValue, afterValue);
} }
@Override
int hashCode(int slot) {
if (missingBucket && bits.get(slot) == false) {
return 0;
} else {
return Long.hashCode(values.get(slot));
}
}
@Override
int hashCodeCurrent() {
if (missingCurrentValue) {
return 0;
} else {
return Long.hashCode(currentValue);
}
}
private int compareValues(long v1, long v2) { private int compareValues(long v1, long v2) {
return Long.compare(v1, v2) * reverseMul; return Long.compare(v1, v2) * reverseMul;
} }

View File

@ -99,6 +99,16 @@ abstract class SingleDimensionValuesSource<T extends Comparable<T>> implements R
*/ */
abstract int compareCurrentWithAfter(); abstract int compareCurrentWithAfter();
/**
* Returns a hash code value for the provided <code>slot</code>.
*/
abstract int hashCode(int slot);
/**
* Returns a hash code value for the current value.
*/
abstract int hashCodeCurrent();
/** /**
* Sets the after value for this source. Values that compares smaller are filtered. * Sets the after value for this source. Values that compares smaller are filtered.
*/ */

View File

@ -296,13 +296,16 @@ public class CompositeValuesCollectorQueueTests extends AggregatorTestCase {
} }
} }
assertEquals(size, Math.min(queue.size(), expected.length - pos)); assertEquals(size, Math.min(queue.size(), expected.length - pos));
int ptr = 0; int ptr = pos + (queue.size() - 1);
for (int slot : queue.getSortedSlot()) {
CompositeKey key = queue.toCompositeKey(slot);
assertThat(key, equalTo(expected[ptr++]));
last = key;
}
pos += queue.size(); pos += queue.size();
last = null;
while (queue.size() > pos) {
CompositeKey key = queue.toCompositeKey(queue.pop());
if (last == null) {
last = key;
}
assertThat(key, equalTo(expected[ptr--]));
}
} }
} }
reader.close(); reader.close();