Reuse CompensatedSum object in agg collect loops (#49548)

The new CompensatedSum is a nice DRY refactor, but had the unanticipated 
side effect of creating a lot of object allocation in the aggregation hot collection 
loop: one object per visited document, per aggregator. In some places it 
created two per-doc-per-agg (weighted avg, geo centroids, etc) since there 
were multiple compensations being maintained.

This PR moves the object creation out of the hot loop so that it is now 
created once per segment, and resets the internal state each time through 
the loop
This commit is contained in:
Zachary Tong 2019-11-25 16:45:51 -05:00 committed by Zachary Tong
parent 2fd58bb845
commit 99e313695f
7 changed files with 54 additions and 32 deletions

View File

@ -73,6 +73,8 @@ class AvgAggregator extends NumericMetricsAggregator.SingleValue {
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
@ -87,7 +89,8 @@ class AvgAggregator extends NumericMetricsAggregator.SingleValue {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
kahanSummation.reset(sum, compensation);
for (int i = 0; i < valueCount; i++) {
double value = values.nextValue();

View File

@ -68,6 +68,14 @@ public class CompensatedSum {
return add(value, NO_CORRECTION);
}
/**
* Resets the internal state to use the new value and compensation delta
*/
public void reset(double value, double delta) {
this.value = value;
this.delta = delta;
}
/**
* Increments the Kahan sum by adding two sums, and updating the correction term for reducing numeric errors.
*/

View File

@ -90,6 +90,8 @@ class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue {
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum compensatedSum = new CompensatedSum(0, 0);
final CompensatedSum compensatedSumOfSqr = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
@ -117,11 +119,11 @@ class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue {
// which is more accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum compensatedSum = new CompensatedSum(sum, compensation);
compensatedSum.reset(sum, compensation);
double sumOfSqr = sumOfSqrs.get(bucket);
double compensationOfSqr = compensationOfSqrs.get(bucket);
CompensatedSum compensatedSumOfSqr = new CompensatedSum(sumOfSqr, compensationOfSqr);
compensatedSumOfSqr.reset(sumOfSqr, compensationOfSqr);
for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();

View File

@ -68,6 +68,9 @@ final class GeoCentroidAggregator extends MetricsAggregator {
}
final BigArrays bigArrays = context.bigArrays();
final MultiGeoPointValues values = valuesSource.geoPointValues(ctx);
final CompensatedSum compensatedSumLat = new CompensatedSum(0, 0);
final CompensatedSum compensatedSumLon = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
@ -88,8 +91,8 @@ final class GeoCentroidAggregator extends MetricsAggregator {
double sumLon = lonSum.get(bucket);
double compensationLon = lonCompensations.get(bucket);
CompensatedSum compensatedSumLat = new CompensatedSum(sumLat, compensationLat);
CompensatedSum compensatedSumLon = new CompensatedSum(sumLon, compensationLon);
compensatedSumLat.reset(sumLat, compensationLat);
compensatedSumLon.reset(sumLon, compensationLon);
// update the sum
for (int i = 0; i < valueCount; ++i) {

View File

@ -81,6 +81,8 @@ class StatsAggregator extends NumericMetricsAggregator.MultiValue {
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
@ -105,7 +107,7 @@ class StatsAggregator extends NumericMetricsAggregator.MultiValue {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
kahanSummation.reset(sum, compensation);
for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();

View File

@ -69,6 +69,7 @@ class SumAggregator extends NumericMetricsAggregator.SingleValue {
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
@ -81,7 +82,7 @@ class SumAggregator extends NumericMetricsAggregator.SingleValue {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
kahanSummation.reset(sum, compensation);
for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();

View File

@ -46,8 +46,8 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
private final MultiValuesSource.NumericMultiValuesSource valuesSources;
private DoubleArray weights;
private DoubleArray sums;
private DoubleArray sumCompensations;
private DoubleArray valueSums;
private DoubleArray valueCompensations;
private DoubleArray weightCompensations;
private DocValueFormat format;
@ -60,8 +60,8 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
if (valuesSources != null) {
final BigArrays bigArrays = context.bigArrays();
weights = bigArrays.newDoubleArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
sumCompensations = bigArrays.newDoubleArray(1, true);
valueSums = bigArrays.newDoubleArray(1, true);
valueCompensations = bigArrays.newDoubleArray(1, true);
weightCompensations = bigArrays.newDoubleArray(1, true);
}
}
@ -80,13 +80,15 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues docValues = valuesSources.getField(VALUE_FIELD.getPreferredName(), ctx);
final SortedNumericDoubleValues docWeights = valuesSources.getField(WEIGHT_FIELD.getPreferredName(), ctx);
final CompensatedSum compensatedValueSum = new CompensatedSum(0, 0);
final CompensatedSum compensatedWeightSum = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, docValues) {
@Override
public void collect(int doc, long bucket) throws IOException {
weights = bigArrays.grow(weights, bucket + 1);
sums = bigArrays.grow(sums, bucket + 1);
sumCompensations = bigArrays.grow(sumCompensations, bucket + 1);
valueSums = bigArrays.grow(valueSums, bucket + 1);
valueCompensations = bigArrays.grow(valueCompensations, bucket + 1);
weightCompensations = bigArrays.grow(weightCompensations, bucket + 1);
if (docValues.advanceExact(doc) && docWeights.advanceExact(doc)) {
@ -102,42 +104,43 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
final int numValues = docValues.docValueCount();
assert numValues > 0;
double valueSum = valueSums.get(bucket);
double valueCompensation = valueCompensations.get(bucket);
compensatedValueSum.reset(valueSum, valueCompensation);
double weightSum = weights.get(bucket);
double weightCompensation = weightCompensations.get(bucket);
compensatedWeightSum.reset(weightSum, weightCompensation);
for (int i = 0; i < numValues; i++) {
kahanSum(docValues.nextValue() * weight, sums, sumCompensations, bucket);
kahanSum(weight, weights, weightCompensations, bucket);
compensatedValueSum.add(docValues.nextValue() * weight);
compensatedWeightSum.add(weight);
}
valueSums.set(bucket, compensatedValueSum.value());
valueCompensations.set(bucket, compensatedValueSum.delta());
weights.set(bucket, compensatedWeightSum.value());
weightCompensations.set(bucket, compensatedWeightSum.delta());
}
}
};
}
private static void kahanSum(double value, DoubleArray values, DoubleArray compensations, long bucket) {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = values.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation)
.add(value);
values.set(bucket, kahanSummation.value());
compensations.set(bucket, kahanSummation.delta());
}
@Override
public double metric(long owningBucketOrd) {
if (valuesSources == null || owningBucketOrd >= sums.size()) {
if (valuesSources == null || owningBucketOrd >= valueSums.size()) {
return Double.NaN;
}
return sums.get(owningBucketOrd) / weights.get(owningBucketOrd);
return valueSums.get(owningBucketOrd) / weights.get(owningBucketOrd);
}
@Override
public InternalAggregation buildAggregation(long bucket) {
if (valuesSources == null || bucket >= sums.size()) {
if (valuesSources == null || bucket >= valueSums.size()) {
return buildEmptyAggregation();
}
return new InternalWeightedAvg(name, sums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData());
return new InternalWeightedAvg(name, valueSums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData());
}
@Override
@ -147,7 +150,7 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
@Override
public void doClose() {
Releasables.close(weights, sums, sumCompensations, weightCompensations);
Releasables.close(weights, valueSums, valueCompensations, weightCompensations);
}
}