HADOOP-18426. Use weighted calculation for MutableStat mean/variance to fix accuracy. (#4844). Contributed by Erik Krogen.

Co-authored-by: Shuyan Zhang <zqingchai@gmail.com>
Signed-off-by: He Xiaoqiao <hexiaoqiao@apache.org>
This commit is contained in:
Erik Krogen 2022-09-06 22:49:56 -07:00 committed by GitHub
parent cc41ad63f9
commit c664f953c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 40 deletions

View File

@ -27,33 +27,29 @@ import org.apache.hadoop.classification.InterfaceAudience;
public class SampleStat { public class SampleStat {
private final MinMax minmax = new MinMax(); private final MinMax minmax = new MinMax();
private long numSamples = 0; private long numSamples = 0;
private double a0, a1, s0, s1, total; private double mean, s;
/** /**
* Construct a new running sample stat * Construct a new running sample stat
*/ */
public SampleStat() { public SampleStat() {
a0 = s0 = 0.0; mean = 0.0;
total = 0.0; s = 0.0;
} }
public void reset() { public void reset() {
numSamples = 0; numSamples = 0;
a0 = s0 = 0.0; mean = 0.0;
total = 0.0; s = 0.0;
minmax.reset(); minmax.reset();
} }
// We want to reuse the object, sometimes. // We want to reuse the object, sometimes.
void reset(long numSamples, double a0, double a1, double s0, double s1, void reset(long numSamples1, double mean1, double s1, MinMax minmax1) {
double total, MinMax minmax) { numSamples = numSamples1;
this.numSamples = numSamples; mean = mean1;
this.a0 = a0; s = s1;
this.a1 = a1; minmax.reset(minmax1);
this.s0 = s0;
this.s1 = s1;
this.total = total;
this.minmax.reset(minmax);
} }
/** /**
@ -61,7 +57,7 @@ public class SampleStat {
* @param other the destination to hold our values * @param other the destination to hold our values
*/ */
public void copyTo(SampleStat other) { public void copyTo(SampleStat other) {
other.reset(numSamples, a0, a1, s0, s1, total, minmax); other.reset(numSamples, mean, s, minmax);
} }
/** /**
@ -78,24 +74,22 @@ public class SampleStat {
* Add some sample and a partial sum to the running stat. * Add some sample and a partial sum to the running stat.
* Note, min/max is not evaluated using this method. * Note, min/max is not evaluated using this method.
* @param nSamples number of samples * @param nSamples number of samples
* @param x the partial sum * @param xTotal the partial sum
* @return self * @return self
*/ */
public SampleStat add(long nSamples, double x) { public SampleStat add(long nSamples, double xTotal) {
numSamples += nSamples; numSamples += nSamples;
total += x;
if (numSamples == 1) { // use the weighted incremental version of Welford's algorithm to get
a0 = a1 = x; // numerical stability while treating the samples as being weighted
s0 = 0.0; // by nSamples
} // see https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
else {
// The Welford method for numerical stability double x = xTotal / nSamples;
a1 = a0 + (x - a0) / numSamples; double meanOld = mean;
s1 = s0 + (x - a0) * (x - a1);
a0 = a1; mean += ((double) nSamples / numSamples) * (x - meanOld);
s0 = s1; s += nSamples * (x - meanOld) * (x - mean);
}
return this; return this;
} }
@ -110,21 +104,21 @@ public class SampleStat {
* @return the total of all samples added * @return the total of all samples added
*/ */
public double total() { public double total() {
return total; return mean * numSamples;
} }
/** /**
* @return the arithmetic mean of the samples * @return the arithmetic mean of the samples
*/ */
public double mean() { public double mean() {
return numSamples > 0 ? (total / numSamples) : 0.0; return numSamples > 0 ? mean : 0.0;
} }
/** /**
* @return the variance of the samples * @return the variance of the samples
*/ */
public double variance() { public double variance() {
return numSamples > 1 ? s1 / (numSamples - 1) : 0.0; return numSamples > 1 ? s / (numSamples - 1) : 0.0;
} }
/** /**

View File

@ -29,6 +29,8 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Random; import java.util.Random;
@ -36,6 +38,7 @@ import java.util.concurrent.CountDownLatch;
import org.apache.hadoop.metrics2.MetricsRecordBuilder; import org.apache.hadoop.metrics2.MetricsRecordBuilder;
import org.apache.hadoop.metrics2.util.Quantile; import org.apache.hadoop.metrics2.util.Quantile;
import org.apache.hadoop.thirdparty.com.google.common.math.Stats;
import org.junit.Test; import org.junit.Test;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -47,7 +50,7 @@ public class TestMutableMetrics {
private static final Logger LOG = private static final Logger LOG =
LoggerFactory.getLogger(TestMutableMetrics.class); LoggerFactory.getLogger(TestMutableMetrics.class);
private final double EPSILON = 1e-42; private static final double EPSILON = 1e-42;
/** /**
* Test the snapshot method * Test the snapshot method
@ -306,19 +309,56 @@ public class TestMutableMetrics {
/** /**
* Tests that when using {@link MutableStat#add(long, long)}, even with a high * Tests that when using {@link MutableStat#add(long, long)}, even with a high
* sample count, the mean does not lose accuracy. * sample count, the mean does not lose accuracy. This also validates that
* the std dev is correct, assuming samples of equal value.
*/ */
@Test public void testMutableStatWithBulkAdd() { @Test
public void testMutableStatWithBulkAdd() {
List<Long> samples = new ArrayList<>();
for (int i = 0; i < 1000; i++) {
samples.add(1000L);
}
for (int i = 0; i < 1000; i++) {
samples.add(2000L);
}
Stats stats = Stats.of(samples);
for (int bulkSize : new int[] {1, 10, 100, 1000}) {
MetricsRecordBuilder rb = mockMetricsRecordBuilder();
MetricsRegistry registry = new MetricsRegistry("test");
MutableStat stat = registry.newStat("Test", "Test", "Ops", "Val", true);
for (int i = 0; i < samples.size(); i += bulkSize) {
stat.add(bulkSize, samples
.subList(i, i + bulkSize)
.stream()
.mapToLong(Long::longValue)
.sum()
);
}
registry.snapshot(rb, false);
assertCounter("TestNumOps", 2000L, rb);
assertGauge("TestAvgVal", stats.mean(), rb);
assertGauge("TestStdevVal", stats.sampleStandardDeviation(), rb);
}
}
@Test
public void testLargeMutableStatAdd() {
MetricsRecordBuilder rb = mockMetricsRecordBuilder(); MetricsRecordBuilder rb = mockMetricsRecordBuilder();
MetricsRegistry registry = new MetricsRegistry("test"); MetricsRegistry registry = new MetricsRegistry("test");
MutableStat stat = registry.newStat("Test", "Test", "Ops", "Val", false); MutableStat stat = registry.newStat("Test", "Test", "Ops", "Val", true);
stat.add(1000, 1000); long sample = 1000000000000009L;
stat.add(1000, 2000); for (int i = 0; i < 100; i++) {
stat.add(1, sample);
}
registry.snapshot(rb, false); registry.snapshot(rb, false);
assertCounter("TestNumOps", 2000L, rb); assertCounter("TestNumOps", 100L, rb);
assertGauge("TestAvgVal", 1.5, rb); assertGauge("TestAvgVal", (double) sample, rb);
assertGauge("TestStdevVal", 0.0, rb);
} }
/** /**