diff --git a/src/java/org/apache/commons/math/stat/descriptive/AggregateSummaryStatistics.java b/src/java/org/apache/commons/math/stat/descriptive/AggregateSummaryStatistics.java index 4f7a33755..7dd269fec 100644 --- a/src/java/org/apache/commons/math/stat/descriptive/AggregateSummaryStatistics.java +++ b/src/java/org/apache/commons/math/stat/descriptive/AggregateSummaryStatistics.java @@ -238,7 +238,7 @@ public class AggregateSummaryStatistics implements StatisticalSummary, } sum += current.getSum(); final double oldN = n; - final long curN = current.getN(); + final double curN = current.getN(); n += curN; final double meanDiff = current.getMean() - mean; mean = sum / n; diff --git a/src/test/org/apache/commons/math/stat/descriptive/AggregateSummaryStatisticsTest.java b/src/test/org/apache/commons/math/stat/descriptive/AggregateSummaryStatisticsTest.java index c6a9fccf7..50ee38586 100644 --- a/src/test/org/apache/commons/math/stat/descriptive/AggregateSummaryStatisticsTest.java +++ b/src/test/org/apache/commons/math/stat/descriptive/AggregateSummaryStatisticsTest.java @@ -77,6 +77,15 @@ public class AggregateSummaryStatisticsTest extends TestCase { assertEquals("Wrong aggregate sum", 42.0, aggregate.getSum()); } + /** + * Test aggregate function by randomly generating a dataset of 10-100 values + * from [-100, 100], dividing it into 2-5 partitions, computing stats for each + * partition and comparing the result of aggregate(...) applied to the collection + * of per-partition SummaryStatistics with a single SummaryStatistics computed + * over the full sample. + * + * @throws Exception + */ public void testAggregate() throws Exception { // Generate a random sample and random partition @@ -108,6 +117,64 @@ public class AggregateSummaryStatisticsTest extends TestCase { assertEquals(totalStats.getSummary(), aggregatedStats, 10E-12); } + + public void testAggregateDegenerate() throws Exception { + double[] totalSample = {1, 2, 3, 4, 5}; + double[][] subSamples = {{1}, {2}, {3}, {4}, {5}}; + + // Compute combined stats directly + SummaryStatistics totalStats = new SummaryStatistics(); + for (int i = 0; i < totalSample.length; i++) { + totalStats.addValue(totalSample[i]); + } + + // Now compute subsample stats individually and aggregate + SummaryStatistics[] subSampleStats = new SummaryStatistics[5]; + for (int i = 0; i < 5; i++) { + subSampleStats[i] = new SummaryStatistics(); + } + Collection aggregate = new ArrayList(); + for (int i = 0; i < 5; i++) { + for (int j = 0; j < subSamples[i].length; j++) { + subSampleStats[i].addValue(subSamples[i][j]); + } + aggregate.add(subSampleStats[i]); + } + + // Compare values + StatisticalSummaryValues aggregatedStats = AggregateSummaryStatistics.aggregate(aggregate); + assertEquals(totalStats.getSummary(), aggregatedStats, 10E-12); + } + + public void testAggregateSpecialValues() throws Exception { + double[] totalSample = {Double.POSITIVE_INFINITY, 2, 3, Double.NaN, 5}; + double[][] subSamples = {{Double.POSITIVE_INFINITY, 2}, {3}, {Double.NaN}, {5}}; + + // Compute combined stats directly + SummaryStatistics totalStats = new SummaryStatistics(); + for (int i = 0; i < totalSample.length; i++) { + totalStats.addValue(totalSample[i]); + } + + // Now compute subsample stats individually and aggregate + SummaryStatistics[] subSampleStats = new SummaryStatistics[5]; + for (int i = 0; i < 4; i++) { + subSampleStats[i] = new SummaryStatistics(); + } + Collection aggregate = new ArrayList(); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < subSamples[i].length; j++) { + subSampleStats[i].addValue(subSamples[i][j]); + } + aggregate.add(subSampleStats[i]); + } + + // Compare values + StatisticalSummaryValues aggregatedStats = AggregateSummaryStatistics.aggregate(aggregate); + assertEquals(totalStats.getSummary(), aggregatedStats, 10E-12); + + } + /** * Verifies that two StatisticalSummaryValues report the same values up * to delta, with NaNs, infinities returned in the same spots. For max, min, n, values