[MATH-837] Support aggregation of any kind of StatisticalSummary in AggregateSummaryStatistics.
This commit is contained in:
parent
1b5925b563
commit
e14d9ce8e3
|
@ -54,6 +54,10 @@ If the output is not quite correct, check for invisible trailing spaces!
|
|||
</release>
|
||||
|
||||
<release version="4.0" date="XXXX-XX-XX" description="">
|
||||
<action dev="tn" type="add" issue="MATH-837"> <!-- backported to 3.6 -->
|
||||
"AggregateSummaryStatistics" can now aggregate any kind of
|
||||
"StatisticalSummary".
|
||||
</action>
|
||||
<action dev="erans" type="fix" issue="MATH-1279"> <!-- backported to 3.6 -->
|
||||
Check precondition (class "o.a.c.m.random.EmpiricalDistribution").
|
||||
</action>
|
||||
|
|
|
@ -309,20 +309,21 @@ public class AggregateSummaryStatistics implements StatisticalSummary,
|
|||
* @param statistics collection of SummaryStatistics to aggregate
|
||||
* @return summary statistics for the combined dataset
|
||||
*/
|
||||
public static StatisticalSummaryValues aggregate(Collection<SummaryStatistics> statistics) {
|
||||
public static StatisticalSummaryValues aggregate(Collection<? extends StatisticalSummary> statistics) {
|
||||
if (statistics == null) {
|
||||
return null;
|
||||
}
|
||||
Iterator<SummaryStatistics> iterator = statistics.iterator();
|
||||
Iterator<? extends StatisticalSummary> iterator = statistics.iterator();
|
||||
if (!iterator.hasNext()) {
|
||||
return null;
|
||||
}
|
||||
SummaryStatistics current = iterator.next();
|
||||
StatisticalSummary current = iterator.next();
|
||||
long n = current.getN();
|
||||
double min = current.getMin();
|
||||
double sum = current.getSum();
|
||||
double max = current.getMax();
|
||||
double m2 = current.getSecondMoment();
|
||||
double var = current.getVariance();
|
||||
double m2 = var * (n - 1d);
|
||||
double mean = current.getMean();
|
||||
while (iterator.hasNext()) {
|
||||
current = iterator.next();
|
||||
|
@ -338,7 +339,8 @@ public class AggregateSummaryStatistics implements StatisticalSummary,
|
|||
n += curN;
|
||||
final double meanDiff = current.getMean() - mean;
|
||||
mean = sum / n;
|
||||
m2 = m2 + current.getSecondMoment() + meanDiff * meanDiff * oldN * curN / n;
|
||||
final double curM2 = current.getVariance() * (curN - 1d);
|
||||
m2 = m2 + curM2 + meanDiff * meanDiff * oldN * curN / n;
|
||||
}
|
||||
final double variance;
|
||||
if (n == 0) {
|
||||
|
|
|
@ -25,10 +25,6 @@ import org.apache.commons.math4.distribution.IntegerDistribution;
|
|||
import org.apache.commons.math4.distribution.RealDistribution;
|
||||
import org.apache.commons.math4.distribution.UniformIntegerDistribution;
|
||||
import org.apache.commons.math4.distribution.UniformRealDistribution;
|
||||
import org.apache.commons.math4.stat.descriptive.AggregateSummaryStatistics;
|
||||
import org.apache.commons.math4.stat.descriptive.StatisticalSummary;
|
||||
import org.apache.commons.math4.stat.descriptive.StatisticalSummaryValues;
|
||||
import org.apache.commons.math4.stat.descriptive.SummaryStatistics;
|
||||
import org.apache.commons.math4.util.Precision;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
@ -36,7 +32,6 @@ import org.junit.Test;
|
|||
|
||||
/**
|
||||
* Test cases for {@link AggregateSummaryStatistics}
|
||||
*
|
||||
*/
|
||||
public class AggregateSummaryStatisticsTest {
|
||||
|
||||
|
@ -132,7 +127,6 @@ public class AggregateSummaryStatisticsTest {
|
|||
* partition and comparing the result of aggregate(...) applied to the collection
|
||||
* of per-partition SummaryStatistics with a single SummaryStatistics computed
|
||||
* over the full sample.
|
||||
*
|
||||
*/
|
||||
@Test
|
||||
public void testAggregate() {
|
||||
|
@ -166,6 +160,42 @@ public class AggregateSummaryStatisticsTest {
|
|||
assertEquals(totalStats.getSummary(), aggregatedStats, 10E-12);
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar to {@link #testAggregate()} but operating on
|
||||
* {@link StatisticalSummary} instead.
|
||||
*/
|
||||
@Test
|
||||
public void testAggregateStatisticalSummary() {
|
||||
|
||||
// Generate a random sample and random partition
|
||||
double[] totalSample = generateSample();
|
||||
double[][] subSamples = generatePartition(totalSample);
|
||||
int nSamples = subSamples.length;
|
||||
|
||||
// Compute combined stats directly
|
||||
SummaryStatistics totalStats = new SummaryStatistics();
|
||||
for (int i = 0; i < totalSample.length; i++) {
|
||||
totalStats.addValue(totalSample[i]);
|
||||
}
|
||||
|
||||
// Now compute subsample stats individually and aggregate
|
||||
SummaryStatistics[] subSampleStats = new SummaryStatistics[nSamples];
|
||||
for (int i = 0; i < nSamples; i++) {
|
||||
subSampleStats[i] = new SummaryStatistics();
|
||||
}
|
||||
Collection<StatisticalSummary> aggregate = new ArrayList<StatisticalSummary>();
|
||||
for (int i = 0; i < nSamples; i++) {
|
||||
for (int j = 0; j < subSamples[i].length; j++) {
|
||||
subSampleStats[i].addValue(subSamples[i][j]);
|
||||
}
|
||||
aggregate.add(subSampleStats[i].getSummary());
|
||||
}
|
||||
|
||||
// Compare values
|
||||
StatisticalSummary aggregatedStats = AggregateSummaryStatistics.aggregate(aggregate);
|
||||
assertEquals(totalStats.getSummary(), aggregatedStats, 10E-12);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testAggregateDegenerate() {
|
||||
|
@ -269,7 +299,7 @@ public class AggregateSummaryStatisticsTest {
|
|||
final double[][] out = new double[5][];
|
||||
int cur = 0; // beginning of current partition segment
|
||||
int offset = 0; // end of current partition segment
|
||||
int sampleCount = 0; // number of segments defined
|
||||
int sampleCount = 0; // number of segments defined
|
||||
for (int i = 0; i < 5; i++) {
|
||||
if (cur == length || offset == length) {
|
||||
break;
|
||||
|
|
Loading…
Reference in New Issue