[MATH-1220] Improve performance of ZipfDistribution.sample. Thanks to Otmar Ertl.

This commit is contained in:
Thomas Neidhart 2015-05-01 13:50:10 +02:00
parent 5597ed7ea3
commit 002276ea31
4 changed files with 379 additions and 8 deletions

View File

@ -206,6 +206,9 @@
<contributor> <contributor>
<name>Ole Ersoy</name> <name>Ole Ersoy</name>
</contributor> </contributor>
<contributor>
<name>Otmar Ertl</name>
</contributor>
<contributor> <contributor>
<name>Ajo Fod</name> <name>Ajo Fod</name>
</contributor> </contributor>

View File

@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces!
</release> </release>
<release version="4.0" date="XXXX-XX-XX" description=""> <release version="4.0" date="XXXX-XX-XX" description="">
<action dev="tn" type="fix" issue="MATH-1220" due-to="Otmar Ertl"> <!-- backported to 3.6 -->
Improve performance of "ZipfDistribution#sample()" by using a rejection algorithm.
</action>
<action dev="tn" type="fix" issue="MATH-1153" due-to="Sergei Lebedev"> <!-- backported to 3.6 --> <action dev="tn" type="fix" issue="MATH-1153" due-to="Sergei Lebedev"> <!-- backported to 3.6 -->
Improve performance of "BetaDistribution#sample()" by using Cheng's algorithm. Improve performance of "BetaDistribution#sample()" by using Cheng's algorithm.
</action> </action>

View File

@ -43,6 +43,8 @@ public class ZipfDistribution extends AbstractIntegerDistribution {
private double numericalVariance = Double.NaN; private double numericalVariance = Double.NaN;
/** Whether or not the numerical variance has been calculated */ /** Whether or not the numerical variance has been calculated */
private boolean numericalVarianceIsCalculated = false; private boolean numericalVarianceIsCalculated = false;
/** The sampler to be used for the sample() method */
private transient ZipfRejectionSampler sampler;
/** /**
* Create a new Zipf distribution with the given number of elements and * Create a new Zipf distribution with the given number of elements and
@ -265,5 +267,152 @@ public class ZipfDistribution extends AbstractIntegerDistribution {
public boolean isSupportConnected() { public boolean isSupportConnected() {
return true; return true;
} }
}
/**
* {@inheritDoc}
* <p>
* An instrumental distribution g(k) is used to generate random values by
* rejection sampling. g(k) is defined as g(1):= 1 and g(k) := I(-s,k-1/2,k+1/2)
* for k larger than 1, where s denotes the exponent of the Zipf distribution
* and I(r,a,b) is the integral of x^r for x from a to b.
* <p>
* Since 1^x^s is a convex function, Jensens's inequality gives
* I(-s,k-1/2,k+1/2) >= 1/k^s for all positive k and non-negative s.
* In order to limit the rejection rate for large exponents s,
* the instrumental distribution weight is differently defined for value 1.
*/
@Override
public int sample() {
if (sampler == null) {
sampler = new ZipfRejectionSampler(numberOfElements, exponent);
}
return sampler.sample(random);
}
/**
* Utility class implementing a rejection sampling method for a discrete,
* bounded Zipf distribution.
*
* @since 3.6
*/
static final class ZipfRejectionSampler {
/** Number of elements. */
private final int numberOfElements;
/** Exponent parameter of the distribution. */
private final double exponent;
/** Cached tail weight of instrumental distribution used for rejection sampling */
private double instrumentalDistributionTailWeight = Double.NaN;
ZipfRejectionSampler(final int numberOfElements, final double exponent) {
this.numberOfElements = numberOfElements;
this.exponent = exponent;
}
int sample(final RandomGenerator random) {
if (Double.isNaN(instrumentalDistributionTailWeight)) {
instrumentalDistributionTailWeight = integratePowerFunction(-exponent, 1.5, numberOfElements+0.5);
}
while(true) {
final double randomValue = random.nextDouble()*(instrumentalDistributionTailWeight + 1.);
if (randomValue < instrumentalDistributionTailWeight) {
final double q = randomValue / instrumentalDistributionTailWeight;
final int sample = sampleFromInstrumentalDistributionTail(q);
if (random.nextDouble() < acceptanceRateForTailSample(sample)) {
return sample;
}
}
else {
return 1;
}
}
}
/**
* Returns a sample from the instrumental distribution tail for a given
* uniformly distributed random value.
*
* @param q a uniformly distributed random value taken from [0,1]
* @return a sample in the range [2, {@link #numberOfElements}]
*/
int sampleFromInstrumentalDistributionTail(double q) {
final double a = 1.5;
final double b = numberOfElements + 0.5;
final double logBdviA = FastMath.log(b / a);
final int result = (int) (a * FastMath.exp(logBdviA * helper1(q, logBdviA * (1. - exponent))) + 0.5);
if (result < 2) {
return 2;
}
if (result > numberOfElements) {
return numberOfElements;
}
return result;
}
/**
* Helper function that calculates log((1-q)+q*exp(x))/x.
* <p>
* A Taylor series expansion is used, if x is close to 0.
*
* @param q a value in the range [0,1]
* @param
* @return log((1-q)+q*exp(x))/x
*/
static double helper1(final double q, final double x) {
if (Math.abs(x) > 1e-8) {
return FastMath.log((1.-q)+q*FastMath.exp(x))/x;
}
else {
return q*(1.+(1./2.)*x*(1.-q)*(1+(1./3.)*x*((1.-2.*q) + (1./4.)*x*(6*q*q*(q-1)+1))));
}
}
/**
* Helper function to calculate (exp(x)-1)/x.
* <p>
* A Taylor series expansion is used, if x is close to 0.
*
* @return (exp(x)-1)/x if x is non-zero, 1 if x=0
*/
static double helper2(final double x) {
if (FastMath.abs(x)>1e-8) {
return FastMath.expm1(x)/x;
}
else {
return 1.+x*(1./2.)*(1.+x*(1./3.)*(1.+x*(1./4.)));
}
}
/**
* Integrates the power function x^r from x=a to b.
*
* @param r the exponent
* @param a the integral lower bound
* @param b the integral upper bound
* @return the calculated integral value
*/
static double integratePowerFunction(final double r, final double a, final double b) {
final double logA = FastMath.log(a);
final double logBdivA = FastMath.log(b/a);
return FastMath.exp((1.+r)*logA)*helper2((1.+r)*logBdivA)*logBdivA;
}
/**
* Calculates the acceptance rate for a sample taken from the tail of the instrumental distribution.
* <p>
* The acceptance rate is given by the ratio k^(-s)/I(-s,k-0.5, k+0.5)
* where I(r,a,b) is the integral of x^r for x from a to b.
*
* @param k the value which has been sampled using the instrumental distribution
* @return the acceptance rate
*/
double acceptanceRateForTailSample(int k) {
final double a = FastMath.log1p(1./(2.*k-1.));
final double b = FastMath.log1p(2./(2.*k-1.));
return FastMath.exp((1.-exponent)*a)/(k*b*helper2((1.-exponent)*b));
}
}
}

View File

@ -17,18 +17,28 @@
package org.apache.commons.math4.distribution; package org.apache.commons.math4.distribution;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import org.apache.commons.math4.TestUtils;
import org.apache.commons.math4.analysis.UnivariateFunction;
import org.apache.commons.math4.analysis.integration.SimpsonIntegrator;
import org.apache.commons.math4.distribution.IntegerDistribution; import org.apache.commons.math4.distribution.IntegerDistribution;
import org.apache.commons.math4.distribution.ZipfDistribution; import org.apache.commons.math4.distribution.ZipfDistribution;
import org.apache.commons.math4.distribution.ZipfDistribution.ZipfRejectionSampler;
import org.apache.commons.math4.exception.NotStrictlyPositiveException; import org.apache.commons.math4.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.random.AbstractRandomGenerator;
import org.apache.commons.math4.random.RandomGenerator;
import org.apache.commons.math4.random.Well1024a;
import org.apache.commons.math4.util.FastMath; import org.apache.commons.math4.util.FastMath;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
/** /**
* Test cases for {@link ZipfDistribution}. * Test cases for {@link ZipfDistribution}.
* Extends IntegerDistributionAbstractTest. See class javadoc for * Extends IntegerDistributionAbstractTest.
* IntegerDistributionAbstractTest for details. * See class javadoc for IntegerDistributionAbstractTest for details.
*
*/ */
public class ZipfDistributionTest extends IntegerDistributionAbstractTest { public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
@ -38,7 +48,7 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
public ZipfDistributionTest() { public ZipfDistributionTest() {
setTolerance(1e-12); setTolerance(1e-12);
} }
@Test(expected=NotStrictlyPositiveException.class) @Test(expected=NotStrictlyPositiveException.class)
public void testPreconditions1() { public void testPreconditions1() {
new ZipfDistribution(0, 1); new ZipfDistribution(0, 1);
@ -63,9 +73,9 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
return new int[] {-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; return new int[] {-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
} }
/** /**
* Creates the default probability density test expected values. * Creates the default probability density test expected values.
* Reference values are from R, version 2.15.3 (VGAM package 0.9-0). * Reference values are from R, version 2.15.3 (VGAM package 0.9-0).
*/ */
@Override @Override
public double[] makeDensityTestValues() { public double[] makeDensityTestValues() {
@ -73,7 +83,7 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
0.0569028586912, 0.0487738788782, 0.0426771440184, 0.0379352391275, 0.0341417152147, 0}; 0.0569028586912, 0.0487738788782, 0.0426771440184, 0.0379352391275, 0.0341417152147, 0};
} }
/** /**
* Creates the default logarithmic probability density test expected values. * Creates the default logarithmic probability density test expected values.
* Reference values are from R, version 2.14.1. * Reference values are from R, version 2.14.1.
*/ */
@ -120,4 +130,210 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
Assert.assertEquals(dist.getNumericalMean(), FastMath.sqrt(2), tol); Assert.assertEquals(dist.getNumericalMean(), FastMath.sqrt(2), tol);
Assert.assertEquals(dist.getNumericalVariance(), 0.24264068711928521, tol); Assert.assertEquals(dist.getNumericalVariance(), 0.24264068711928521, tol);
} }
/**
* Test sampling for various number of points and exponents.
*/
@Test
public void testSamplingExtended() {
int sampleSize = 1000;
int[] numPointsValues = {
2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100
};
double[] exponentValues = {
1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1,
1. - 1e-9, 1.0, 1. + 1e-9, 1.1, 1.2, 1.3, 1.5, 1.6, 1.7, 1.8, 2.0,
2.5, 3.0, 4., 5., 6., 7., 8., 9., 10., 20., 30.
};
for (int numPoints : numPointsValues) {
for (double exponent : exponentValues) {
double weightSum = 0.;
double[] weights = new double[numPoints];
for (int i = numPoints; i>=1; i-=1) {
weights[i-1] = Math.pow(i, -exponent);
weightSum += weights[i-1];
}
ZipfDistribution distribution = new ZipfDistribution(numPoints, exponent);
distribution.reseedRandomGenerator(6); // use fixed seed, the test is expected to fail for more than 50% of all seeds because each test case can fail with probability 0.001, the chance that all test cases do not fail is 0.999^(32*22) = 0.49442874426
double[] expectedCounts = new double[numPoints];
long[] observedCounts = new long[numPoints];
for (int i = 0; i < numPoints; i++) {
expectedCounts[i] = sampleSize * (weights[i]/weightSum);
}
int[] sample = distribution.sample(sampleSize);
for (int s : sample) {
observedCounts[s-1]++;
}
TestUtils.assertChiSquareAccept(expectedCounts, observedCounts, 0.001);
}
}
}
@Test
public void testSamplerIntegratePowerFunction() {
final double tol = 1e-6;
final double[] exponents = {
-1e-5, -1e-4, -1e-3, -1e-2, -1e-1, -1e0, -1e1
};
final double[] limits = {
0.5, 1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6.0, 6.5, 7.0,
7.5, 8.0, 8.5, 9.0, 9.5, 10.0
};
for (final double exponent : exponents) {
for (int lowerLimitIndex = 0; lowerLimitIndex < limits.length; ++lowerLimitIndex) {
final double lowerLimit = limits[lowerLimitIndex];
for (int upperLimitIndex = lowerLimitIndex+1; upperLimitIndex < limits.length; ++upperLimitIndex) {
final double upperLimit = limits[upperLimitIndex];
final double result1 = new SimpsonIntegrator().integrate(10000, new UnivariateFunction() {
@Override
public double value(double x) {
return Math.pow(x, exponent);
}
}, lowerLimit, upperLimit);
final double result2 =
ZipfRejectionSampler.integratePowerFunction(exponent, lowerLimit, upperLimit);
assertEquals(result1, result2, (result1+result2)*tol);
}
}
}
}
@Test
public void testSamplerAcceptanceRate() {
final double tol = 1e-12;
final double[] exponents = {
1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 2e0, 5e0, 1e1, 1e2, 1e3
};
final int[] values = {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
};
final int numberOfElements = 1000;
for (final double exponent : exponents) {
ZipfRejectionSampler sampler = new ZipfRejectionSampler(numberOfElements, exponent);
for (final int value : values) {
double expected = FastMath.pow(value, -exponent);
double result = sampler.acceptanceRateForTailSample(value) *
ZipfRejectionSampler.integratePowerFunction(-exponent, value - 0.5, value + 0.5);
TestUtils.assertRelativelyEquals(expected, result, tol);
assertTrue(result <= 1.); // test Jensen's inequality
}
}
}
@Test
public void testSamplerInverseInstrumentalDistribution() {
final double tol = 1e-14;
final double[] exponentValues = {
1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 2E0, 3e0, 4e0, 5e0, 6., 7., 8., 9., 10., 50.
};
final double[] qValues = {
0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0
};
final int[] numberOfElementsValues = {
2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 100
};
for (final double exponent : exponentValues) {
for (final int numberOfElements : numberOfElementsValues) {
final ZipfRejectionSampler sampler = new ZipfRejectionSampler(numberOfElements, exponent);
for (final double q : qValues) {
int result = sampler.sampleFromInstrumentalDistributionTail(q);
double total =
ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, numberOfElements + 0.5);
double lowerBound =
ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, result - 0.5) / total;
double upperBound =
ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, result + 0.5) / total;
assertTrue(lowerBound <= q*(1.+tol));
assertTrue(upperBound >= q*(1.-tol));
}
}
}
}
@Test
public void testSamplerHelper1() {
final double tol = 1e-14;
final double[] qValues = {
0., 1e-12, 1e-11, 1e-10, 1e-9, 9e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4,
1e-3, 1e-2, 1e-1, 1e0
};
final double[] xValues = {
-Double.MAX_VALUE, -1e10, -1e9, -1e8, -1e7, -1e6, -1e5, -1e4, -1e3,
-1e2, -1e1, -1e0, -1e-1, -1e-2, -1e-3, -1e-4, -1e-5, -1e-6, -1e-7,
-1e-8, -1e-9, -1e-10, -Double.MIN_VALUE, 0.0, Double.MIN_VALUE,
1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0,
1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, Double.MAX_VALUE
};
for (final double q : qValues) {
for(final double x : xValues) {
double calculated = ZipfRejectionSampler.helper1(q, x);
TestUtils.assertRelativelyEquals((1.-q)+q*Math.exp(x), FastMath.exp(calculated*x), tol);
}
}
}
@Test
public void testSamplerHelper2() {
final double tol = 1e-12;
final double[] testValues = {
-1e0, -1e-1, -1e-2, -1e-3, -1e-4, -1e-5, -1e-6, -1e-7, -1e-8,
-1e-9, -1e-10, -1e-11, 0., 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6,
1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0
};
for (double testValue : testValues) {
final double expected = FastMath.expm1(testValue);
TestUtils.assertRelativelyEquals(expected, ZipfRejectionSampler.helper2(testValue)*testValue, tol);
}
}
@Ignore
@Test
public void testSamplerPerformance() {
int[] numPointsValues = {1, 2, 5, 10, 100, 1000, 10000};
double[] exponentValues = {1e-3, 1e-2, 1e-1, 1., 2., 5., 10.};
int numGeneratedSamples = 1000000;
long sum = 0;
for (int numPoints : numPointsValues) {
for (double exponent : exponentValues) {
long start = System.currentTimeMillis();
final int[] randomNumberCounter = new int[1];
RandomGenerator randomGenerator = new AbstractRandomGenerator() {
private final RandomGenerator r = new Well1024a(0L);
@Override
public void setSeed(long seed) {
}
@Override
public double nextDouble() {
randomNumberCounter[0]+=1;
return r.nextDouble();
}
};
final ZipfDistribution distribution = new ZipfDistribution(randomGenerator, numPoints, exponent);
for (int i = 0; i < numGeneratedSamples; ++i) {
sum += distribution.sample();
}
long end = System.currentTimeMillis();
System.out.println("n = " + numPoints + ", exponent = " + exponent + ", avg number consumed random values = " + (double)(randomNumberCounter[0])/numGeneratedSamples + ", measured time = " + (end-start)/1000. + "s");
}
}
System.out.println(sum);
}
} }