From b44d8ee8c88f87a7a3410d7aa3266056034c16f2 Mon Sep 17 00:00:00 2001 From: Mikkel Meyer Andersen Date: Thu, 21 Jul 2011 20:37:35 +0000 Subject: [PATCH] Added fix for MATH-585: Implemented faster generation of random gamma variates using Ahrens and Dieter (1974) and Marsaglia and Tsang (2001). Test case was improved, too. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1149350 13f79535-47bb-0310-9956-ffa450edef68 --- .../commons/math/random/RandomDataImpl.java | 84 ++++++++++++++++++- .../commons/math/random/RandomDataTest.java | 18 +++- 2 files changed, 96 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/apache/commons/math/random/RandomDataImpl.java b/src/main/java/org/apache/commons/math/random/RandomDataImpl.java index 132dbd718..f46aa8a76 100644 --- a/src/main/java/org/apache/commons/math/random/RandomDataImpl.java +++ b/src/main/java/org/apache/commons/math/random/RandomDataImpl.java @@ -31,7 +31,6 @@ import org.apache.commons.math.distribution.CauchyDistributionImpl; import org.apache.commons.math.distribution.ChiSquaredDistributionImpl; import org.apache.commons.math.distribution.ContinuousDistribution; import org.apache.commons.math.distribution.FDistributionImpl; -import org.apache.commons.math.distribution.GammaDistributionImpl; import org.apache.commons.math.distribution.HypergeometricDistributionImpl; import org.apache.commons.math.distribution.IntegerDistribution; import org.apache.commons.math.distribution.PascalDistributionImpl; @@ -676,8 +675,17 @@ public class RandomDataImpl implements RandomData, Serializable { /** * Generates a random value from the {@link GammaDistributionImpl Gamma Distribution}. - * This implementation uses {@link #nextInversionDeviate(ContinuousDistribution) inversion} - * to generate random values. + * + * This implementation uses the following algorithms: + * + * For 0 < shape < 1: + * [1]: Ahrens, J. H. and Dieter, U. (1974). Computer methods for + * sampling from gamma, beta, Poisson and binomial distributions. + * Computing, 12, 223-246. + * + * For shape >= 1: + * [2]: Marsaglia and Tsang (2001). A Simple Method for Generating + * Gamma Variables. ACM Transactions on Mathematical Software, 26, * * @param shape the median of the Gamma distribution * @param scale the scale parameter of the Gamma distribution @@ -686,7 +694,75 @@ public class RandomDataImpl implements RandomData, Serializable { * @since 2.2 */ public double nextGamma(double shape, double scale) throws MathException { - return nextInversionDeviate(new GammaDistributionImpl(shape, scale)); + if (shape < 1) { + /* + final double gamma = this.nextOldGamma(1 + shape, scale); + final double u = this.nextUniform(0, 1); + return gamma * FastMath.pow(u, 1/shape); + */ + + // [1]: p. 228, Algorithm GS + + while (true) { + // Step 1: + final double u = this.nextUniform(0, 1); + final double bGS = 1 + shape/FastMath.E; + final double p = bGS*u; + + if (p <= 1) { + // Step 2: + + final double x = FastMath.pow(p, 1/shape); + final double u2 = this.nextUniform(0.0, 1); + + if (u2 > FastMath.exp(-x)) { + // Reject + continue; + } else { + return scale*x; + } + } else { + // Step 3: + + final double x = -1 * FastMath.log((bGS-p)/shape); + final double u2 = this.nextUniform(0, 1); + + if (u2 > FastMath.pow(x, shape - 1)) { + // Reject + continue; + } else { + return scale*x; + } + } + } + } + + // Now shape >= 1 + + final RandomGenerator generator = this.getRan(); + final double d = shape - 0.333333333333333333; + final double c = 1.0 / (3*FastMath.sqrt(d)); + + while (true) { + final double x = generator.nextGaussian(); + final double v = (1+c*x)*(1+c*x)*(1+c*x); + + if (v <= 0) { + continue; + } + + final double xx = x*x; + final double u = this.nextUniform(0, 1); + + // Squeeze + if (u < 1 - 0.0331*xx*xx) { + return scale*d*v; + } + + if (FastMath.log(u) < 0.5*xx + d*(1 - v + FastMath.log(v))) { + return scale*d*v; + } + } } /** diff --git a/src/test/java/org/apache/commons/math/random/RandomDataTest.java b/src/test/java/org/apache/commons/math/random/RandomDataTest.java index e55ce248d..5d0393df7 100644 --- a/src/test/java/org/apache/commons/math/random/RandomDataTest.java +++ b/src/test/java/org/apache/commons/math/random/RandomDataTest.java @@ -910,14 +910,28 @@ public class RandomDataTest { @Test public void testNextGamma() throws Exception { - double[] quartiles = TestUtils.getDistributionQuartiles(new GammaDistributionImpl(4, 2)); - long[] counts = new long[4]; + double[] quartiles; + long[] counts; + + // Tests shape > 1, one case in the rejection sampling + quartiles = TestUtils.getDistributionQuartiles(new GammaDistributionImpl(4, 2)); + counts = new long[4]; randomData.reSeed(1000); for (int i = 0; i < 1000; i++) { double value = randomData.nextGamma(4, 2); TestUtils.updateCounts(value, counts, quartiles); } TestUtils.assertChiSquareAccept(expected, counts, 0.001); + + // Tests shape <= 1, another case in the rejection sampling + quartiles = TestUtils.getDistributionQuartiles(new GammaDistributionImpl(0.3, 3)); + counts = new long[4]; + randomData.reSeed(1000); + for (int i = 0; i < 1000; i++) { + double value = randomData.nextGamma(0.3, 3); + TestUtils.updateCounts(value, counts, quartiles); + } + TestUtils.assertChiSquareAccept(expected, counts, 0.001); } @Test