Unit tests for GammaDistribution, based on reference data generated with

Maxima. Solves MATH-753.


git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1339014 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Sebastien Brisard 2012-05-16 05:36:40 +00:00
parent 4b8ab11cd5
commit 282f7175a4
1 changed files with 159 additions and 0 deletions

View File

@ -17,7 +17,15 @@
package org.apache.commons.math3.distribution;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.special.Gamma;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
@ -167,4 +175,155 @@ public class GammaDistributionTest extends RealDistributionAbstractTest {
Assert.assertEquals(dist.getNumericalMean(), 1.1d * 4.2d, tol);
Assert.assertEquals(dist.getNumericalVariance(), 1.1d * 4.2d * 4.2d, tol);
}
public static double density(final double x, final double shape,
final double scale) {
/*
* This is a copy of
* double GammaDistribution.density(double)
* prior to r1338548.
*/
if (x < 0) {
return 0;
}
return FastMath.pow(x / scale, shape - 1) / scale *
FastMath.exp(-x / scale) / FastMath.exp(Gamma.logGamma(shape));
}
/*
* MATH-753: large values of x or shape parameter cause density(double) to
* overflow. Reference data is generated with the Maxima script
* gamma-distribution.mac, which can be found in
* src/test/resources/org/apache/commons/math3/distribution.
*/
private void doTestMath753(final double shape,
final double meanNoOF, final double sdNoOF,
final double meanOF, final double sdOF,
final String resourceName) throws IOException {
final GammaDistribution distribution = new GammaDistribution(shape, 1.0);
final SummaryStatistics statOld = new SummaryStatistics();
final SummaryStatistics statNewNoOF = new SummaryStatistics();
final SummaryStatistics statNewOF = new SummaryStatistics();
final InputStream resourceAsStream;
resourceAsStream = this.getClass().getResourceAsStream(resourceName);
Assert.assertNotNull("Could not find resource " + resourceName,
resourceAsStream);
final BufferedReader in;
in = new BufferedReader(new InputStreamReader(resourceAsStream));
try {
for (String line = in.readLine(); line != null; line = in
.readLine()) {
final String[] tokens = line.split(", ");
Assert.assertTrue("expected two floating-point values",
tokens.length == 2);
final double x = Double.parseDouble(tokens[0]);
final String msg = "x = " + x + ", shape = " + shape +
", scale = 1.0";
final double expected = Double.parseDouble(tokens[1]);
final double ulp = FastMath.ulp(expected);
final double actualOld = density(x, shape, 1.0);
final double actualNew = distribution.density(x);
final double errOld, errNew;
errOld = FastMath.abs((actualOld - expected) / ulp);
errNew = FastMath.abs((actualNew - expected) / ulp);
if (Double.isNaN(actualOld) || Double.isInfinite(actualOld)) {
Assert.assertFalse(msg, Double.isNaN(actualNew));
Assert.assertFalse(msg, Double.isInfinite(actualNew));
statNewOF.addValue(errNew);
} else {
statOld.addValue(errOld);
statNewNoOF.addValue(errNew);
}
}
if (statOld.getN() != 0) {
/*
* If no overflow occurs, check that new implementation is
* better than old one.
*/
final StringBuilder sb = new StringBuilder("shape = ");
sb.append(shape);
sb.append(", scale = 1.0\n");
sb.append("Old implementation\n");
sb.append("------------------\n");
sb.append(statOld.toString());
sb.append("New implementation\n");
sb.append("------------------\n");
sb.append(statNewNoOF.toString());
final String msg = sb.toString();
final double oldMin = statOld.getMin();
final double newMin = statNewNoOF.getMin();
Assert.assertTrue(msg, newMin <= oldMin);
final double oldMax = statOld.getMax();
final double newMax = statNewNoOF.getMax();
Assert.assertTrue(msg, newMax <= oldMax);
final double oldMean = statOld.getMean();
final double newMean = statNewNoOF.getMean();
Assert.assertTrue(msg, newMean <= oldMean);
final double oldSd = statOld.getStandardDeviation();
final double newSd = statNewNoOF.getStandardDeviation();
Assert.assertTrue(msg, newSd <= oldSd);
Assert.assertTrue(msg, newMean <= meanNoOF);
Assert.assertTrue(msg, newSd <= sdNoOF);
}
if (statNewOF.getN() != 0) {
final double newMean = statNewOF.getMean();
final double newSd = statNewOF.getStandardDeviation();
final StringBuilder sb = new StringBuilder("shape = ");
sb.append(shape);
sb.append(", scale = 1.0");
sb.append(", max. mean error (ulps) = ");
sb.append(meanOF);
sb.append(", actual mean error (ulps) = ");
sb.append(newMean);
sb.append(", max. sd of error (ulps) = ");
sb.append(sdOF);
sb.append(", actual sd of error (ulps) = ");
sb.append(newSd);
final String msg = sb.toString();
Assert.assertTrue(msg, newMean <= meanOF);
Assert.assertTrue(msg, newSd <= sdOF);
}
} catch (IOException e) {
Assert.fail(e.getMessage());
} finally {
in.close();
}
}
@Test
public void testMath753Shape1() throws IOException {
doTestMath753(1.0, 1.5, 0.5, 0.0, 0.0, "gamma-distribution-shape-1.csv");
}
@Test
public void testMath753Shape10() throws IOException {
doTestMath753(10.0, 1.0, 1.0, 0.0, 0.0, "gamma-distribution-shape-10.csv");
}
@Test
public void testMath753Shape100() throws IOException {
doTestMath753(100.0, 1.5, 1.0, 0.0, 0.0, "gamma-distribution-shape-100.csv");
}
@Test
public void testMath753Shape142() throws IOException {
doTestMath753(142.0, 0.5, 1.5, 40.0, 40.0, "gamma-distribution-shape-142.csv");
}
@Test
public void testMath753Shape1000() throws IOException {
doTestMath753(1000.0, 1.0, 1.0, 160.0, 220.0, "gamma-distribution-shape-1000.csv");
}
}