diff --git a/src/java/org/apache/commons/math/stat/descriptive/moment/Mean.java b/src/java/org/apache/commons/math/stat/descriptive/moment/Mean.java index 13964f87f..758e643be 100644 --- a/src/java/org/apache/commons/math/stat/descriptive/moment/Mean.java +++ b/src/java/org/apache/commons/math/stat/descriptive/moment/Mean.java @@ -22,24 +22,32 @@ import org.apache.commons.math.stat.descriptive.AbstractStorelessUnivariateStati import org.apache.commons.math.stat.descriptive.summary.Sum; /** - * Returns the arithmetic mean of the available values. Uses the definitional - * formula: + *

Computes the arithmetic mean of a set of values. Uses the definitional + * formula:

*

* mean = sum(x_i) / n - *

- * where n is the number of observations. - *

- * The value of the statistic is computed using the following recursive - * updating algorithm: - *

+ *

+ *

where n is the number of observations. + *

+ *

When {@link #increment(double)} is used to add data incrementally from a + * stream of (unstored) values, the value of the statistic that + * {@link #getResult()} returns is computed using the following recursive + * updating algorithm:

*
    *
  1. Initialize m = the first value
  2. *
  3. For each additional value, update using
    * m = m + (new value - m) / (number of observations)
  4. *
+ *

If {@link #evaluate(double[])} is used to compute the mean of an array + * of stored values, a two-pass, corrected algorithm is used, starting with + * the definitional formula computed using the array of stored values and then + * correcting this by adding the mean deviation of the data values from the + * arithmetic mean. See, e.g. "Comparison of Several Algorithms for Computing + * Sample Means and Variances," Robert F. Ling, Journal of the American + * Statistical Association, Vol. 69, No. 348 (Dec., 1974), pp. 859-866.

*

* Returns Double.NaN if the dataset is empty. - *

+ *

* Note that this implementation is not synchronized. If * multiple threads access an instance of this class concurrently, and at least * one of the threads invokes the increment() or @@ -131,7 +139,17 @@ public class Mean extends AbstractStorelessUnivariateStatistic public double evaluate(final double[] values,final int begin, final int length) { if (test(values, begin, length)) { Sum sum = new Sum(); - return sum.evaluate(values, begin, length) / ((double) length); + double sampleSize = (double) length; + + // Compute initial estimate using definitional formula + double xbar = sum.evaluate(values, begin, length) / sampleSize; + + // Compute correction factor in second pass + double correction = 0; + for (int i = begin; i < begin + length; i++) { + correction += (values[i] - xbar); + } + return xbar + (correction/sampleSize); } return Double.NaN; } diff --git a/src/test/org/apache/commons/math/stat/CertifiedDataTest.java b/src/test/org/apache/commons/math/stat/CertifiedDataTest.java index 3c23439c0..7ceecc156 100644 --- a/src/test/org/apache/commons/math/stat/CertifiedDataTest.java +++ b/src/test/org/apache/commons/math/stat/CertifiedDataTest.java @@ -61,57 +61,59 @@ public class CertifiedDataTest extends TestCase { } /** - * Test StorelessDescriptiveStatistics + * Test SummaryStatistics - implementations that do not store the data + * and use single pass algorithms to compute statistics */ - public void testUnivariateImpl() throws Exception { + public void testSummaryStatistics() throws Exception { SummaryStatistics u = SummaryStatistics.newInstance(SummaryStatisticsImpl.class); loadStats("data/PiDigits.txt", u); - assertEquals("PiDigits: std", std, u.getStandardDeviation(), .0000000000001); - assertEquals("PiDigits: mean", mean, u.getMean(), .0000000000001); + assertEquals("PiDigits: std", std, u.getStandardDeviation(), 1E-13); + assertEquals("PiDigits: mean", mean, u.getMean(), 1E-13); loadStats("data/Mavro.txt", u); - assertEquals("Mavro: std", std, u.getStandardDeviation(), .00000000000001); - assertEquals("Mavro: mean", mean, u.getMean(), .00000000000001); + assertEquals("Mavro: std", std, u.getStandardDeviation(), 1E-14); + assertEquals("Mavro: mean", mean, u.getMean(), 1E-14); - //loadStats("data/Michelso.txt"); - //assertEquals("Michelso: std", std, u.getStandardDeviation(), .00000000000001); - //assertEquals("Michelso: mean", mean, u.getMean(), .00000000000001); + loadStats("data/Michelso.txt", u); + assertEquals("Michelso: std", std, u.getStandardDeviation(), 1E-13); + assertEquals("Michelso: mean", mean, u.getMean(), 1E-13); loadStats("data/NumAcc1.txt", u); - assertEquals("NumAcc1: std", std, u.getStandardDeviation(), .00000000000001); - assertEquals("NumAcc1: mean", mean, u.getMean(), .00000000000001); + assertEquals("NumAcc1: std", std, u.getStandardDeviation(), 1E-14); + assertEquals("NumAcc1: mean", mean, u.getMean(), 1E-14); - //loadStats("data/NumAcc2.txt"); - //assertEquals("NumAcc2: std", std, u.getStandardDeviation(), .000000001); - //assertEquals("NumAcc2: mean", mean, u.getMean(), .00000000000001); + loadStats("data/NumAcc2.txt", u); + assertEquals("NumAcc2: std", std, u.getStandardDeviation(), 1E-14); + assertEquals("NumAcc2: mean", mean, u.getMean(), 1E-14); } /** - * Test StorelessDescriptiveStatistics + * Test DescriptiveStatistics - implementations that store full array of + * values and execute multi-pass algorithms */ - public void testStoredUnivariateImpl() throws Exception { + public void testDescriptiveStatistics() throws Exception { DescriptiveStatistics u = DescriptiveStatistics.newInstance(); loadStats("data/PiDigits.txt", u); - assertEquals("PiDigits: std", std, u.getStandardDeviation(), .0000000000001); - assertEquals("PiDigits: mean", mean, u.getMean(), .0000000000001); + assertEquals("PiDigits: std", std, u.getStandardDeviation(), 1E-14); + assertEquals("PiDigits: mean", mean, u.getMean(), 1E-14); loadStats("data/Mavro.txt", u); - assertEquals("Mavro: std", std, u.getStandardDeviation(), .00000000000001); - assertEquals("Mavro: mean", mean, u.getMean(), .00000000000001); + assertEquals("Mavro: std", std, u.getStandardDeviation(), 1E-14); + assertEquals("Mavro: mean", mean, u.getMean(), 1E-14); - //loadStats("data/Michelso.txt"); - //assertEquals("Michelso: std", std, u.getStandardDeviation(), .00000000000001); - //assertEquals("Michelso: mean", mean, u.getMean(), .00000000000001); + loadStats("data/Michelso.txt", u); + assertEquals("Michelso: std", std, u.getStandardDeviation(), 1E-14); + assertEquals("Michelso: mean", mean, u.getMean(), 1E-14); loadStats("data/NumAcc1.txt", u); - assertEquals("NumAcc1: std", std, u.getStandardDeviation(), .00000000000001); - assertEquals("NumAcc1: mean", mean, u.getMean(), .00000000000001); + assertEquals("NumAcc1: std", std, u.getStandardDeviation(), 1E-14); + assertEquals("NumAcc1: mean", mean, u.getMean(), 1E-14); - //loadStats("data/NumAcc2.txt"); - //assertEquals("NumAcc2: std", std, u.getStandardDeviation(), .000000001); - //assertEquals("NumAcc2: mean", mean, u.getMean(), .00000000000001); + loadStats("data/NumAcc2.txt", u); + assertEquals("NumAcc2: std", std, u.getStandardDeviation(), 1E-14); + assertEquals("NumAcc2: mean", mean, u.getMean(), 1E-14); } /** diff --git a/xdocs/changes.xml b/xdocs/changes.xml index 3b9fc458c..577591ac7 100644 --- a/xdocs/changes.xml +++ b/xdocs/changes.xml @@ -111,7 +111,12 @@ Commons Math Release Notes and SummaryStatistics concrete classes. Pushed implementations up from DescriptiveStatisticsImpl, SummaryStatisticsImpl. Made implementations of statistics configurable via setters. - + + + Changed Mean.evaluate() to use a two-pass algorithm, improving accuracy + by exploiting the the fact that this method has access to the full + array of data values. +