Calculate sum in Kahan summation algorithm in aggregations (#27807) (#27848)

This commit is contained in:
kel 2018-01-22 19:42:56 +08:00 committed by Adrien Grand
parent 700d9ecc95
commit 452c36c552
17 changed files with 557 additions and 37 deletions

View File

@ -44,6 +44,7 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {
LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DocValueFormat format;
public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
@ -55,6 +56,7 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
}
}
@ -76,15 +78,29 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {
public void collect(int doc, long bucket) throws IOException {
counts = bigArrays.grow(counts, bucket + 1);
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);
if (values.advanceExact(doc)) {
final int valueCount = values.docValueCount();
counts.increment(bucket, valueCount);
double sum = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
for (int i = 0; i < valueCount; i++) {
sum += values.nextValue();
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
sums.increment(bucket, sum);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
}
}
};
@ -113,7 +129,7 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {
@Override
public void doClose() {
Releasables.close(counts, sums);
Releasables.close(counts, sums, compensations);
}
}

View File

@ -91,9 +91,20 @@ public class InternalAvg extends InternalNumericMetricsAggregation.SingleValue i
public InternalAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
long count = 0;
double sum = 0;
double compensation = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
for (InternalAggregation aggregation : aggregations) {
count += ((InternalAvg) aggregation).count;
sum += ((InternalAvg) aggregation).sum;
InternalAvg avg = (InternalAvg) aggregation;
count += avg.count;
if (Double.isFinite(avg.sum) == false) {
sum += avg.sum;
} else if (Double.isFinite(sum)) {
double corrected = avg.sum - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
return new InternalAvg(getName(), sum, count, format, pipelineAggregators(), getMetaData());
}

View File

@ -152,12 +152,23 @@ public class InternalStats extends InternalNumericMetricsAggregation.MultiValue
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
double sum = 0;
double compensation = 0;
for (InternalAggregation aggregation : aggregations) {
InternalStats stats = (InternalStats) aggregation;
count += stats.getCount();
min = Math.min(min, stats.getMin());
max = Math.max(max, stats.getMax());
sum += stats.getSum();
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double value = stats.getSum();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
return new InternalStats(name, count, sum, min, max, format, pipelineAggregators(), getMetaData());
}

View File

@ -45,6 +45,7 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue {
LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DoubleArray mins;
DoubleArray maxes;
@ -59,6 +60,7 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue {
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
mins = bigArrays.newDoubleArray(1, false);
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
maxes = bigArrays.newDoubleArray(1, false);
@ -88,6 +90,7 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue {
final long overSize = BigArrays.overSize(bucket + 1);
counts = bigArrays.resize(counts, overSize);
sums = bigArrays.resize(sums, overSize);
compensations = bigArrays.resize(compensations, overSize);
mins = bigArrays.resize(mins, overSize);
maxes = bigArrays.resize(maxes, overSize);
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
@ -97,16 +100,28 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue {
if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
counts.increment(bucket, valuesCount);
double sum = 0;
double min = mins.get(bucket);
double max = maxes.get(bucket);
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
sum += value;
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.increment(bucket, sum);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
mins.set(bucket, min);
maxes.set(bucket, max);
}
@ -164,6 +179,6 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue {
@Override
public void doClose() {
Releasables.close(counts, maxes, mins, sums);
Releasables.close(counts, maxes, mins, sums, compensations);
}
}

View File

@ -49,9 +49,11 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DoubleArray mins;
DoubleArray maxes;
DoubleArray sumOfSqrs;
DoubleArray compensationOfSqrs;
public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter,
SearchContext context, Aggregator parent, double sigma, List<PipelineAggregator> pipelineAggregators,
@ -65,11 +67,13 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
mins = bigArrays.newDoubleArray(1, false);
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
maxes = bigArrays.newDoubleArray(1, false);
maxes.fill(0, maxes.size(), Double.NEGATIVE_INFINITY);
sumOfSqrs = bigArrays.newDoubleArray(1, true);
compensationOfSqrs = bigArrays.newDoubleArray(1, true);
}
}
@ -95,9 +99,11 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
final long overSize = BigArrays.overSize(bucket + 1);
counts = bigArrays.resize(counts, overSize);
sums = bigArrays.resize(sums, overSize);
compensations = bigArrays.resize(compensations, overSize);
mins = bigArrays.resize(mins, overSize);
maxes = bigArrays.resize(maxes, overSize);
sumOfSqrs = bigArrays.resize(sumOfSqrs, overSize);
compensationOfSqrs = bigArrays.resize(compensationOfSqrs, overSize);
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
maxes.fill(from, overSize, Double.NEGATIVE_INFINITY);
}
@ -105,19 +111,40 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
counts.increment(bucket, valuesCount);
double sum = 0;
double sumOfSqr = 0;
double min = mins.get(bucket);
double max = maxes.get(bucket);
// Compute the sum and sum of squires for double values with Kahan summation algorithm
// which is more accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
double sumOfSqr = sumOfSqrs.get(bucket);
double compensationOfSqr = compensationOfSqrs.get(bucket);
for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
sum += value;
sumOfSqr += value * value;
if (Double.isFinite(value) == false) {
sum += value;
sumOfSqr += value * value;
} else {
if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
if (Double.isFinite(sumOfSqr)) {
double correctedOfSqr = value * value - compensationOfSqr;
double newSumOfSqr = sumOfSqr + correctedOfSqr;
compensationOfSqr = (newSumOfSqr - sumOfSqr) - correctedOfSqr;
sumOfSqr = newSumOfSqr;
}
}
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.increment(bucket, sum);
sumOfSqrs.increment(bucket, sumOfSqr);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
sumOfSqrs.set(bucket, sumOfSqr);
compensationOfSqrs.set(bucket, compensationOfSqr);
mins.set(bucket, min);
maxes.set(bucket, max);
}
@ -196,6 +223,6 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
@Override
public void doClose() {
Releasables.close(counts, maxes, mins, sumOfSqrs, sums);
Releasables.close(counts, maxes, mins, sumOfSqrs, compensationOfSqrs, sums, compensations);
}
}

View File

@ -45,7 +45,7 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
private final double sigma;
public InternalExtendedStats(String name, long count, double sum, double min, double max, double sumOfSqrs, double sigma,
DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
super(name, count, sum, min, max, formatter, pipelineAggregators, metaData);
this.sumOfSqrs = sumOfSqrs;
this.sigma = sigma;
@ -142,16 +142,25 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
@Override
public InternalExtendedStats doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
double sumOfSqrs = 0;
double compensationOfSqrs = 0;
for (InternalAggregation aggregation : aggregations) {
InternalExtendedStats stats = (InternalExtendedStats) aggregation;
if (stats.sigma != sigma) {
throw new IllegalStateException("Cannot reduce other stats aggregations that have a different sigma");
}
sumOfSqrs += stats.getSumOfSquares();
double value = stats.getSumOfSquares();
if (Double.isFinite(value) == false) {
sumOfSqrs += value;
} else if (Double.isFinite(sumOfSqrs)) {
double correctedOfSqrs = value - compensationOfSqrs;
double newSumOfSqrs = sumOfSqrs + correctedOfSqrs;
compensationOfSqrs = (newSumOfSqrs - sumOfSqrs) - correctedOfSqrs;
sumOfSqrs = newSumOfSqrs;
}
}
final InternalStats stats = super.doReduce(aggregations, reduceContext);
return new InternalExtendedStats(name, stats.getCount(), stats.getSum(), stats.getMin(), stats.getMax(), sumOfSqrs, sigma,
format, pipelineAggregators(), getMetaData());
format, pipelineAggregators(), getMetaData());
}
static class Fields {

View File

@ -35,7 +35,7 @@ public class InternalSum extends InternalNumericMetricsAggregation.SingleValue i
private final double sum;
public InternalSum(String name, double sum, DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) {
Map<String, Object> metaData) {
super(name, pipelineAggregators, metaData);
this.sum = sum;
this.format = formatter;
@ -73,9 +73,20 @@ public class InternalSum extends InternalNumericMetricsAggregation.SingleValue i
@Override
public InternalSum doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = 0;
double compensation = 0;
for (InternalAggregation aggregation : aggregations) {
sum += ((InternalSum) aggregation).sum;
double value = ((InternalSum) aggregation).sum;
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
return new InternalSum(name, sum, format, pipelineAggregators(), getMetaData());
}

View File

@ -43,6 +43,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
private final DocValueFormat format;
private DoubleArray sums;
private DoubleArray compensations;
SumAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
@ -51,6 +52,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
this.format = formatter;
if (valuesSource != null) {
sums = context.bigArrays().newDoubleArray(1, true);
compensations = context.bigArrays().newDoubleArray(1, true);
}
}
@ -71,13 +73,27 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
@Override
public void collect(int doc, long bucket) throws IOException {
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);
if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
double sum = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
for (int i = 0; i < valuesCount; i++) {
sum += values.nextValue();
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
sums.increment(bucket, sum);
compensations.set(bucket, compensation);
sums.set(bucket, sum);
}
}
};
@ -106,6 +122,6 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
@Override
public void doClose() {
Releasables.close(sums);
Releasables.close(sums, compensations);
}
}

View File

@ -20,6 +20,7 @@
package org.elasticsearch.search.aggregations.metrics;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter;
@ -38,6 +39,8 @@ import org.elasticsearch.search.aggregations.metrics.stats.extended.InternalExte
import java.io.IOException;
import java.util.function.Consumer;
import static java.util.Collections.singleton;
public class ExtendedStatsAggregatorTests extends AggregatorTestCase {
private static final double TOLERANCE = 1e-5;
@ -132,6 +135,68 @@ public class ExtendedStatsAggregatorTests extends AggregatorTestCase {
);
}
public void testSummationAccuracy() throws IOException {
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
verifyStatsOfDoubles(values, 13.5, 16.21, 0d);
// Summing up an array which contains NaN and infinities and expect a result same as naive summation
int n = randomIntBetween(5, 10);
values = new double[n];
double sum = 0;
double sumOfSqrs = 0;
for (int i = 0; i < n; i++) {
values[i] = frequently()
? randomFrom(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY)
: randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
sum += values[i];
sumOfSqrs += values[i] * values[i];
}
verifyStatsOfDoubles(values, sum, sumOfSqrs, TOLERANCE);
// Summing up some big double values and expect infinity result
n = randomIntBetween(5, 10);
double[] largeValues = new double[n];
for (int i = 0; i < n; i++) {
largeValues[i] = Double.MAX_VALUE;
}
verifyStatsOfDoubles(largeValues, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, 0d);
for (int i = 0; i < n; i++) {
largeValues[i] = -Double.MAX_VALUE;
}
verifyStatsOfDoubles(largeValues, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0d);
}
private void verifyStatsOfDoubles(double[] values, double expectedSum,
double expectedSumOfSqrs, double delta) throws IOException {
MappedFieldType ft = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
final String fieldName = "field";
ft.setName(fieldName);
double max = Double.NEGATIVE_INFINITY;
double min = Double.POSITIVE_INFINITY;
for (double value : values) {
max = Math.max(max, value);
min = Math.min(min, value);
}
double expectedMax = max;
double expectedMin = min;
testCase(ft,
iw -> {
for (double value : values) {
iw.addDocument(singleton(new NumericDocValuesField(fieldName, NumericUtils.doubleToSortableLong(value))));
}
},
stats -> {
assertEquals(values.length, stats.getCount());
assertEquals(expectedSum / values.length, stats.getAvg(), delta);
assertEquals(expectedSum, stats.getSum(), delta);
assertEquals(expectedSumOfSqrs, stats.getSumOfSquares(), delta);
assertEquals(expectedMax, stats.getMax(), 0d);
assertEquals(expectedMin, stats.getMin(), 0d);
}
);
}
public void testCase(MappedFieldType ft,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalExtendedStats> verify) throws IOException {

View File

@ -21,6 +21,7 @@ package org.elasticsearch.search.aggregations.metrics;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.ParsedAggregation;
import org.elasticsearch.search.aggregations.metrics.stats.extended.ExtendedStats.Bounds;
import org.elasticsearch.search.aggregations.metrics.stats.extended.InternalExtendedStats;
@ -28,6 +29,7 @@ import org.elasticsearch.search.aggregations.metrics.stats.extended.ParsedExtend
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -188,4 +190,44 @@ public class InternalExtendedStatsTests extends InternalAggregationTestCase<Inte
}
return new InternalExtendedStats(name, count, sum, min, max, sumOfSqrs, sigma, formatter, pipelineAggregators, metaData);
}
public void testSummationAccuracy() {
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
verifySumOfSqrsOfDoubles(values, 13.5, 0d);
int n = randomIntBetween(5, 10);
values = new double[n];
double sum = 0;
for (int i = 0; i < n; i++) {
values[i] = frequently()
? randomFrom(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY)
: randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
sum += values[i];
}
verifySumOfSqrsOfDoubles(values, sum, TOLERANCE);
// Summing up some big double values and expect infinity result
n = randomIntBetween(5, 10);
double[] largeValues = new double[n];
for (int i = 0; i < n; i++) {
largeValues[i] = Double.MAX_VALUE;
}
verifySumOfSqrsOfDoubles(largeValues, Double.POSITIVE_INFINITY, 0d);
for (int i = 0; i < n; i++) {
largeValues[i] = -Double.MAX_VALUE;
}
verifySumOfSqrsOfDoubles(largeValues, Double.NEGATIVE_INFINITY, 0d);
}
private void verifySumOfSqrsOfDoubles(double[] values, double expectedSumOfSqrs, double delta) {
List<InternalAggregation> aggregations = new ArrayList<>(values.length);
double sigma = randomDouble();
for (double sumOfSqrs : values) {
aggregations.add(new InternalExtendedStats("dummy1", 1, 0.0, 0.0, 0.0, sumOfSqrs, sigma, null, null, null));
}
InternalExtendedStats stats = new InternalExtendedStats("dummy", 1, 0.0, 0.0, 0.0, 0.0, sigma, null, null, null);
InternalExtendedStats reduced = stats.doReduce(aggregations, null);
assertEquals(expectedSumOfSqrs, reduced.getSumOfSquares(), delta);
}
}

View File

@ -23,6 +23,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.ParsedAggregation;
import org.elasticsearch.search.aggregations.metrics.stats.InternalStats;
import org.elasticsearch.search.aggregations.metrics.stats.ParsedStats;
@ -30,6 +31,7 @@ import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@ -48,7 +50,7 @@ public class InternalStatsTests extends InternalAggregationTestCase<InternalStat
}
protected InternalStats createInstance(String name, long count, double sum, double min, double max, DocValueFormat formatter,
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
return new InternalStats(name, count, sum, min, max, formatter, pipelineAggregators, metaData);
}
@ -74,6 +76,54 @@ public class InternalStatsTests extends InternalAggregationTestCase<InternalStat
assertEquals(expectedMax, reduced.getMax(), 0d);
}
public void testSummationAccuracy() {
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
verifyStatsOfDoubles(values, 13.5, 0.9, 0d);
int n = randomIntBetween(5, 10);
values = new double[n];
double sum = 0;
for (int i = 0; i < n; i++) {
values[i] = frequently()
? randomFrom(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY)
: randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
sum += values[i];
}
verifyStatsOfDoubles(values, sum, sum / n, TOLERANCE);
// Summing up some big double values and expect infinity result
n = randomIntBetween(5, 10);
double[] largeValues = new double[n];
for (int i = 0; i < n; i++) {
largeValues[i] = Double.MAX_VALUE;
}
verifyStatsOfDoubles(largeValues, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, 0d);
for (int i = 0; i < n; i++) {
largeValues[i] = -Double.MAX_VALUE;
}
verifyStatsOfDoubles(largeValues, Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d);
}
private void verifyStatsOfDoubles(double[] values, double expectedSum, double expectedAvg, double delta) {
List<InternalAggregation> aggregations = new ArrayList<>(values.length);
double max = Double.NEGATIVE_INFINITY;
double min = Double.POSITIVE_INFINITY;
for (double value : values) {
max = Math.max(max, value);
min = Math.min(min, value);
aggregations.add(new InternalStats("dummy1", 1, value, value, value, null, null, null));
}
InternalStats internalStats = new InternalStats("dummy2", 0, 0.0, 2.0, 0.0, null, null, null);
InternalStats reduced = internalStats.doReduce(aggregations, null);
assertEquals("dummy2", reduced.getName());
assertEquals(values.length, reduced.getCount());
assertEquals(expectedSum, reduced.getSum(), delta);
assertEquals(expectedAvg, reduced.getAvg(), delta);
assertEquals(min, reduced.getMin(), 0d);
assertEquals(max, reduced.getMax(), 0d);
}
@Override
protected void assertFromXContent(InternalStats aggregation, ParsedAggregation parsedAggregation) {
assertTrue(parsedAggregation instanceof ParsedStats);

View File

@ -20,12 +20,14 @@ package org.elasticsearch.search.aggregations.metrics;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.ParsedAggregation;
import org.elasticsearch.search.aggregations.metrics.sum.InternalSum;
import org.elasticsearch.search.aggregations.metrics.sum.ParsedSum;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -34,7 +36,7 @@ public class InternalSumTests extends InternalAggregationTestCase<InternalSum> {
@Override
protected InternalSum createTestInstance(String name, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
double value = frequently() ? randomDouble() : randomFrom(new Double[] { Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY });
double value = frequently() ? randomDouble() : randomFrom(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, Double.NaN);
DocValueFormat formatter = randomFrom(new DocValueFormat.Decimal("###.##"), DocValueFormat.BOOLEAN, DocValueFormat.RAW);
return new InternalSum(name, value, formatter, pipelineAggregators, metaData);
}
@ -50,6 +52,47 @@ public class InternalSumTests extends InternalAggregationTestCase<InternalSum> {
assertEquals(expectedSum, reduced.getValue(), 0.0001d);
}
public void testSummationAccuracy() {
// Summing up a normal array and expect an accurate value
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
verifySummationOfDoubles(values, 13.5, 0d);
// Summing up an array which contains NaN and infinities and expect a result same as naive summation
int n = randomIntBetween(5, 10);
values = new double[n];
double sum = 0;
for (int i = 0; i < n; i++) {
values[i] = frequently()
? randomFrom(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY)
: randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
sum += values[i];
}
verifySummationOfDoubles(values, sum, TOLERANCE);
// Summing up some big double values and expect infinity result
n = randomIntBetween(5, 10);
double[] largeValues = new double[n];
for (int i = 0; i < n; i++) {
largeValues[i] = Double.MAX_VALUE;
}
verifySummationOfDoubles(largeValues, Double.POSITIVE_INFINITY, 0d);
for (int i = 0; i < n; i++) {
largeValues[i] = -Double.MAX_VALUE;
}
verifySummationOfDoubles(largeValues, Double.NEGATIVE_INFINITY, 0d);
}
private void verifySummationOfDoubles(double[] values, double expected, double delta) {
List<InternalAggregation> aggregations = new ArrayList<>(values.length);
for (double value : values) {
aggregations.add(new InternalSum("dummy1", value, null, null, null));
}
InternalSum internalSum = new InternalSum("dummy", 0, null, null, null);
InternalSum reduced = internalSum.doReduce(aggregations, null);
assertEquals(expected, reduced.value(), delta);
}
@Override
protected void assertFromXContent(InternalSum sum, ParsedAggregation parsedAggregation) {
ParsedSum parsed = ((ParsedSum) parsedAggregation);

View File

@ -19,6 +19,7 @@
package org.elasticsearch.search.aggregations.metrics;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter;
@ -36,6 +37,8 @@ import org.elasticsearch.search.aggregations.metrics.stats.StatsAggregationBuild
import java.io.IOException;
import java.util.function.Consumer;
import static java.util.Collections.singleton;
public class StatsAggregatorTests extends AggregatorTestCase {
static final double TOLERANCE = 1e-10;
@ -113,6 +116,66 @@ public class StatsAggregatorTests extends AggregatorTestCase {
);
}
public void testSummationAccuracy() throws IOException {
// Summing up a normal array and expect an accurate value
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
verifySummationOfDoubles(values, 15.3, 0.9, 0d);
// Summing up an array which contains NaN and infinities and expect a result same as naive summation
int n = randomIntBetween(5, 10);
values = new double[n];
double sum = 0;
for (int i = 0; i < n; i++) {
values[i] = frequently()
? randomFrom(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY)
: randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
sum += values[i];
}
verifySummationOfDoubles(values, sum, sum / n, TOLERANCE);
// Summing up some big double values and expect infinity result
n = randomIntBetween(5, 10);
double[] largeValues = new double[n];
for (int i = 0; i < n; i++) {
largeValues[i] = Double.MAX_VALUE;
}
verifySummationOfDoubles(largeValues, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, 0d);
for (int i = 0; i < n; i++) {
largeValues[i] = -Double.MAX_VALUE;
}
verifySummationOfDoubles(largeValues, Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d);
}
private void verifySummationOfDoubles(double[] values, double expectedSum,
double expectedAvg, double delta) throws IOException {
MappedFieldType ft = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
ft.setName("field");
double max = Double.NEGATIVE_INFINITY;
double min = Double.POSITIVE_INFINITY;
for (double value : values) {
max = Math.max(max, value);
min = Math.min(min, value);
}
double expectedMax = max;
double expectedMin = min;
testCase(ft,
iw -> {
for (double value : values) {
iw.addDocument(singleton(new NumericDocValuesField("field", NumericUtils.doubleToSortableLong(value))));
}
},
stats -> {
assertEquals(values.length, stats.getCount());
assertEquals(expectedAvg, stats.getAvg(), delta);
assertEquals(expectedSum, stats.getSum(), delta);
assertEquals(expectedMax, stats.getMax(), 0d);
assertEquals(expectedMin, stats.getMin(), 0d);
}
);
}
public void testCase(MappedFieldType ft,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalStats> verify) throws IOException {

View File

@ -34,6 +34,7 @@ import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
@ -107,7 +108,7 @@ public class SumAggregatorTests extends AggregatorTestCase {
}
public void testStringField() throws IOException {
IllegalStateException e = expectThrows(IllegalStateException.class , () -> {
IllegalStateException e = expectThrows(IllegalStateException.class, () -> {
testCase(new MatchAllDocsQuery(), iw -> {
iw.addDocument(singleton(new SortedDocValuesField(FIELD_NAME, new BytesRef("1"))));
}, count -> assertEquals(0L, count.getValue(), 0d));
@ -116,10 +117,59 @@ public class SumAggregatorTests extends AggregatorTestCase {
"Re-index with correct docvalues type.", e.getMessage());
}
public void testSummationAccuracy() throws IOException {
// Summing up a normal array and expect an accurate value
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
verifySummationOfDoubles(values, 15.3, 0d);
// Summing up an array which contains NaN and infinities and expect a result same as naive summation
int n = randomIntBetween(5, 10);
values = new double[n];
double sum = 0;
for (int i = 0; i < n; i++) {
values[i] = frequently()
? randomFrom(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY)
: randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
sum += values[i];
}
verifySummationOfDoubles(values, sum, 1e-10);
// Summing up some big double values and expect infinity result
n = randomIntBetween(5, 10);
double[] largeValues = new double[n];
for (int i = 0; i < n; i++) {
largeValues[i] = Double.MAX_VALUE;
}
verifySummationOfDoubles(largeValues, Double.POSITIVE_INFINITY, 0d);
for (int i = 0; i < n; i++) {
largeValues[i] = -Double.MAX_VALUE;
}
verifySummationOfDoubles(largeValues, Double.NEGATIVE_INFINITY, 0d);
}
private void verifySummationOfDoubles(double[] values, double expected, double delta) throws IOException {
testCase(new MatchAllDocsQuery(),
iw -> {
for (double value : values) {
iw.addDocument(singleton(new NumericDocValuesField(FIELD_NAME, NumericUtils.doubleToSortableLong(value))));
}
},
result -> assertEquals(expected, result.getValue(), delta),
NumberFieldMapper.NumberType.DOUBLE
);
}
private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> indexer,
Consumer<Sum> verify) throws IOException {
testCase(query, indexer, verify, NumberFieldMapper.NumberType.LONG);
}
private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> indexer,
Consumer<Sum> verify,
NumberFieldMapper.NumberType fieldNumberType) throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
indexer.accept(indexWriter);
@ -128,7 +178,7 @@ public class SumAggregatorTests extends AggregatorTestCase {
try (IndexReader indexReader = DirectoryReader.open(directory)) {
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType);
fieldType.setName(FIELD_NAME);
fieldType.setHasDocValues(true);

View File

@ -30,13 +30,11 @@ import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.metrics.avg.AvgAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.avg.AvgAggregator;
import org.elasticsearch.search.aggregations.metrics.avg.InternalAvg;
import java.io.IOException;
import java.util.Arrays;
@ -103,8 +101,59 @@ public class AvgAggregatorTests extends AggregatorTestCase {
});
}
private void testCase(Query query, CheckedConsumer<RandomIndexWriter, IOException> buildIndex, Consumer<InternalAvg> verify)
throws IOException {
public void testSummationAccuracy() throws IOException {
// Summing up a normal array and expect an accurate value
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
verifyAvgOfDoubles(values, 0.9, 0d);
// Summing up an array which contains NaN and infinities and expect a result same as naive summation
int n = randomIntBetween(5, 10);
values = new double[n];
double sum = 0;
for (int i = 0; i < n; i++) {
values[i] = frequently()
? randomFrom(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY)
: randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
sum += values[i];
}
verifyAvgOfDoubles(values, sum / n, 1e-10);
// Summing up some big double values and expect infinity result
n = randomIntBetween(5, 10);
double[] largeValues = new double[n];
for (int i = 0; i < n; i++) {
largeValues[i] = Double.MAX_VALUE;
}
verifyAvgOfDoubles(largeValues, Double.POSITIVE_INFINITY, 0d);
for (int i = 0; i < n; i++) {
largeValues[i] = -Double.MAX_VALUE;
}
verifyAvgOfDoubles(largeValues, Double.NEGATIVE_INFINITY, 0d);
}
private void verifyAvgOfDoubles(double[] values, double expected, double delta) throws IOException {
testCase(new MatchAllDocsQuery(),
iw -> {
for (double value : values) {
iw.addDocument(singleton(new NumericDocValuesField("number", NumericUtils.doubleToSortableLong(value))));
}
},
avg -> assertEquals(expected, avg.getValue(), delta),
NumberFieldMapper.NumberType.DOUBLE
);
}
private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalAvg> verify) throws IOException {
testCase(query, buildIndex, verify, NumberFieldMapper.NumberType.LONG);
}
private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalAvg> verify,
NumberFieldMapper.NumberType fieldNumberType) throws IOException {
Directory directory = newDirectory();
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
buildIndex.accept(indexWriter);
@ -114,7 +163,7 @@ public class AvgAggregatorTests extends AggregatorTestCase {
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);
AvgAggregationBuilder aggregationBuilder = new AvgAggregationBuilder("_name").field("number");
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType);
fieldType.setName("number");
AvgAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType);

View File

@ -21,10 +21,12 @@ package org.elasticsearch.search.aggregations.metrics.avg;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.ParsedAggregation;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -56,6 +58,45 @@ public class InternalAvgTests extends InternalAggregationTestCase<InternalAvg> {
assertEquals(sum / counts, reduced.value(), 0.0000001);
}
public void testSummationAccuracy() {
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
verifyAvgOfDoubles(values, 0.9, 0d);
int n = randomIntBetween(5, 10);
values = new double[n];
double sum = 0;
for (int i = 0; i < n; i++) {
values[i] = frequently()
? randomFrom(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY)
: randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
sum += values[i];
}
verifyAvgOfDoubles(values, sum / n, TOLERANCE);
// Summing up some big double values and expect infinity result
n = randomIntBetween(5, 10);
double[] largeValues = new double[n];
for (int i = 0; i < n; i++) {
largeValues[i] = Double.MAX_VALUE;
}
verifyAvgOfDoubles(largeValues, Double.POSITIVE_INFINITY, 0d);
for (int i = 0; i < n; i++) {
largeValues[i] = -Double.MAX_VALUE;
}
verifyAvgOfDoubles(largeValues, Double.NEGATIVE_INFINITY, 0d);
}
private void verifyAvgOfDoubles(double[] values, double expected, double delta) {
List<InternalAggregation> aggregations = new ArrayList<>(values.length);
for (double value : values) {
aggregations.add(new InternalAvg("dummy1", value, 1, null, null, null));
}
InternalAvg internalAvg = new InternalAvg("dummy2", 0, 0, null, null, null);
InternalAvg reduced = internalAvg.doReduce(aggregations, null);
assertEquals(expected, reduced.getValue(), delta);
}
@Override
protected void assertFromXContent(InternalAvg avg, ParsedAggregation parsedAggregation) {
ParsedAvg parsed = ((ParsedAvg) parsedAggregation);

View File

@ -150,6 +150,7 @@ import static org.hamcrest.Matchers.equalTo;
public abstract class InternalAggregationTestCase<T extends InternalAggregation> extends AbstractWireSerializingTestCase<T> {
public static final int DEFAULT_MAX_BUCKETS = 100000;
protected static final double TOLERANCE = 1e-10;
private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(
new SearchModule(Settings.EMPTY, false, emptyList()).getNamedWriteables());