Added static direct aggregation method to AggregateSummaryStatistics. JIRA: MATH-224.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@791687 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Phil Steitz 2009-07-07 03:03:56 +00:00
parent 4ce932e3cd
commit 71f41f90fe
2 changed files with 175 additions and 10 deletions

View File

@ -18,6 +18,8 @@
package org.apache.commons.math.stat.descriptive;
import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
/**
* <p>
@ -199,6 +201,60 @@ public class AggregateSummaryStatistics implements StatisticalSummary,
return contributingStatistics;
}
/**
* Computes aggregate summary statistics. This method can be used to combine statistics
* computed over partitions or subsamples - i.e., the StatisticalSummaryValues returned
* should contain the same values that would have been obtained by computing a single
* StatisticalSummary over the combined dataset.
* <p>
* Returns null if the collection is empty or null.
* </p>
*
* @param statistics collection of SummaryStatistics to aggregate
* @return summary statistics for the combined dataset
*/
public static StatisticalSummaryValues aggregate(Collection<SummaryStatistics> statistics) {
if (statistics == null) {
return null;
}
Iterator<SummaryStatistics> iterator = statistics.iterator();
if (!iterator.hasNext()) {
return null;
}
SummaryStatistics current = iterator.next();
long n = current.getN();
double min = current.getMin();
double sum = current.getSum();
double max = current.getMax();
double m2 = current.getSecondMoment();
double mean = current.getMean();
while (iterator.hasNext()) {
current = iterator.next();
if (current.getMin() < min || Double.isNaN(min)) {
min = current.getMin();
}
if (current.getMax() > max || Double.isNaN(max)) {
max = current.getMax();
}
sum += current.getSum();
final double oldN = n;
final long curN = current.getN();
n += curN;
final double meanDiff = current.getMean() - mean;
mean = sum / n;
m2 = m2 + current.getSecondMoment() + meanDiff * meanDiff * oldN * curN / n;
}
final double variance;
if (n == 0) {
variance = Double.NaN;
} else if (n == 1) {
variance = 0d;
} else {
variance = m2 / (n - 1);
}
return new StatisticalSummaryValues(mean, variance, n, max, min, sum);
}
/**
* A SummaryStatistics that also forwards all values added to it to a second
* {@code SummaryStatistics} for aggregation.
@ -271,6 +327,5 @@ public class AggregateSummaryStatistics implements StatisticalSummary,
public int hashCode() {
return 123 + super.hashCode() + aggregateStatistics.hashCode();
}
}
}

View File

@ -21,6 +21,13 @@ import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import java.util.Collection;
import java.util.ArrayList;
import org.apache.commons.math.random.RandomData;
import org.apache.commons.math.random.RandomDataImpl;
import org.apache.commons.math.TestUtils;
/**
* Test cases for {@link AggregateSummaryStatistics}
@ -28,6 +35,18 @@ import junit.framework.TestSuite;
*/
public class AggregateSummaryStatisticsTest extends TestCase {
/**
* Creates and returns a {@code Test} representing all the test cases in this
* class
*
* @return a {@code Test} representing all the test cases in this class
*/
public static Test suite() {
TestSuite suite = new TestSuite(AggregateSummaryStatisticsTest.class);
suite.setName("AggregateSummaryStatistics tests");
return suite;
}
/**
* Tests the standard aggregation behavior
*/
@ -57,17 +76,108 @@ public class AggregateSummaryStatisticsTest extends TestCase {
assertEquals("Wrong number of aggregate values", 8, aggregate.getN());
assertEquals("Wrong aggregate sum", 42.0, aggregate.getSum());
}
public void testAggregate() throws Exception {
// Generate a random sample and random partition
double[] totalSample = generateSample();
double[][] subSamples = generatePartition(totalSample);
int nSamples = subSamples.length;
// Compute combined stats directly
SummaryStatistics totalStats = new SummaryStatistics();
for (int i = 0; i < totalSample.length; i++) {
totalStats.addValue(totalSample[i]);
}
// Now compute subsample stats individually and aggregate
SummaryStatistics[] subSampleStats = new SummaryStatistics[nSamples];
for (int i = 0; i < nSamples; i++) {
subSampleStats[i] = new SummaryStatistics();
}
Collection<SummaryStatistics> aggregate = new ArrayList<SummaryStatistics>();
for (int i = 0; i < nSamples; i++) {
for (int j = 0; j < subSamples[i].length; j++) {
subSampleStats[i].addValue(subSamples[i][j]);
}
aggregate.add(subSampleStats[i]);
}
// Compare values
StatisticalSummaryValues aggregatedStats = AggregateSummaryStatistics.aggregate(aggregate);
assertEquals(totalStats.getSummary(), aggregatedStats, 10E-12);
}
/**
* Creates and returns a {@code Test} representing all the test cases in this
* class
*
* @return a {@code Test} representing all the test cases in this class
* Verifies that two StatisticalSummaryValues report the same values up
* to delta, with NaNs, infinities returned in the same spots. For max, min, n, values
* have to agree exactly, delta is used only for sum, mean, variance, std dev.
*/
public static Test suite() {
TestSuite suite = new TestSuite(AggregateSummaryStatisticsTest.class);
suite.setName("AggregateSummaryStatistics tests");
return suite;
protected static void assertEquals(StatisticalSummary expected, StatisticalSummaryValues observed, double delta) {
TestUtils.assertEquals(expected.getMax(), observed.getMax(), 0);
TestUtils.assertEquals(expected.getMin(), observed.getMin(), 0);
assertEquals(expected.getN(), observed.getN());
TestUtils.assertEquals(expected.getSum(), observed.getSum(), delta);
TestUtils.assertEquals(expected.getMean(), observed.getMean(), delta);
TestUtils.assertEquals(expected.getStandardDeviation(), observed.getStandardDeviation(), delta);
TestUtils.assertEquals(expected.getVariance(), observed.getVariance(), delta);
}
/**
* Generates a random sample of double values.
* Sample size is random, between 10 and 100 and values are
* uniformly distributed over [-100, 100].
*
* @return array of random double values
*/
private double[] generateSample() {
final RandomData randomData = new RandomDataImpl();
final int sampleSize = randomData.nextInt(10,100);
double[] out = new double[sampleSize];
for (int i = 0; i < out.length; i++) {
out[i] = randomData.nextUniform(-100, 100);
}
return out;
}
/**
* Generates a partition of <sample> into up to 5 sequentially selected
* subsamples with randomly selected partition points.
*
* @param sample array to partition
* @return rectangular array with rows = subsamples
*/
private double[][] generatePartition(double[] sample) {
final int length = sample.length;
final double[][] out = new double[5][];
final RandomData randomData = new RandomDataImpl();
int cur = 0;
int offset = 0;
int sampleCount = 0;
for (int i = 0; i < 5; i++) {
if (cur == length || offset == length) {
break;
}
final int next = (i == 4 || cur == length - 1) ? length - 1 : randomData.nextInt(cur, length - 1);
final int subLength = next - cur + 1;
out[i] = new double[subLength];
System.arraycopy(sample, offset, out[i], 0, subLength);
cur = next + 1;
sampleCount++;
offset += subLength;
}
if (sampleCount < 5) {
double[][] out2 = new double[sampleCount][];
for (int j = 0; j < sampleCount; j++) {
final int curSize = out[j].length;
out2[j] = new double[curSize];
System.arraycopy(out[j], 0, out2[j], 0, curSize);
}
return out2;
} else {
return out;
}
}
}