Made EmpiricalDistribution smoothing kernel pluggable.

To enable subclasses to provide the enclosing distribution's underlying RandomGenerator
to distribution constructors, two more changes were required:
* In EmpiricalDistribution, the RandomDataGenerator field (randomData) was changed
  from private to protected.
* The private getRan() method in RandomDataGenerator returning the underlying
  RandomGenerator was renamed getRandomGenerator and made public.
JIRA: MATH-671

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1457372 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Phil Steitz 2013-03-17 04:28:04 +00:00
parent 7dd09bfb64
commit 5d5f253240
4 changed files with 188 additions and 33 deletions

View File

@ -55,6 +55,9 @@ This is a minor release: It combines bug fixes and new features.
Changes to existing features were made in a backwards-compatible
way such as to allow drop-in replacement of the v3.1[.1] JAR file.
">
<action dev="psteitz" type="update" issue="MATH-671">
Made EmpiricalDisribution smoothing kernel pluggable.
</action>
<action dev="psteitz" type="add" issue="MATH-946" due-to="Jared Becksfort">
Added array-scaling methods to MathArrays.
</action>

View File

@ -110,6 +110,9 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
/** Serializable version identifier */
private static final long serialVersionUID = 5729073523949762654L;
/** RandomDataGenerator instance to use in repeated calls to getNext() */
protected final RandomDataGenerator randomData;
/** List of SummaryStatistics objects characterizing the bins */
private final List<SummaryStatistics> binStats;
@ -134,9 +137,6 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
/** upper bounds of subintervals in (0,1) "belonging" to the bins */
private double[] upperBounds = null;
/** RandomDataGenerator instance to use in repeated calls to getNext() */
private final RandomDataGenerator randomData;
/**
* Creates a new EmpiricalDistribution with the default bin count.
*/
@ -487,8 +487,7 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
SummaryStatistics stats = binStats.get(i);
if (stats.getN() > 0) {
if (stats.getStandardDeviation() > 0) { // more than one obs
return randomData.nextGaussian(stats.getMean(),
stats.getStandardDeviation());
return getKernel(stats).sample();
} else {
return stats.getMean(); // only one obs in bin
}
@ -842,9 +841,10 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
* @param bStats summary statistics for the bin
* @return within-bin kernel parameterized by bStats
*/
private RealDistribution getKernel(SummaryStatistics bStats) {
// For now, hard-code Gaussian (only kernel supported)
return new NormalDistribution(
bStats.getMean(), bStats.getStandardDeviation());
protected RealDistribution getKernel(SummaryStatistics bStats) {
// Default to Gaussian
return new NormalDistribution(randomData.getRandomGenerator(),
bStats.getMean(), bStats.getStandardDeviation(),
NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
}
}

View File

@ -163,7 +163,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
}
// Get a random number generator
RandomGenerator ran = getRan();
RandomGenerator ran = getRandomGenerator();
// Initialize output buffer
StringBuilder outBuffer = new StringBuilder();
@ -202,7 +202,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
if (max <= 0) {
// the range is too wide to fit in a positive int (larger than 2^31); as it covers
// more than half the integer range, we use directly a simple rejection method
final RandomGenerator rng = getRan();
final RandomGenerator rng = getRandomGenerator();
while (true) {
final int r = rng.nextInt();
if (r >= lower && r <= upper) {
@ -211,7 +211,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
}
} else {
// we can shift the range and generate directly a positive int
return lower + getRan().nextInt(max);
return lower + getRandomGenerator().nextInt(max);
}
}
@ -225,7 +225,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
if (max <= 0) {
// the range is too wide to fit in a positive long (larger than 2^63); as it covers
// more than half the long range, we use directly a simple rejection method
final RandomGenerator rng = getRan();
final RandomGenerator rng = getRandomGenerator();
while (true) {
final long r = rng.nextLong();
if (r >= lower && r <= upper) {
@ -234,10 +234,10 @@ public class RandomDataGenerator implements RandomData, Serializable {
}
} else if (max < Integer.MAX_VALUE){
// we can shift the range and generate directly a positive int
return lower + getRan().nextInt((int) max);
return lower + getRandomGenerator().nextInt((int) max);
} else {
// we can shift the range and generate directly a positive long
return lower + nextLong(getRan(), max);
return lower + nextLong(getRandomGenerator(), max);
}
}
@ -433,7 +433,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* @throws NotStrictlyPositiveException if {@code len <= 0}
*/
public long nextPoisson(double mean) throws NotStrictlyPositiveException {
return new PoissonDistribution(getRan(), mean,
return new PoissonDistribution(getRandomGenerator(), mean,
PoissonDistribution.DEFAULT_EPSILON,
PoissonDistribution.DEFAULT_MAX_ITERATIONS).sample();
}
@ -443,7 +443,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
if (sigma <= 0) {
throw new NotStrictlyPositiveException(LocalizedFormats.STANDARD_DEVIATION, sigma);
}
return sigma * getRan().nextGaussian() + mu;
return sigma * getRandomGenerator().nextGaussian() + mu;
}
/**
@ -458,7 +458,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* </p>
*/
public double nextExponential(double mean) throws NotStrictlyPositiveException {
return new ExponentialDistribution(getRan(), mean,
return new ExponentialDistribution(getRandomGenerator(), mean,
ExponentialDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
}
@ -485,7 +485,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* {@code scale <= 0}.
*/
public double nextGamma(double shape, double scale) throws NotStrictlyPositiveException {
return new GammaDistribution(getRan(),shape, scale,
return new GammaDistribution(getRandomGenerator(),shape, scale,
GammaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
}
@ -502,7 +502,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* @throws NotPositiveException if {@code numberOfSuccesses < 0}.
*/
public int nextHypergeometric(int populationSize, int numberOfSuccesses, int sampleSize) throws NotPositiveException, NotStrictlyPositiveException, NumberIsTooLargeException {
return new HypergeometricDistribution(getRan(),populationSize,
return new HypergeometricDistribution(getRandomGenerator(),populationSize,
numberOfSuccesses, sampleSize).sample();
}
@ -517,7 +517,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* range {@code [0, 1]}.
*/
public int nextPascal(int r, double p) throws NotStrictlyPositiveException, OutOfRangeException {
return new PascalDistribution(getRan(), r, p).sample();
return new PascalDistribution(getRandomGenerator(), r, p).sample();
}
/**
@ -528,7 +528,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* @throws NotStrictlyPositiveException if {@code df <= 0}
*/
public double nextT(double df) throws NotStrictlyPositiveException {
return new TDistribution(getRan(), df,
return new TDistribution(getRandomGenerator(), df,
TDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
}
@ -542,7 +542,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* {@code scale <= 0}.
*/
public double nextWeibull(double shape, double scale) throws NotStrictlyPositiveException {
return new WeibullDistribution(getRan(), shape, scale,
return new WeibullDistribution(getRandomGenerator(), shape, scale,
WeibullDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
}
@ -556,7 +556,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* or {@code exponent <= 0}.
*/
public int nextZipf(int numberOfElements, double exponent) throws NotStrictlyPositiveException {
return new ZipfDistribution(getRan(), numberOfElements, exponent).sample();
return new ZipfDistribution(getRandomGenerator(), numberOfElements, exponent).sample();
}
/**
@ -567,7 +567,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* @return random value sampled from the beta(alpha, beta) distribution
*/
public double nextBeta(double alpha, double beta) {
return new BetaDistribution(getRan(), alpha, beta,
return new BetaDistribution(getRandomGenerator(), alpha, beta,
BetaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
}
@ -579,7 +579,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* @return random value sampled from the Binomial(numberOfTrials, probabilityOfSuccess) distribution
*/
public int nextBinomial(int numberOfTrials, double probabilityOfSuccess) {
return new BinomialDistribution(getRan(), numberOfTrials, probabilityOfSuccess).sample();
return new BinomialDistribution(getRandomGenerator(), numberOfTrials, probabilityOfSuccess).sample();
}
/**
@ -590,7 +590,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* @return random value sampled from the Cauchy(median, scale) distribution
*/
public double nextCauchy(double median, double scale) {
return new CauchyDistribution(getRan(), median, scale,
return new CauchyDistribution(getRandomGenerator(), median, scale,
CauchyDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
}
@ -601,7 +601,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* @return random value sampled from the ChiSquare(df) distribution
*/
public double nextChiSquare(double df) {
return new ChiSquaredDistribution(getRan(), df,
return new ChiSquaredDistribution(getRandomGenerator(), df,
ChiSquaredDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
}
@ -615,7 +615,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* {@code numeratorDf <= 0} or {@code denominatorDf <= 0}.
*/
public double nextF(double numeratorDf, double denominatorDf) throws NotStrictlyPositiveException {
return new FDistribution(getRan(), numeratorDf, denominatorDf,
return new FDistribution(getRandomGenerator(), numeratorDf, denominatorDf,
FDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
}
@ -671,7 +671,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
throw new NotANumberException();
}
final RandomGenerator generator = getRan();
final RandomGenerator generator = getRandomGenerator();
// ensure nextDouble() isn't 0.0
double u = generator.nextDouble();
@ -758,7 +758,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* @param seed the seed value to use
*/
public void reSeed(long seed) {
getRan().setSeed(seed);
getRandomGenerator().setSeed(seed);
}
/**
@ -789,7 +789,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
* {@code System.currentTimeMillis() + System.identityHashCode(this))}.
*/
public void reSeed() {
getRan().setSeed(System.currentTimeMillis() + System.identityHashCode(this));
getRandomGenerator().setSeed(System.currentTimeMillis() + System.identityHashCode(this));
}
/**
@ -823,7 +823,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
*
* @return the Random used to generate random data
*/
private RandomGenerator getRan() {
public RandomGenerator getRandomGenerator() {
if (rand == null) {
initRan();
}

View File

@ -22,15 +22,19 @@ import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.math3.TestUtils;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.integration.BaseAbstractUnivariateIntegrator;
import org.apache.commons.math3.analysis.integration.IterativeLegendreGaussIntegrator;
import org.apache.commons.math3.distribution.AbstractRealDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.distribution.RealDistributionAbstractTest;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.exception.NullArgumentException;
import org.apache.commons.math3.exception.OutOfRangeException;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.junit.Assert;
import org.junit.Before;
@ -428,4 +432,152 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
return new NormalDistribution((upper + lower + 1) / 2d, 3.0276503540974917);
}
}
@Test
public void testKernelOverrideConstant() {
final EmpiricalDistribution dist = new ConstantKernelEmpiricalDistribution(5);
final double[] data = {1d,2d,3d, 4d,5d,6d, 7d,8d,9d, 10d,11d,12d, 13d,14d,15d};
dist.load(data);
// Bin masses concentrated on 2, 5, 8, 11, 14 <- effectively discrete uniform distribution over these
double[] values = {2d, 5d, 8d, 11d, 14d};
for (int i = 0; i < 20; i++) {
Assert.assertTrue(Arrays.binarySearch(values, dist.sample()) >= 0);
}
final double tol = 10E-12;
Assert.assertEquals(0.0, dist.cumulativeProbability(1), tol);
Assert.assertEquals(0.2, dist.cumulativeProbability(2), tol);
Assert.assertEquals(0.6, dist.cumulativeProbability(10), tol);
Assert.assertEquals(0.8, dist.cumulativeProbability(12), tol);
Assert.assertEquals(0.8, dist.cumulativeProbability(13), tol);
Assert.assertEquals(1.0, dist.cumulativeProbability(15), tol);
Assert.assertEquals(2.0, dist.inverseCumulativeProbability(0.1), tol);
Assert.assertEquals(2.0, dist.inverseCumulativeProbability(0.2), tol);
Assert.assertEquals(5.0, dist.inverseCumulativeProbability(0.3), tol);
Assert.assertEquals(5.0, dist.inverseCumulativeProbability(0.4), tol);
Assert.assertEquals(8.0, dist.inverseCumulativeProbability(0.5), tol);
Assert.assertEquals(8.0, dist.inverseCumulativeProbability(0.6), tol);
}
@Test
public void testKernelOverrideUniform() {
final EmpiricalDistribution dist = new UniformKernelEmpiricalDistribution(5);
final double[] data = {1d,2d,3d, 4d,5d,6d, 7d,8d,9d, 10d,11d,12d, 13d,14d,15d};
dist.load(data);
// Kernels are uniform distributions on [1,3], [4,6], [7,9], [10,12], [13,15]
final double bounds[] = {3d, 6d, 9d, 12d};
final double tol = 10E-12;
for (int i = 0; i < 20; i++) {
final double v = dist.sample();
// Make sure v is not in the excluded range between bins - that is (bounds[i], bounds[i] + 1)
for (int j = 0; j < bounds.length; j++) {
Assert.assertFalse(v > bounds[j] + tol && v < bounds[j] + 1 - tol);
}
}
Assert.assertEquals(0.0, dist.cumulativeProbability(1), tol);
Assert.assertEquals(0.1, dist.cumulativeProbability(2), tol);
Assert.assertEquals(0.6, dist.cumulativeProbability(10), tol);
Assert.assertEquals(0.8, dist.cumulativeProbability(12), tol);
Assert.assertEquals(0.8, dist.cumulativeProbability(13), tol);
Assert.assertEquals(1.0, dist.cumulativeProbability(15), tol);
Assert.assertEquals(2.0, dist.inverseCumulativeProbability(0.1), tol);
Assert.assertEquals(3.0, dist.inverseCumulativeProbability(0.2), tol);
Assert.assertEquals(5.0, dist.inverseCumulativeProbability(0.3), tol);
Assert.assertEquals(6.0, dist.inverseCumulativeProbability(0.4), tol);
Assert.assertEquals(8.0, dist.inverseCumulativeProbability(0.5), tol);
Assert.assertEquals(9.0, dist.inverseCumulativeProbability(0.6), tol);
}
/**
* Empirical distribution using a constant smoothing kernel.
*/
private class ConstantKernelEmpiricalDistribution extends EmpiricalDistribution {
private static final long serialVersionUID = 1L;
public ConstantKernelEmpiricalDistribution(int i) {
super(i);
}
// Use constant distribution equal to bin mean within bin
protected RealDistribution getKernel(SummaryStatistics bStats) {
return new ConstantDistribution(bStats.getMean());
}
}
/**
* Empirical distribution using a uniform smoothing kernel.
*/
private class UniformKernelEmpiricalDistribution extends EmpiricalDistribution {
public UniformKernelEmpiricalDistribution(int i) {
super(i);
}
protected RealDistribution getKernel(SummaryStatistics bStats) {
return new UniformRealDistribution(randomData.getRandomGenerator(), bStats.getMin(), bStats.getMax(),
UniformRealDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
}
}
/**
* Distribution that takes just one value.
*/
private class ConstantDistribution extends AbstractRealDistribution {
private static final long serialVersionUID = 1L;
/** Singleton value in the sample space */
private final double c;
public ConstantDistribution(double c) {
this.c = c;
}
public double density(double x) {
return 0;
}
public double cumulativeProbability(double x) {
return x < c ? 0 : 1;
}
@Override
public double inverseCumulativeProbability(double p) {
if (p < 0.0 || p > 1.0) {
throw new OutOfRangeException(p, 0, 1);
}
return c;
}
public double getNumericalMean() {
return c;
}
public double getNumericalVariance() {
return 0;
}
public double getSupportLowerBound() {
return c;
}
public double getSupportUpperBound() {
return c;
}
public boolean isSupportLowerBoundInclusive() {
return false;
}
public boolean isSupportUpperBoundInclusive() {
return true;
}
public boolean isSupportConnected() {
return true;
}
@Override
public double sample() {
return c;
}
}
}