Added static direct aggregation method to AggregateSummaryStatistics. JIRA: MATH-224.
git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@791687 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
4ce932e3cd
commit
71f41f90fe
|
@ -18,6 +18,8 @@
|
|||
package org.apache.commons.math.stat.descriptive;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Collection;
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* <p>
|
||||
|
@ -199,6 +201,60 @@ public class AggregateSummaryStatistics implements StatisticalSummary,
|
|||
return contributingStatistics;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes aggregate summary statistics. This method can be used to combine statistics
|
||||
* computed over partitions or subsamples - i.e., the StatisticalSummaryValues returned
|
||||
* should contain the same values that would have been obtained by computing a single
|
||||
* StatisticalSummary over the combined dataset.
|
||||
* <p>
|
||||
* Returns null if the collection is empty or null.
|
||||
* </p>
|
||||
*
|
||||
* @param statistics collection of SummaryStatistics to aggregate
|
||||
* @return summary statistics for the combined dataset
|
||||
*/
|
||||
public static StatisticalSummaryValues aggregate(Collection<SummaryStatistics> statistics) {
|
||||
if (statistics == null) {
|
||||
return null;
|
||||
}
|
||||
Iterator<SummaryStatistics> iterator = statistics.iterator();
|
||||
if (!iterator.hasNext()) {
|
||||
return null;
|
||||
}
|
||||
SummaryStatistics current = iterator.next();
|
||||
long n = current.getN();
|
||||
double min = current.getMin();
|
||||
double sum = current.getSum();
|
||||
double max = current.getMax();
|
||||
double m2 = current.getSecondMoment();
|
||||
double mean = current.getMean();
|
||||
while (iterator.hasNext()) {
|
||||
current = iterator.next();
|
||||
if (current.getMin() < min || Double.isNaN(min)) {
|
||||
min = current.getMin();
|
||||
}
|
||||
if (current.getMax() > max || Double.isNaN(max)) {
|
||||
max = current.getMax();
|
||||
}
|
||||
sum += current.getSum();
|
||||
final double oldN = n;
|
||||
final long curN = current.getN();
|
||||
n += curN;
|
||||
final double meanDiff = current.getMean() - mean;
|
||||
mean = sum / n;
|
||||
m2 = m2 + current.getSecondMoment() + meanDiff * meanDiff * oldN * curN / n;
|
||||
}
|
||||
final double variance;
|
||||
if (n == 0) {
|
||||
variance = Double.NaN;
|
||||
} else if (n == 1) {
|
||||
variance = 0d;
|
||||
} else {
|
||||
variance = m2 / (n - 1);
|
||||
}
|
||||
return new StatisticalSummaryValues(mean, variance, n, max, min, sum);
|
||||
}
|
||||
|
||||
/**
|
||||
* A SummaryStatistics that also forwards all values added to it to a second
|
||||
* {@code SummaryStatistics} for aggregation.
|
||||
|
@ -271,6 +327,5 @@ public class AggregateSummaryStatistics implements StatisticalSummary,
|
|||
public int hashCode() {
|
||||
return 123 + super.hashCode() + aggregateStatistics.hashCode();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,13 @@ import junit.framework.Test;
|
|||
import junit.framework.TestCase;
|
||||
import junit.framework.TestSuite;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.ArrayList;
|
||||
|
||||
import org.apache.commons.math.random.RandomData;
|
||||
import org.apache.commons.math.random.RandomDataImpl;
|
||||
import org.apache.commons.math.TestUtils;
|
||||
|
||||
|
||||
/**
|
||||
* Test cases for {@link AggregateSummaryStatistics}
|
||||
|
@ -28,6 +35,18 @@ import junit.framework.TestSuite;
|
|||
*/
|
||||
public class AggregateSummaryStatisticsTest extends TestCase {
|
||||
|
||||
/**
|
||||
* Creates and returns a {@code Test} representing all the test cases in this
|
||||
* class
|
||||
*
|
||||
* @return a {@code Test} representing all the test cases in this class
|
||||
*/
|
||||
public static Test suite() {
|
||||
TestSuite suite = new TestSuite(AggregateSummaryStatisticsTest.class);
|
||||
suite.setName("AggregateSummaryStatistics tests");
|
||||
return suite;
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests the standard aggregation behavior
|
||||
*/
|
||||
|
@ -57,17 +76,108 @@ public class AggregateSummaryStatisticsTest extends TestCase {
|
|||
assertEquals("Wrong number of aggregate values", 8, aggregate.getN());
|
||||
assertEquals("Wrong aggregate sum", 42.0, aggregate.getSum());
|
||||
}
|
||||
|
||||
|
||||
public void testAggregate() throws Exception {
|
||||
|
||||
// Generate a random sample and random partition
|
||||
double[] totalSample = generateSample();
|
||||
double[][] subSamples = generatePartition(totalSample);
|
||||
int nSamples = subSamples.length;
|
||||
|
||||
// Compute combined stats directly
|
||||
SummaryStatistics totalStats = new SummaryStatistics();
|
||||
for (int i = 0; i < totalSample.length; i++) {
|
||||
totalStats.addValue(totalSample[i]);
|
||||
}
|
||||
|
||||
// Now compute subsample stats individually and aggregate
|
||||
SummaryStatistics[] subSampleStats = new SummaryStatistics[nSamples];
|
||||
for (int i = 0; i < nSamples; i++) {
|
||||
subSampleStats[i] = new SummaryStatistics();
|
||||
}
|
||||
Collection<SummaryStatistics> aggregate = new ArrayList<SummaryStatistics>();
|
||||
for (int i = 0; i < nSamples; i++) {
|
||||
for (int j = 0; j < subSamples[i].length; j++) {
|
||||
subSampleStats[i].addValue(subSamples[i][j]);
|
||||
}
|
||||
aggregate.add(subSampleStats[i]);
|
||||
}
|
||||
|
||||
// Compare values
|
||||
StatisticalSummaryValues aggregatedStats = AggregateSummaryStatistics.aggregate(aggregate);
|
||||
assertEquals(totalStats.getSummary(), aggregatedStats, 10E-12);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates and returns a {@code Test} representing all the test cases in this
|
||||
* class
|
||||
*
|
||||
* @return a {@code Test} representing all the test cases in this class
|
||||
* Verifies that two StatisticalSummaryValues report the same values up
|
||||
* to delta, with NaNs, infinities returned in the same spots. For max, min, n, values
|
||||
* have to agree exactly, delta is used only for sum, mean, variance, std dev.
|
||||
*/
|
||||
public static Test suite() {
|
||||
TestSuite suite = new TestSuite(AggregateSummaryStatisticsTest.class);
|
||||
suite.setName("AggregateSummaryStatistics tests");
|
||||
return suite;
|
||||
protected static void assertEquals(StatisticalSummary expected, StatisticalSummaryValues observed, double delta) {
|
||||
TestUtils.assertEquals(expected.getMax(), observed.getMax(), 0);
|
||||
TestUtils.assertEquals(expected.getMin(), observed.getMin(), 0);
|
||||
assertEquals(expected.getN(), observed.getN());
|
||||
TestUtils.assertEquals(expected.getSum(), observed.getSum(), delta);
|
||||
TestUtils.assertEquals(expected.getMean(), observed.getMean(), delta);
|
||||
TestUtils.assertEquals(expected.getStandardDeviation(), observed.getStandardDeviation(), delta);
|
||||
TestUtils.assertEquals(expected.getVariance(), observed.getVariance(), delta);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generates a random sample of double values.
|
||||
* Sample size is random, between 10 and 100 and values are
|
||||
* uniformly distributed over [-100, 100].
|
||||
*
|
||||
* @return array of random double values
|
||||
*/
|
||||
private double[] generateSample() {
|
||||
final RandomData randomData = new RandomDataImpl();
|
||||
final int sampleSize = randomData.nextInt(10,100);
|
||||
double[] out = new double[sampleSize];
|
||||
for (int i = 0; i < out.length; i++) {
|
||||
out[i] = randomData.nextUniform(-100, 100);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a partition of <sample> into up to 5 sequentially selected
|
||||
* subsamples with randomly selected partition points.
|
||||
*
|
||||
* @param sample array to partition
|
||||
* @return rectangular array with rows = subsamples
|
||||
*/
|
||||
private double[][] generatePartition(double[] sample) {
|
||||
final int length = sample.length;
|
||||
final double[][] out = new double[5][];
|
||||
final RandomData randomData = new RandomDataImpl();
|
||||
int cur = 0;
|
||||
int offset = 0;
|
||||
int sampleCount = 0;
|
||||
for (int i = 0; i < 5; i++) {
|
||||
if (cur == length || offset == length) {
|
||||
break;
|
||||
}
|
||||
final int next = (i == 4 || cur == length - 1) ? length - 1 : randomData.nextInt(cur, length - 1);
|
||||
final int subLength = next - cur + 1;
|
||||
out[i] = new double[subLength];
|
||||
System.arraycopy(sample, offset, out[i], 0, subLength);
|
||||
cur = next + 1;
|
||||
sampleCount++;
|
||||
offset += subLength;
|
||||
}
|
||||
if (sampleCount < 5) {
|
||||
double[][] out2 = new double[sampleCount][];
|
||||
for (int j = 0; j < sampleCount; j++) {
|
||||
final int curSize = out[j].length;
|
||||
out2[j] = new double[curSize];
|
||||
System.arraycopy(out[j], 0, out2[j], 0, curSize);
|
||||
}
|
||||
return out2;
|
||||
} else {
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue