HADOOP-18426. Use weighted calculation for MutableStat mean/variance to fix accuracy. (#4844). Contributed by Erik Krogen.
Co-authored-by: Shuyan Zhang <zqingchai@gmail.com> Signed-off-by: He Xiaoqiao <hexiaoqiao@apache.org>
This commit is contained in:
parent
cc41ad63f9
commit
c664f953c9
|
@ -27,33 +27,29 @@ import org.apache.hadoop.classification.InterfaceAudience;
|
||||||
public class SampleStat {
|
public class SampleStat {
|
||||||
private final MinMax minmax = new MinMax();
|
private final MinMax minmax = new MinMax();
|
||||||
private long numSamples = 0;
|
private long numSamples = 0;
|
||||||
private double a0, a1, s0, s1, total;
|
private double mean, s;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct a new running sample stat
|
* Construct a new running sample stat
|
||||||
*/
|
*/
|
||||||
public SampleStat() {
|
public SampleStat() {
|
||||||
a0 = s0 = 0.0;
|
mean = 0.0;
|
||||||
total = 0.0;
|
s = 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void reset() {
|
public void reset() {
|
||||||
numSamples = 0;
|
numSamples = 0;
|
||||||
a0 = s0 = 0.0;
|
mean = 0.0;
|
||||||
total = 0.0;
|
s = 0.0;
|
||||||
minmax.reset();
|
minmax.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
// We want to reuse the object, sometimes.
|
// We want to reuse the object, sometimes.
|
||||||
void reset(long numSamples, double a0, double a1, double s0, double s1,
|
void reset(long numSamples1, double mean1, double s1, MinMax minmax1) {
|
||||||
double total, MinMax minmax) {
|
numSamples = numSamples1;
|
||||||
this.numSamples = numSamples;
|
mean = mean1;
|
||||||
this.a0 = a0;
|
s = s1;
|
||||||
this.a1 = a1;
|
minmax.reset(minmax1);
|
||||||
this.s0 = s0;
|
|
||||||
this.s1 = s1;
|
|
||||||
this.total = total;
|
|
||||||
this.minmax.reset(minmax);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -61,7 +57,7 @@ public class SampleStat {
|
||||||
* @param other the destination to hold our values
|
* @param other the destination to hold our values
|
||||||
*/
|
*/
|
||||||
public void copyTo(SampleStat other) {
|
public void copyTo(SampleStat other) {
|
||||||
other.reset(numSamples, a0, a1, s0, s1, total, minmax);
|
other.reset(numSamples, mean, s, minmax);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -78,24 +74,22 @@ public class SampleStat {
|
||||||
* Add some sample and a partial sum to the running stat.
|
* Add some sample and a partial sum to the running stat.
|
||||||
* Note, min/max is not evaluated using this method.
|
* Note, min/max is not evaluated using this method.
|
||||||
* @param nSamples number of samples
|
* @param nSamples number of samples
|
||||||
* @param x the partial sum
|
* @param xTotal the partial sum
|
||||||
* @return self
|
* @return self
|
||||||
*/
|
*/
|
||||||
public SampleStat add(long nSamples, double x) {
|
public SampleStat add(long nSamples, double xTotal) {
|
||||||
numSamples += nSamples;
|
numSamples += nSamples;
|
||||||
total += x;
|
|
||||||
|
|
||||||
if (numSamples == 1) {
|
// use the weighted incremental version of Welford's algorithm to get
|
||||||
a0 = a1 = x;
|
// numerical stability while treating the samples as being weighted
|
||||||
s0 = 0.0;
|
// by nSamples
|
||||||
}
|
// see https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
|
||||||
else {
|
|
||||||
// The Welford method for numerical stability
|
double x = xTotal / nSamples;
|
||||||
a1 = a0 + (x - a0) / numSamples;
|
double meanOld = mean;
|
||||||
s1 = s0 + (x - a0) * (x - a1);
|
|
||||||
a0 = a1;
|
mean += ((double) nSamples / numSamples) * (x - meanOld);
|
||||||
s0 = s1;
|
s += nSamples * (x - meanOld) * (x - mean);
|
||||||
}
|
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,21 +104,21 @@ public class SampleStat {
|
||||||
* @return the total of all samples added
|
* @return the total of all samples added
|
||||||
*/
|
*/
|
||||||
public double total() {
|
public double total() {
|
||||||
return total;
|
return mean * numSamples;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return the arithmetic mean of the samples
|
* @return the arithmetic mean of the samples
|
||||||
*/
|
*/
|
||||||
public double mean() {
|
public double mean() {
|
||||||
return numSamples > 0 ? (total / numSamples) : 0.0;
|
return numSamples > 0 ? mean : 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return the variance of the samples
|
* @return the variance of the samples
|
||||||
*/
|
*/
|
||||||
public double variance() {
|
public double variance() {
|
||||||
return numSamples > 1 ? s1 / (numSamples - 1) : 0.0;
|
return numSamples > 1 ? s / (numSamples - 1) : 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -29,6 +29,8 @@ import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Map.Entry;
|
import java.util.Map.Entry;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
@ -36,6 +38,7 @@ import java.util.concurrent.CountDownLatch;
|
||||||
|
|
||||||
import org.apache.hadoop.metrics2.MetricsRecordBuilder;
|
import org.apache.hadoop.metrics2.MetricsRecordBuilder;
|
||||||
import org.apache.hadoop.metrics2.util.Quantile;
|
import org.apache.hadoop.metrics2.util.Quantile;
|
||||||
|
import org.apache.hadoop.thirdparty.com.google.common.math.Stats;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
@ -47,7 +50,7 @@ public class TestMutableMetrics {
|
||||||
|
|
||||||
private static final Logger LOG =
|
private static final Logger LOG =
|
||||||
LoggerFactory.getLogger(TestMutableMetrics.class);
|
LoggerFactory.getLogger(TestMutableMetrics.class);
|
||||||
private final double EPSILON = 1e-42;
|
private static final double EPSILON = 1e-42;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Test the snapshot method
|
* Test the snapshot method
|
||||||
|
@ -306,19 +309,56 @@ public class TestMutableMetrics {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests that when using {@link MutableStat#add(long, long)}, even with a high
|
* Tests that when using {@link MutableStat#add(long, long)}, even with a high
|
||||||
* sample count, the mean does not lose accuracy.
|
* sample count, the mean does not lose accuracy. This also validates that
|
||||||
|
* the std dev is correct, assuming samples of equal value.
|
||||||
*/
|
*/
|
||||||
@Test public void testMutableStatWithBulkAdd() {
|
@Test
|
||||||
|
public void testMutableStatWithBulkAdd() {
|
||||||
|
List<Long> samples = new ArrayList<>();
|
||||||
|
for (int i = 0; i < 1000; i++) {
|
||||||
|
samples.add(1000L);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 1000; i++) {
|
||||||
|
samples.add(2000L);
|
||||||
|
}
|
||||||
|
Stats stats = Stats.of(samples);
|
||||||
|
|
||||||
|
for (int bulkSize : new int[] {1, 10, 100, 1000}) {
|
||||||
|
MetricsRecordBuilder rb = mockMetricsRecordBuilder();
|
||||||
|
MetricsRegistry registry = new MetricsRegistry("test");
|
||||||
|
MutableStat stat = registry.newStat("Test", "Test", "Ops", "Val", true);
|
||||||
|
|
||||||
|
for (int i = 0; i < samples.size(); i += bulkSize) {
|
||||||
|
stat.add(bulkSize, samples
|
||||||
|
.subList(i, i + bulkSize)
|
||||||
|
.stream()
|
||||||
|
.mapToLong(Long::longValue)
|
||||||
|
.sum()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
registry.snapshot(rb, false);
|
||||||
|
|
||||||
|
assertCounter("TestNumOps", 2000L, rb);
|
||||||
|
assertGauge("TestAvgVal", stats.mean(), rb);
|
||||||
|
assertGauge("TestStdevVal", stats.sampleStandardDeviation(), rb);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLargeMutableStatAdd() {
|
||||||
MetricsRecordBuilder rb = mockMetricsRecordBuilder();
|
MetricsRecordBuilder rb = mockMetricsRecordBuilder();
|
||||||
MetricsRegistry registry = new MetricsRegistry("test");
|
MetricsRegistry registry = new MetricsRegistry("test");
|
||||||
MutableStat stat = registry.newStat("Test", "Test", "Ops", "Val", false);
|
MutableStat stat = registry.newStat("Test", "Test", "Ops", "Val", true);
|
||||||
|
|
||||||
stat.add(1000, 1000);
|
long sample = 1000000000000009L;
|
||||||
stat.add(1000, 2000);
|
for (int i = 0; i < 100; i++) {
|
||||||
|
stat.add(1, sample);
|
||||||
|
}
|
||||||
registry.snapshot(rb, false);
|
registry.snapshot(rb, false);
|
||||||
|
|
||||||
assertCounter("TestNumOps", 2000L, rb);
|
assertCounter("TestNumOps", 100L, rb);
|
||||||
assertGauge("TestAvgVal", 1.5, rb);
|
assertGauge("TestAvgVal", (double) sample, rb);
|
||||||
|
assertGauge("TestStdevVal", 0.0, rb);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Loading…
Reference in New Issue