[MATH-1220] Improve performance of ZipfDistribution.sample. Thanks to Otmar Ertl.
This commit is contained in:
parent
5597ed7ea3
commit
002276ea31
3
pom.xml
3
pom.xml
|
@ -206,6 +206,9 @@
|
|||
<contributor>
|
||||
<name>Ole Ersoy</name>
|
||||
</contributor>
|
||||
<contributor>
|
||||
<name>Otmar Ertl</name>
|
||||
</contributor>
|
||||
<contributor>
|
||||
<name>Ajo Fod</name>
|
||||
</contributor>
|
||||
|
|
|
@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces!
|
|||
</release>
|
||||
|
||||
<release version="4.0" date="XXXX-XX-XX" description="">
|
||||
<action dev="tn" type="fix" issue="MATH-1220" due-to="Otmar Ertl"> <!-- backported to 3.6 -->
|
||||
Improve performance of "ZipfDistribution#sample()" by using a rejection algorithm.
|
||||
</action>
|
||||
<action dev="tn" type="fix" issue="MATH-1153" due-to="Sergei Lebedev"> <!-- backported to 3.6 -->
|
||||
Improve performance of "BetaDistribution#sample()" by using Cheng's algorithm.
|
||||
</action>
|
||||
|
|
|
@ -43,6 +43,8 @@ public class ZipfDistribution extends AbstractIntegerDistribution {
|
|||
private double numericalVariance = Double.NaN;
|
||||
/** Whether or not the numerical variance has been calculated */
|
||||
private boolean numericalVarianceIsCalculated = false;
|
||||
/** The sampler to be used for the sample() method */
|
||||
private transient ZipfRejectionSampler sampler;
|
||||
|
||||
/**
|
||||
* Create a new Zipf distribution with the given number of elements and
|
||||
|
@ -265,5 +267,152 @@ public class ZipfDistribution extends AbstractIntegerDistribution {
|
|||
public boolean isSupportConnected() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
* <p>
|
||||
* An instrumental distribution g(k) is used to generate random values by
|
||||
* rejection sampling. g(k) is defined as g(1):= 1 and g(k) := I(-s,k-1/2,k+1/2)
|
||||
* for k larger than 1, where s denotes the exponent of the Zipf distribution
|
||||
* and I(r,a,b) is the integral of x^r for x from a to b.
|
||||
* <p>
|
||||
* Since 1^x^s is a convex function, Jensens's inequality gives
|
||||
* I(-s,k-1/2,k+1/2) >= 1/k^s for all positive k and non-negative s.
|
||||
* In order to limit the rejection rate for large exponents s,
|
||||
* the instrumental distribution weight is differently defined for value 1.
|
||||
*/
|
||||
@Override
|
||||
public int sample() {
|
||||
if (sampler == null) {
|
||||
sampler = new ZipfRejectionSampler(numberOfElements, exponent);
|
||||
}
|
||||
return sampler.sample(random);
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility class implementing a rejection sampling method for a discrete,
|
||||
* bounded Zipf distribution.
|
||||
*
|
||||
* @since 3.6
|
||||
*/
|
||||
static final class ZipfRejectionSampler {
|
||||
|
||||
/** Number of elements. */
|
||||
private final int numberOfElements;
|
||||
/** Exponent parameter of the distribution. */
|
||||
private final double exponent;
|
||||
/** Cached tail weight of instrumental distribution used for rejection sampling */
|
||||
private double instrumentalDistributionTailWeight = Double.NaN;
|
||||
|
||||
ZipfRejectionSampler(final int numberOfElements, final double exponent) {
|
||||
this.numberOfElements = numberOfElements;
|
||||
this.exponent = exponent;
|
||||
}
|
||||
|
||||
int sample(final RandomGenerator random) {
|
||||
if (Double.isNaN(instrumentalDistributionTailWeight)) {
|
||||
instrumentalDistributionTailWeight = integratePowerFunction(-exponent, 1.5, numberOfElements+0.5);
|
||||
}
|
||||
|
||||
while(true) {
|
||||
final double randomValue = random.nextDouble()*(instrumentalDistributionTailWeight + 1.);
|
||||
if (randomValue < instrumentalDistributionTailWeight) {
|
||||
final double q = randomValue / instrumentalDistributionTailWeight;
|
||||
final int sample = sampleFromInstrumentalDistributionTail(q);
|
||||
if (random.nextDouble() < acceptanceRateForTailSample(sample)) {
|
||||
return sample;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a sample from the instrumental distribution tail for a given
|
||||
* uniformly distributed random value.
|
||||
*
|
||||
* @param q a uniformly distributed random value taken from [0,1]
|
||||
* @return a sample in the range [2, {@link #numberOfElements}]
|
||||
*/
|
||||
int sampleFromInstrumentalDistributionTail(double q) {
|
||||
final double a = 1.5;
|
||||
final double b = numberOfElements + 0.5;
|
||||
final double logBdviA = FastMath.log(b / a);
|
||||
|
||||
final int result = (int) (a * FastMath.exp(logBdviA * helper1(q, logBdviA * (1. - exponent))) + 0.5);
|
||||
if (result < 2) {
|
||||
return 2;
|
||||
}
|
||||
if (result > numberOfElements) {
|
||||
return numberOfElements;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function that calculates log((1-q)+q*exp(x))/x.
|
||||
* <p>
|
||||
* A Taylor series expansion is used, if x is close to 0.
|
||||
*
|
||||
* @param q a value in the range [0,1]
|
||||
* @param
|
||||
* @return log((1-q)+q*exp(x))/x
|
||||
*/
|
||||
static double helper1(final double q, final double x) {
|
||||
if (Math.abs(x) > 1e-8) {
|
||||
return FastMath.log((1.-q)+q*FastMath.exp(x))/x;
|
||||
}
|
||||
else {
|
||||
return q*(1.+(1./2.)*x*(1.-q)*(1+(1./3.)*x*((1.-2.*q) + (1./4.)*x*(6*q*q*(q-1)+1))));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to calculate (exp(x)-1)/x.
|
||||
* <p>
|
||||
* A Taylor series expansion is used, if x is close to 0.
|
||||
*
|
||||
* @return (exp(x)-1)/x if x is non-zero, 1 if x=0
|
||||
*/
|
||||
static double helper2(final double x) {
|
||||
if (FastMath.abs(x)>1e-8) {
|
||||
return FastMath.expm1(x)/x;
|
||||
}
|
||||
else {
|
||||
return 1.+x*(1./2.)*(1.+x*(1./3.)*(1.+x*(1./4.)));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Integrates the power function x^r from x=a to b.
|
||||
*
|
||||
* @param r the exponent
|
||||
* @param a the integral lower bound
|
||||
* @param b the integral upper bound
|
||||
* @return the calculated integral value
|
||||
*/
|
||||
static double integratePowerFunction(final double r, final double a, final double b) {
|
||||
final double logA = FastMath.log(a);
|
||||
final double logBdivA = FastMath.log(b/a);
|
||||
return FastMath.exp((1.+r)*logA)*helper2((1.+r)*logBdivA)*logBdivA;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the acceptance rate for a sample taken from the tail of the instrumental distribution.
|
||||
* <p>
|
||||
* The acceptance rate is given by the ratio k^(-s)/I(-s,k-0.5, k+0.5)
|
||||
* where I(r,a,b) is the integral of x^r for x from a to b.
|
||||
*
|
||||
* @param k the value which has been sampled using the instrumental distribution
|
||||
* @return the acceptance rate
|
||||
*/
|
||||
double acceptanceRateForTailSample(int k) {
|
||||
final double a = FastMath.log1p(1./(2.*k-1.));
|
||||
final double b = FastMath.log1p(2./(2.*k-1.));
|
||||
return FastMath.exp((1.-exponent)*a)/(k*b*helper2((1.-exponent)*b));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,18 +17,28 @@
|
|||
|
||||
package org.apache.commons.math4.distribution;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import org.apache.commons.math4.TestUtils;
|
||||
import org.apache.commons.math4.analysis.UnivariateFunction;
|
||||
import org.apache.commons.math4.analysis.integration.SimpsonIntegrator;
|
||||
import org.apache.commons.math4.distribution.IntegerDistribution;
|
||||
import org.apache.commons.math4.distribution.ZipfDistribution;
|
||||
import org.apache.commons.math4.distribution.ZipfDistribution.ZipfRejectionSampler;
|
||||
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math4.random.AbstractRandomGenerator;
|
||||
import org.apache.commons.math4.random.RandomGenerator;
|
||||
import org.apache.commons.math4.random.Well1024a;
|
||||
import org.apache.commons.math4.util.FastMath;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Test cases for {@link ZipfDistribution}.
|
||||
* Extends IntegerDistributionAbstractTest. See class javadoc for
|
||||
* IntegerDistributionAbstractTest for details.
|
||||
*
|
||||
* Extends IntegerDistributionAbstractTest.
|
||||
* See class javadoc for IntegerDistributionAbstractTest for details.
|
||||
*/
|
||||
public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
|
||||
|
||||
|
@ -120,4 +130,210 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
|
|||
Assert.assertEquals(dist.getNumericalMean(), FastMath.sqrt(2), tol);
|
||||
Assert.assertEquals(dist.getNumericalVariance(), 0.24264068711928521, tol);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Test sampling for various number of points and exponents.
|
||||
*/
|
||||
@Test
|
||||
public void testSamplingExtended() {
|
||||
int sampleSize = 1000;
|
||||
|
||||
int[] numPointsValues = {
|
||||
2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100
|
||||
};
|
||||
double[] exponentValues = {
|
||||
1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1,
|
||||
1. - 1e-9, 1.0, 1. + 1e-9, 1.1, 1.2, 1.3, 1.5, 1.6, 1.7, 1.8, 2.0,
|
||||
2.5, 3.0, 4., 5., 6., 7., 8., 9., 10., 20., 30.
|
||||
};
|
||||
|
||||
for (int numPoints : numPointsValues) {
|
||||
for (double exponent : exponentValues) {
|
||||
double weightSum = 0.;
|
||||
double[] weights = new double[numPoints];
|
||||
for (int i = numPoints; i>=1; i-=1) {
|
||||
weights[i-1] = Math.pow(i, -exponent);
|
||||
weightSum += weights[i-1];
|
||||
}
|
||||
|
||||
ZipfDistribution distribution = new ZipfDistribution(numPoints, exponent);
|
||||
distribution.reseedRandomGenerator(6); // use fixed seed, the test is expected to fail for more than 50% of all seeds because each test case can fail with probability 0.001, the chance that all test cases do not fail is 0.999^(32*22) = 0.49442874426
|
||||
|
||||
double[] expectedCounts = new double[numPoints];
|
||||
long[] observedCounts = new long[numPoints];
|
||||
for (int i = 0; i < numPoints; i++) {
|
||||
expectedCounts[i] = sampleSize * (weights[i]/weightSum);
|
||||
}
|
||||
int[] sample = distribution.sample(sampleSize);
|
||||
for (int s : sample) {
|
||||
observedCounts[s-1]++;
|
||||
}
|
||||
TestUtils.assertChiSquareAccept(expectedCounts, observedCounts, 0.001);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSamplerIntegratePowerFunction() {
|
||||
final double tol = 1e-6;
|
||||
final double[] exponents = {
|
||||
-1e-5, -1e-4, -1e-3, -1e-2, -1e-1, -1e0, -1e1
|
||||
};
|
||||
final double[] limits = {
|
||||
0.5, 1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6.0, 6.5, 7.0,
|
||||
7.5, 8.0, 8.5, 9.0, 9.5, 10.0
|
||||
};
|
||||
|
||||
for (final double exponent : exponents) {
|
||||
for (int lowerLimitIndex = 0; lowerLimitIndex < limits.length; ++lowerLimitIndex) {
|
||||
final double lowerLimit = limits[lowerLimitIndex];
|
||||
for (int upperLimitIndex = lowerLimitIndex+1; upperLimitIndex < limits.length; ++upperLimitIndex) {
|
||||
final double upperLimit = limits[upperLimitIndex];
|
||||
final double result1 = new SimpsonIntegrator().integrate(10000, new UnivariateFunction() {
|
||||
@Override
|
||||
public double value(double x) {
|
||||
return Math.pow(x, exponent);
|
||||
}
|
||||
}, lowerLimit, upperLimit);
|
||||
|
||||
final double result2 =
|
||||
ZipfRejectionSampler.integratePowerFunction(exponent, lowerLimit, upperLimit);
|
||||
assertEquals(result1, result2, (result1+result2)*tol);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSamplerAcceptanceRate() {
|
||||
final double tol = 1e-12;
|
||||
final double[] exponents = {
|
||||
1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 2e0, 5e0, 1e1, 1e2, 1e3
|
||||
};
|
||||
final int[] values = {
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
|
||||
};
|
||||
final int numberOfElements = 1000;
|
||||
for (final double exponent : exponents) {
|
||||
ZipfRejectionSampler sampler = new ZipfRejectionSampler(numberOfElements, exponent);
|
||||
for (final int value : values) {
|
||||
double expected = FastMath.pow(value, -exponent);
|
||||
double result = sampler.acceptanceRateForTailSample(value) *
|
||||
ZipfRejectionSampler.integratePowerFunction(-exponent, value - 0.5, value + 0.5);
|
||||
TestUtils.assertRelativelyEquals(expected, result, tol);
|
||||
assertTrue(result <= 1.); // test Jensen's inequality
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSamplerInverseInstrumentalDistribution() {
|
||||
final double tol = 1e-14;
|
||||
final double[] exponentValues = {
|
||||
1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 2E0, 3e0, 4e0, 5e0, 6., 7., 8., 9., 10., 50.
|
||||
};
|
||||
final double[] qValues = {
|
||||
0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0
|
||||
};
|
||||
final int[] numberOfElementsValues = {
|
||||
2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 100
|
||||
};
|
||||
|
||||
for (final double exponent : exponentValues) {
|
||||
for (final int numberOfElements : numberOfElementsValues) {
|
||||
final ZipfRejectionSampler sampler = new ZipfRejectionSampler(numberOfElements, exponent);
|
||||
for (final double q : qValues) {
|
||||
int result = sampler.sampleFromInstrumentalDistributionTail(q);
|
||||
double total =
|
||||
ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, numberOfElements + 0.5);
|
||||
double lowerBound =
|
||||
ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, result - 0.5) / total;
|
||||
double upperBound =
|
||||
ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, result + 0.5) / total;
|
||||
assertTrue(lowerBound <= q*(1.+tol));
|
||||
assertTrue(upperBound >= q*(1.-tol));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSamplerHelper1() {
|
||||
final double tol = 1e-14;
|
||||
final double[] qValues = {
|
||||
0., 1e-12, 1e-11, 1e-10, 1e-9, 9e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4,
|
||||
1e-3, 1e-2, 1e-1, 1e0
|
||||
};
|
||||
final double[] xValues = {
|
||||
-Double.MAX_VALUE, -1e10, -1e9, -1e8, -1e7, -1e6, -1e5, -1e4, -1e3,
|
||||
-1e2, -1e1, -1e0, -1e-1, -1e-2, -1e-3, -1e-4, -1e-5, -1e-6, -1e-7,
|
||||
-1e-8, -1e-9, -1e-10, -Double.MIN_VALUE, 0.0, Double.MIN_VALUE,
|
||||
1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0,
|
||||
1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, Double.MAX_VALUE
|
||||
};
|
||||
|
||||
for (final double q : qValues) {
|
||||
for(final double x : xValues) {
|
||||
double calculated = ZipfRejectionSampler.helper1(q, x);
|
||||
TestUtils.assertRelativelyEquals((1.-q)+q*Math.exp(x), FastMath.exp(calculated*x), tol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSamplerHelper2() {
|
||||
final double tol = 1e-12;
|
||||
final double[] testValues = {
|
||||
-1e0, -1e-1, -1e-2, -1e-3, -1e-4, -1e-5, -1e-6, -1e-7, -1e-8,
|
||||
-1e-9, -1e-10, -1e-11, 0., 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6,
|
||||
1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0
|
||||
};
|
||||
for (double testValue : testValues) {
|
||||
final double expected = FastMath.expm1(testValue);
|
||||
TestUtils.assertRelativelyEquals(expected, ZipfRejectionSampler.helper2(testValue)*testValue, tol);
|
||||
}
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testSamplerPerformance() {
|
||||
int[] numPointsValues = {1, 2, 5, 10, 100, 1000, 10000};
|
||||
double[] exponentValues = {1e-3, 1e-2, 1e-1, 1., 2., 5., 10.};
|
||||
int numGeneratedSamples = 1000000;
|
||||
|
||||
long sum = 0;
|
||||
|
||||
for (int numPoints : numPointsValues) {
|
||||
for (double exponent : exponentValues) {
|
||||
long start = System.currentTimeMillis();
|
||||
final int[] randomNumberCounter = new int[1];
|
||||
|
||||
RandomGenerator randomGenerator = new AbstractRandomGenerator() {
|
||||
|
||||
private final RandomGenerator r = new Well1024a(0L);
|
||||
|
||||
@Override
|
||||
public void setSeed(long seed) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public double nextDouble() {
|
||||
randomNumberCounter[0]+=1;
|
||||
return r.nextDouble();
|
||||
}
|
||||
};
|
||||
|
||||
final ZipfDistribution distribution = new ZipfDistribution(randomGenerator, numPoints, exponent);
|
||||
for (int i = 0; i < numGeneratedSamples; ++i) {
|
||||
sum += distribution.sample();
|
||||
}
|
||||
|
||||
long end = System.currentTimeMillis();
|
||||
System.out.println("n = " + numPoints + ", exponent = " + exponent + ", avg number consumed random values = " + (double)(randomNumberCounter[0])/numGeneratedSamples + ", measured time = " + (end-start)/1000. + "s");
|
||||
}
|
||||
}
|
||||
System.out.println(sum);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue