Refactor and DRY up Kahan Sum algorithm (#48558) (#48959)

This commit is contained in:
Mark Tozzi 2019-11-11 15:09:19 -05:00 committed by GitHub
parent c45470f84f
commit d9e569278f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 240 additions and 132 deletions

View File

@ -87,20 +87,15 @@ class AvgAggregator extends NumericMetricsAggregator.SingleValue {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
for (int i = 0; i < valueCount; i++) {
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
kahanSummation.add(value);
}
}
sums.set(bucket, sum);
compensations.set(bucket, compensation);
sums.set(bucket, kahanSummation.value());
compensations.set(bucket, kahanSummation.delta());
}
}
};

View File

@ -0,0 +1,93 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.metrics;
/**
* Used to calculate sums using the Kahan summation algorithm.
*
* <p>The Kahan summation algorithm (also known as compensated summation) reduces the numerical errors that
* occur when adding a sequence of finite precision floating point numbers. Numerical errors arise due to
* truncation and rounding. These errors can lead to numerical instability.
*
* @see <a href="http://en.wikipedia.org/wiki/Kahan_summation_algorithm">Kahan Summation Algorithm</a>
*/
public class CompensatedSum {
private static final double NO_CORRECTION = 0.0;
private double value;
private double delta;
/**
* Used to calculate sums using the Kahan summation algorithm.
*
* @param value the sum
* @param delta correction term
*/
public CompensatedSum(double value, double delta) {
this.value = value;
this.delta = delta;
}
/**
* The value of the sum.
*/
public double value() {
return value;
}
/**
* The correction term.
*/
public double delta() {
return delta;
}
/**
* Increments the Kahan sum by adding a value without a correction term.
*/
public CompensatedSum add(double value) {
return add(value, NO_CORRECTION);
}
/**
* Increments the Kahan sum by adding two sums, and updating the correction term for reducing numeric errors.
*/
public CompensatedSum add(double value, double delta) {
// If the value is Inf or NaN, just add it to the running tally to "convert" to
// Inf/NaN. This keeps the behavior bwc from before kahan summing
if (Double.isFinite(value) == false) {
this.value = value + this.value;
}
if (Double.isFinite(this.value)) {
double correctedSum = value + (this.delta + delta);
double updatedValue = this.value + correctedSum;
this.delta = correctedSum - (updatedValue - this.value);
this.value = updatedValue;
}
return this;
}
}

View File

@ -117,34 +117,24 @@ class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue {
// which is more accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum compensatedSum = new CompensatedSum(sum, compensation);
double sumOfSqr = sumOfSqrs.get(bucket);
double compensationOfSqr = compensationOfSqrs.get(bucket);
CompensatedSum compensatedSumOfSqr = new CompensatedSum(sumOfSqr, compensationOfSqr);
for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
sumOfSqr += value * value;
} else {
if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
if (Double.isFinite(sumOfSqr)) {
double correctedOfSqr = value * value - compensationOfSqr;
double newSumOfSqr = sumOfSqr + correctedOfSqr;
compensationOfSqr = (newSumOfSqr - sumOfSqr) - correctedOfSqr;
sumOfSqr = newSumOfSqr;
}
}
compensatedSum.add(value);
compensatedSumOfSqr.add(value * value);
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.set(bucket, sum);
compensations.set(bucket, compensation);
sumOfSqrs.set(bucket, sumOfSqr);
compensationOfSqrs.set(bucket, compensationOfSqr);
sums.set(bucket, compensatedSum.value());
compensations.set(bucket, compensatedSum.delta());
sumOfSqrs.set(bucket, compensatedSumOfSqr.value());
compensationOfSqrs.set(bucket, compensatedSumOfSqr.delta());
mins.set(bucket, min);
maxes.set(bucket, max);
}

View File

@ -88,24 +88,21 @@ final class GeoCentroidAggregator extends MetricsAggregator {
double sumLon = lonSum.get(bucket);
double compensationLon = lonCompensations.get(bucket);
CompensatedSum compensatedSumLat = new CompensatedSum(sumLat, compensationLat);
CompensatedSum compensatedSumLon = new CompensatedSum(sumLon, compensationLon);
// update the sum
for (int i = 0; i < valueCount; ++i) {
GeoPoint value = values.nextValue();
//latitude
double correctedLat = value.getLat() - compensationLat;
double newSumLat = sumLat + correctedLat;
compensationLat = (newSumLat - sumLat) - correctedLat;
sumLat = newSumLat;
compensatedSumLat.add(value.getLat());
//longitude
double correctedLon = value.getLon() - compensationLon;
double newSumLon = sumLon + correctedLon;
compensationLon = (newSumLon - sumLon) - correctedLon;
sumLon = newSumLon;
compensatedSumLon.add(value.getLon());
}
lonSum.set(bucket, sumLon);
lonCompensations.set(bucket, compensationLon);
latSum.set(bucket, sumLat);
latCompensations.set(bucket, compensationLat);
lonSum.set(bucket, compensatedSumLon.value());
lonCompensations.set(bucket, compensatedSumLon.delta());
latSum.set(bucket, compensatedSumLat.value());
latCompensations.set(bucket, compensatedSumLat.delta());
}
}
};

View File

@ -88,24 +88,16 @@ public class InternalAvg extends InternalNumericMetricsAggregation.SingleValue i
@Override
public InternalAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
long count = 0;
double sum = 0;
double compensation = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
for (InternalAggregation aggregation : aggregations) {
InternalAvg avg = (InternalAvg) aggregation;
count += avg.count;
if (Double.isFinite(avg.sum) == false) {
sum += avg.sum;
} else if (Double.isFinite(sum)) {
double corrected = avg.sum - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
kahanSummation.add(avg.sum);
}
}
return new InternalAvg(getName(), sum, count, format, pipelineAggregators(), getMetaData());
return new InternalAvg(getName(), kahanSummation.value(), count, format, pipelineAggregators(), getMetaData());
}
@Override

View File

@ -149,8 +149,8 @@ public class InternalStats extends InternalNumericMetricsAggregation.MultiValue
long count = 0;
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
double sum = 0;
double compensation = 0;
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
for (InternalAggregation aggregation : aggregations) {
InternalStats stats = (InternalStats) aggregation;
count += stats.getCount();
@ -158,17 +158,9 @@ public class InternalStats extends InternalNumericMetricsAggregation.MultiValue
max = Math.max(max, stats.getMax());
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double value = stats.getSum();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
kahanSummation.add(stats.getSum());
}
}
return new InternalStats(name, count, sum, min, max, format, pipelineAggregators(), getMetaData());
return new InternalStats(name, count, kahanSummation.value(), min, max, format, pipelineAggregators(), getMetaData());
}
static class Fields {

View File

@ -74,20 +74,12 @@ public class InternalSum extends InternalNumericMetricsAggregation.SingleValue i
public InternalSum doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = 0;
double compensation = 0;
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
for (InternalAggregation aggregation : aggregations) {
double value = ((InternalSum) aggregation).sum;
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
kahanSummation.add(value);
}
}
return new InternalSum(name, sum, format, pipelineAggregators(), getMetaData());
return new InternalSum(name, kahanSummation.value(), format, pipelineAggregators(), getMetaData());
}
@Override

View File

@ -88,37 +88,21 @@ public class InternalWeightedAvg extends InternalNumericMetricsAggregation.Singl
@Override
public InternalWeightedAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
double weight = 0;
double sum = 0;
double sumCompensation = 0;
double weightCompensation = 0;
CompensatedSum sumCompensation = new CompensatedSum(0, 0);
CompensatedSum weightCompensation = new CompensatedSum(0, 0);
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
for (InternalAggregation aggregation : aggregations) {
InternalWeightedAvg avg = (InternalWeightedAvg) aggregation;
// If the weight is Inf or NaN, just add it to the running tally to "convert" to
// Inf/NaN. This keeps the behavior bwc from before kahan summing
if (Double.isFinite(avg.weight) == false) {
weight += avg.weight;
} else if (Double.isFinite(weight)) {
double corrected = avg.weight - weightCompensation;
double newWeight = weight + corrected;
weightCompensation = (newWeight - weight) - corrected;
weight = newWeight;
weightCompensation.add(avg.weight);
sumCompensation.add(avg.sum);
}
// If the avg is Inf or NaN, just add it to the running tally to "convert" to
// Inf/NaN. This keeps the behavior bwc from before kahan summing
if (Double.isFinite(avg.sum) == false) {
sum += avg.sum;
} else if (Double.isFinite(sum)) {
double corrected = avg.sum - sumCompensation;
double newSum = sum + corrected;
sumCompensation = (newSum - sum) - corrected;
sum = newSum;
}
}
return new InternalWeightedAvg(getName(), sum, weight, format, pipelineAggregators(), getMetaData());
return new InternalWeightedAvg(getName(), sumCompensation.value(), weightCompensation.value(),
format, pipelineAggregators(), getMetaData());
}
@Override
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.field(CommonFields.VALUE.getPreferredName(), weight != 0 ? getValue() : null);

View File

@ -105,22 +105,16 @@ class StatsAggregator extends NumericMetricsAggregator.MultiValue {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
kahanSummation.add(value);
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.set(bucket, sum);
compensations.set(bucket, compensation);
sums.set(bucket, kahanSummation.value());
compensations.set(bucket, kahanSummation.delta());
mins.set(bucket, min);
maxes.set(bucket, max);
}

View File

@ -81,19 +81,15 @@ class SumAggregator extends NumericMetricsAggregator.SingleValue {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
kahanSummation.add(value);
}
}
compensations.set(bucket, compensation);
sums.set(bucket, sum);
compensations.set(bucket, kahanSummation.delta());
sums.set(bucket, kahanSummation.value());
}
}
};

View File

@ -117,16 +117,11 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
double sum = values.get(bucket);
double compensation = compensations.get(bucket);
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
values.set(bucket, sum);
compensations.set(bucket, compensation);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation)
.add(value);
values.set(bucket, kahanSummation.value());
compensations.set(bucket, kahanSummation.delta());
}
@Override

View File

@ -0,0 +1,88 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.metrics;
import org.elasticsearch.test.ESTestCase;
import org.junit.Assert;
public class CompensatedSumTests extends ESTestCase {
/**
* When adding a series of numbers the order of the numbers should not impact the results.
*
* <p>This test shows that a naive summation comes up with a different result than Kahan
* Summation when you start with either a smaller or larger number in some cases and
* helps prove our Kahan Summation is working.
*/
public void testAdd() {
final CompensatedSum smallSum = new CompensatedSum(0.001, 0.0);
final CompensatedSum largeSum = new CompensatedSum(1000, 0.0);
CompensatedSum compensatedResult1 = new CompensatedSum(0.001, 0.0);
CompensatedSum compensatedResult2 = new CompensatedSum(1000, 0.0);
double naiveResult1 = smallSum.value();
double naiveResult2 = largeSum.value();
for (int i = 0; i < 10; i++) {
compensatedResult1.add(smallSum.value());
compensatedResult2.add(smallSum.value());
naiveResult1 += smallSum.value();
naiveResult2 += smallSum.value();
}
compensatedResult1.add(largeSum.value());
compensatedResult2.add(smallSum.value());
naiveResult1 += largeSum.value();
naiveResult2 += smallSum.value();
// Kahan summation gave the same result no matter what order we added
Assert.assertEquals(1000.011, compensatedResult1.value(), 0.0);
Assert.assertEquals(1000.011, compensatedResult2.value(), 0.0);
// naive addition gave a small floating point error
Assert.assertEquals(1000.011, naiveResult1, 0.0);
Assert.assertEquals(1000.0109999999997, naiveResult2, 0.0);
Assert.assertEquals(compensatedResult1.value(), compensatedResult2.value(), 0.0);
Assert.assertEquals(naiveResult1, naiveResult2, 0.0001);
Assert.assertNotEquals(naiveResult1, naiveResult2, 0.0);
}
public void testDelta() {
CompensatedSum compensatedResult1 = new CompensatedSum(0.001, 0.0);
for (int i = 0; i < 10; i++) {
compensatedResult1.add(0.001);
}
Assert.assertEquals(0.011, compensatedResult1.value(), 0.0);
Assert.assertEquals(Double.parseDouble("8.673617379884035E-19"), compensatedResult1.delta(), 0.0);
}
public void testInfiniteAndNaN() {
CompensatedSum compensatedResult1 = new CompensatedSum(0, 0);
double[] doubles = {Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, Double.NaN};
for (double d : doubles) {
compensatedResult1.add(d);
}
Assert.assertTrue(Double.isNaN(compensatedResult1.value()));
}
}