[MATH-837] Support aggregation of any kind of StatisticalSummary in AggregateSummaryStatistics.

This commit is contained in:
Thomas Neidhart 2015-10-19 21:41:16 +02:00
parent 1b5925b563
commit e14d9ce8e3
3 changed files with 48 additions and 12 deletions

View File

@ -54,6 +54,10 @@ If the output is not quite correct, check for invisible trailing spaces!
</release>
<release version="4.0" date="XXXX-XX-XX" description="">
<action dev="tn" type="add" issue="MATH-837"> <!-- backported to 3.6 -->
"AggregateSummaryStatistics" can now aggregate any kind of
"StatisticalSummary".
</action>
<action dev="erans" type="fix" issue="MATH-1279"> <!-- backported to 3.6 -->
Check precondition (class "o.a.c.m.random.EmpiricalDistribution").
</action>

View File

@ -309,20 +309,21 @@ public class AggregateSummaryStatistics implements StatisticalSummary,
* @param statistics collection of SummaryStatistics to aggregate
* @return summary statistics for the combined dataset
*/
public static StatisticalSummaryValues aggregate(Collection<SummaryStatistics> statistics) {
public static StatisticalSummaryValues aggregate(Collection<? extends StatisticalSummary> statistics) {
if (statistics == null) {
return null;
}
Iterator<SummaryStatistics> iterator = statistics.iterator();
Iterator<? extends StatisticalSummary> iterator = statistics.iterator();
if (!iterator.hasNext()) {
return null;
}
SummaryStatistics current = iterator.next();
StatisticalSummary current = iterator.next();
long n = current.getN();
double min = current.getMin();
double sum = current.getSum();
double max = current.getMax();
double m2 = current.getSecondMoment();
double var = current.getVariance();
double m2 = var * (n - 1d);
double mean = current.getMean();
while (iterator.hasNext()) {
current = iterator.next();
@ -338,7 +339,8 @@ public class AggregateSummaryStatistics implements StatisticalSummary,
n += curN;
final double meanDiff = current.getMean() - mean;
mean = sum / n;
m2 = m2 + current.getSecondMoment() + meanDiff * meanDiff * oldN * curN / n;
final double curM2 = current.getVariance() * (curN - 1d);
m2 = m2 + curM2 + meanDiff * meanDiff * oldN * curN / n;
}
final double variance;
if (n == 0) {

View File

@ -25,10 +25,6 @@ import org.apache.commons.math4.distribution.IntegerDistribution;
import org.apache.commons.math4.distribution.RealDistribution;
import org.apache.commons.math4.distribution.UniformIntegerDistribution;
import org.apache.commons.math4.distribution.UniformRealDistribution;
import org.apache.commons.math4.stat.descriptive.AggregateSummaryStatistics;
import org.apache.commons.math4.stat.descriptive.StatisticalSummary;
import org.apache.commons.math4.stat.descriptive.StatisticalSummaryValues;
import org.apache.commons.math4.stat.descriptive.SummaryStatistics;
import org.apache.commons.math4.util.Precision;
import org.junit.Assert;
import org.junit.Test;
@ -36,7 +32,6 @@ import org.junit.Test;
/**
* Test cases for {@link AggregateSummaryStatistics}
*
*/
public class AggregateSummaryStatisticsTest {
@ -132,7 +127,6 @@ public class AggregateSummaryStatisticsTest {
* partition and comparing the result of aggregate(...) applied to the collection
* of per-partition SummaryStatistics with a single SummaryStatistics computed
* over the full sample.
*
*/
@Test
public void testAggregate() {
@ -166,6 +160,42 @@ public class AggregateSummaryStatisticsTest {
assertEquals(totalStats.getSummary(), aggregatedStats, 10E-12);
}
/**
* Similar to {@link #testAggregate()} but operating on
* {@link StatisticalSummary} instead.
*/
@Test
public void testAggregateStatisticalSummary() {
// Generate a random sample and random partition
double[] totalSample = generateSample();
double[][] subSamples = generatePartition(totalSample);
int nSamples = subSamples.length;
// Compute combined stats directly
SummaryStatistics totalStats = new SummaryStatistics();
for (int i = 0; i < totalSample.length; i++) {
totalStats.addValue(totalSample[i]);
}
// Now compute subsample stats individually and aggregate
SummaryStatistics[] subSampleStats = new SummaryStatistics[nSamples];
for (int i = 0; i < nSamples; i++) {
subSampleStats[i] = new SummaryStatistics();
}
Collection<StatisticalSummary> aggregate = new ArrayList<StatisticalSummary>();
for (int i = 0; i < nSamples; i++) {
for (int j = 0; j < subSamples[i].length; j++) {
subSampleStats[i].addValue(subSamples[i][j]);
}
aggregate.add(subSampleStats[i].getSummary());
}
// Compare values
StatisticalSummary aggregatedStats = AggregateSummaryStatistics.aggregate(aggregate);
assertEquals(totalStats.getSummary(), aggregatedStats, 10E-12);
}
@Test
public void testAggregateDegenerate() {
@ -269,7 +299,7 @@ public class AggregateSummaryStatisticsTest {
final double[][] out = new double[5][];
int cur = 0; // beginning of current partition segment
int offset = 0; // end of current partition segment
int sampleCount = 0; // number of segments defined
int sampleCount = 0; // number of segments defined
for (int i = 0; i < 5; i++) {
if (cur == length || offset == length) {
break;