Made EmpiricalDistribution smoothing kernel pluggable.
To enable subclasses to provide the enclosing distribution's underlying RandomGenerator to distribution constructors, two more changes were required: * In EmpiricalDistribution, the RandomDataGenerator field (randomData) was changed from private to protected. * The private getRan() method in RandomDataGenerator returning the underlying RandomGenerator was renamed getRandomGenerator and made public. JIRA: MATH-671 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1457372 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
7dd09bfb64
commit
5d5f253240
|
@ -55,6 +55,9 @@ This is a minor release: It combines bug fixes and new features.
|
|||
Changes to existing features were made in a backwards-compatible
|
||||
way such as to allow drop-in replacement of the v3.1[.1] JAR file.
|
||||
">
|
||||
<action dev="psteitz" type="update" issue="MATH-671">
|
||||
Made EmpiricalDisribution smoothing kernel pluggable.
|
||||
</action>
|
||||
<action dev="psteitz" type="add" issue="MATH-946" due-to="Jared Becksfort">
|
||||
Added array-scaling methods to MathArrays.
|
||||
</action>
|
||||
|
|
|
@ -110,6 +110,9 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
|
|||
/** Serializable version identifier */
|
||||
private static final long serialVersionUID = 5729073523949762654L;
|
||||
|
||||
/** RandomDataGenerator instance to use in repeated calls to getNext() */
|
||||
protected final RandomDataGenerator randomData;
|
||||
|
||||
/** List of SummaryStatistics objects characterizing the bins */
|
||||
private final List<SummaryStatistics> binStats;
|
||||
|
||||
|
@ -134,9 +137,6 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
|
|||
/** upper bounds of subintervals in (0,1) "belonging" to the bins */
|
||||
private double[] upperBounds = null;
|
||||
|
||||
/** RandomDataGenerator instance to use in repeated calls to getNext() */
|
||||
private final RandomDataGenerator randomData;
|
||||
|
||||
/**
|
||||
* Creates a new EmpiricalDistribution with the default bin count.
|
||||
*/
|
||||
|
@ -487,8 +487,7 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
|
|||
SummaryStatistics stats = binStats.get(i);
|
||||
if (stats.getN() > 0) {
|
||||
if (stats.getStandardDeviation() > 0) { // more than one obs
|
||||
return randomData.nextGaussian(stats.getMean(),
|
||||
stats.getStandardDeviation());
|
||||
return getKernel(stats).sample();
|
||||
} else {
|
||||
return stats.getMean(); // only one obs in bin
|
||||
}
|
||||
|
@ -842,9 +841,10 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
|
|||
* @param bStats summary statistics for the bin
|
||||
* @return within-bin kernel parameterized by bStats
|
||||
*/
|
||||
private RealDistribution getKernel(SummaryStatistics bStats) {
|
||||
// For now, hard-code Gaussian (only kernel supported)
|
||||
return new NormalDistribution(
|
||||
bStats.getMean(), bStats.getStandardDeviation());
|
||||
protected RealDistribution getKernel(SummaryStatistics bStats) {
|
||||
// Default to Gaussian
|
||||
return new NormalDistribution(randomData.getRandomGenerator(),
|
||||
bStats.getMean(), bStats.getStandardDeviation(),
|
||||
NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -163,7 +163,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
}
|
||||
|
||||
// Get a random number generator
|
||||
RandomGenerator ran = getRan();
|
||||
RandomGenerator ran = getRandomGenerator();
|
||||
|
||||
// Initialize output buffer
|
||||
StringBuilder outBuffer = new StringBuilder();
|
||||
|
@ -202,7 +202,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
if (max <= 0) {
|
||||
// the range is too wide to fit in a positive int (larger than 2^31); as it covers
|
||||
// more than half the integer range, we use directly a simple rejection method
|
||||
final RandomGenerator rng = getRan();
|
||||
final RandomGenerator rng = getRandomGenerator();
|
||||
while (true) {
|
||||
final int r = rng.nextInt();
|
||||
if (r >= lower && r <= upper) {
|
||||
|
@ -211,7 +211,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
}
|
||||
} else {
|
||||
// we can shift the range and generate directly a positive int
|
||||
return lower + getRan().nextInt(max);
|
||||
return lower + getRandomGenerator().nextInt(max);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -225,7 +225,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
if (max <= 0) {
|
||||
// the range is too wide to fit in a positive long (larger than 2^63); as it covers
|
||||
// more than half the long range, we use directly a simple rejection method
|
||||
final RandomGenerator rng = getRan();
|
||||
final RandomGenerator rng = getRandomGenerator();
|
||||
while (true) {
|
||||
final long r = rng.nextLong();
|
||||
if (r >= lower && r <= upper) {
|
||||
|
@ -234,10 +234,10 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
}
|
||||
} else if (max < Integer.MAX_VALUE){
|
||||
// we can shift the range and generate directly a positive int
|
||||
return lower + getRan().nextInt((int) max);
|
||||
return lower + getRandomGenerator().nextInt((int) max);
|
||||
} else {
|
||||
// we can shift the range and generate directly a positive long
|
||||
return lower + nextLong(getRan(), max);
|
||||
return lower + nextLong(getRandomGenerator(), max);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -433,7 +433,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* @throws NotStrictlyPositiveException if {@code len <= 0}
|
||||
*/
|
||||
public long nextPoisson(double mean) throws NotStrictlyPositiveException {
|
||||
return new PoissonDistribution(getRan(), mean,
|
||||
return new PoissonDistribution(getRandomGenerator(), mean,
|
||||
PoissonDistribution.DEFAULT_EPSILON,
|
||||
PoissonDistribution.DEFAULT_MAX_ITERATIONS).sample();
|
||||
}
|
||||
|
@ -443,7 +443,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
if (sigma <= 0) {
|
||||
throw new NotStrictlyPositiveException(LocalizedFormats.STANDARD_DEVIATION, sigma);
|
||||
}
|
||||
return sigma * getRan().nextGaussian() + mu;
|
||||
return sigma * getRandomGenerator().nextGaussian() + mu;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -458,7 +458,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* </p>
|
||||
*/
|
||||
public double nextExponential(double mean) throws NotStrictlyPositiveException {
|
||||
return new ExponentialDistribution(getRan(), mean,
|
||||
return new ExponentialDistribution(getRandomGenerator(), mean,
|
||||
ExponentialDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
|
||||
}
|
||||
|
||||
|
@ -485,7 +485,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* {@code scale <= 0}.
|
||||
*/
|
||||
public double nextGamma(double shape, double scale) throws NotStrictlyPositiveException {
|
||||
return new GammaDistribution(getRan(),shape, scale,
|
||||
return new GammaDistribution(getRandomGenerator(),shape, scale,
|
||||
GammaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
|
||||
}
|
||||
|
||||
|
@ -502,7 +502,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* @throws NotPositiveException if {@code numberOfSuccesses < 0}.
|
||||
*/
|
||||
public int nextHypergeometric(int populationSize, int numberOfSuccesses, int sampleSize) throws NotPositiveException, NotStrictlyPositiveException, NumberIsTooLargeException {
|
||||
return new HypergeometricDistribution(getRan(),populationSize,
|
||||
return new HypergeometricDistribution(getRandomGenerator(),populationSize,
|
||||
numberOfSuccesses, sampleSize).sample();
|
||||
}
|
||||
|
||||
|
@ -517,7 +517,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* range {@code [0, 1]}.
|
||||
*/
|
||||
public int nextPascal(int r, double p) throws NotStrictlyPositiveException, OutOfRangeException {
|
||||
return new PascalDistribution(getRan(), r, p).sample();
|
||||
return new PascalDistribution(getRandomGenerator(), r, p).sample();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -528,7 +528,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* @throws NotStrictlyPositiveException if {@code df <= 0}
|
||||
*/
|
||||
public double nextT(double df) throws NotStrictlyPositiveException {
|
||||
return new TDistribution(getRan(), df,
|
||||
return new TDistribution(getRandomGenerator(), df,
|
||||
TDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
|
||||
}
|
||||
|
||||
|
@ -542,7 +542,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* {@code scale <= 0}.
|
||||
*/
|
||||
public double nextWeibull(double shape, double scale) throws NotStrictlyPositiveException {
|
||||
return new WeibullDistribution(getRan(), shape, scale,
|
||||
return new WeibullDistribution(getRandomGenerator(), shape, scale,
|
||||
WeibullDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
|
||||
}
|
||||
|
||||
|
@ -556,7 +556,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* or {@code exponent <= 0}.
|
||||
*/
|
||||
public int nextZipf(int numberOfElements, double exponent) throws NotStrictlyPositiveException {
|
||||
return new ZipfDistribution(getRan(), numberOfElements, exponent).sample();
|
||||
return new ZipfDistribution(getRandomGenerator(), numberOfElements, exponent).sample();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -567,7 +567,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* @return random value sampled from the beta(alpha, beta) distribution
|
||||
*/
|
||||
public double nextBeta(double alpha, double beta) {
|
||||
return new BetaDistribution(getRan(), alpha, beta,
|
||||
return new BetaDistribution(getRandomGenerator(), alpha, beta,
|
||||
BetaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
|
||||
}
|
||||
|
||||
|
@ -579,7 +579,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* @return random value sampled from the Binomial(numberOfTrials, probabilityOfSuccess) distribution
|
||||
*/
|
||||
public int nextBinomial(int numberOfTrials, double probabilityOfSuccess) {
|
||||
return new BinomialDistribution(getRan(), numberOfTrials, probabilityOfSuccess).sample();
|
||||
return new BinomialDistribution(getRandomGenerator(), numberOfTrials, probabilityOfSuccess).sample();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -590,7 +590,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* @return random value sampled from the Cauchy(median, scale) distribution
|
||||
*/
|
||||
public double nextCauchy(double median, double scale) {
|
||||
return new CauchyDistribution(getRan(), median, scale,
|
||||
return new CauchyDistribution(getRandomGenerator(), median, scale,
|
||||
CauchyDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
|
||||
}
|
||||
|
||||
|
@ -601,7 +601,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* @return random value sampled from the ChiSquare(df) distribution
|
||||
*/
|
||||
public double nextChiSquare(double df) {
|
||||
return new ChiSquaredDistribution(getRan(), df,
|
||||
return new ChiSquaredDistribution(getRandomGenerator(), df,
|
||||
ChiSquaredDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
|
||||
}
|
||||
|
||||
|
@ -615,7 +615,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* {@code numeratorDf <= 0} or {@code denominatorDf <= 0}.
|
||||
*/
|
||||
public double nextF(double numeratorDf, double denominatorDf) throws NotStrictlyPositiveException {
|
||||
return new FDistribution(getRan(), numeratorDf, denominatorDf,
|
||||
return new FDistribution(getRandomGenerator(), numeratorDf, denominatorDf,
|
||||
FDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample();
|
||||
}
|
||||
|
||||
|
@ -671,7 +671,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
throw new NotANumberException();
|
||||
}
|
||||
|
||||
final RandomGenerator generator = getRan();
|
||||
final RandomGenerator generator = getRandomGenerator();
|
||||
|
||||
// ensure nextDouble() isn't 0.0
|
||||
double u = generator.nextDouble();
|
||||
|
@ -758,7 +758,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* @param seed the seed value to use
|
||||
*/
|
||||
public void reSeed(long seed) {
|
||||
getRan().setSeed(seed);
|
||||
getRandomGenerator().setSeed(seed);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -789,7 +789,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
* {@code System.currentTimeMillis() + System.identityHashCode(this))}.
|
||||
*/
|
||||
public void reSeed() {
|
||||
getRan().setSeed(System.currentTimeMillis() + System.identityHashCode(this));
|
||||
getRandomGenerator().setSeed(System.currentTimeMillis() + System.identityHashCode(this));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -823,7 +823,7 @@ public class RandomDataGenerator implements RandomData, Serializable {
|
|||
*
|
||||
* @return the Random used to generate random data
|
||||
*/
|
||||
private RandomGenerator getRan() {
|
||||
public RandomGenerator getRandomGenerator() {
|
||||
if (rand == null) {
|
||||
initRan();
|
||||
}
|
||||
|
|
|
@ -22,15 +22,19 @@ import java.io.IOException;
|
|||
import java.io.InputStreamReader;
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
||||
import org.apache.commons.math3.TestUtils;
|
||||
import org.apache.commons.math3.analysis.UnivariateFunction;
|
||||
import org.apache.commons.math3.analysis.integration.BaseAbstractUnivariateIntegrator;
|
||||
import org.apache.commons.math3.analysis.integration.IterativeLegendreGaussIntegrator;
|
||||
import org.apache.commons.math3.distribution.AbstractRealDistribution;
|
||||
import org.apache.commons.math3.distribution.NormalDistribution;
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.apache.commons.math3.distribution.RealDistributionAbstractTest;
|
||||
import org.apache.commons.math3.distribution.UniformRealDistribution;
|
||||
import org.apache.commons.math3.exception.NullArgumentException;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
|
@ -428,4 +432,152 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
|
|||
return new NormalDistribution((upper + lower + 1) / 2d, 3.0276503540974917);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKernelOverrideConstant() {
|
||||
final EmpiricalDistribution dist = new ConstantKernelEmpiricalDistribution(5);
|
||||
final double[] data = {1d,2d,3d, 4d,5d,6d, 7d,8d,9d, 10d,11d,12d, 13d,14d,15d};
|
||||
dist.load(data);
|
||||
// Bin masses concentrated on 2, 5, 8, 11, 14 <- effectively discrete uniform distribution over these
|
||||
double[] values = {2d, 5d, 8d, 11d, 14d};
|
||||
for (int i = 0; i < 20; i++) {
|
||||
Assert.assertTrue(Arrays.binarySearch(values, dist.sample()) >= 0);
|
||||
}
|
||||
final double tol = 10E-12;
|
||||
Assert.assertEquals(0.0, dist.cumulativeProbability(1), tol);
|
||||
Assert.assertEquals(0.2, dist.cumulativeProbability(2), tol);
|
||||
Assert.assertEquals(0.6, dist.cumulativeProbability(10), tol);
|
||||
Assert.assertEquals(0.8, dist.cumulativeProbability(12), tol);
|
||||
Assert.assertEquals(0.8, dist.cumulativeProbability(13), tol);
|
||||
Assert.assertEquals(1.0, dist.cumulativeProbability(15), tol);
|
||||
|
||||
Assert.assertEquals(2.0, dist.inverseCumulativeProbability(0.1), tol);
|
||||
Assert.assertEquals(2.0, dist.inverseCumulativeProbability(0.2), tol);
|
||||
Assert.assertEquals(5.0, dist.inverseCumulativeProbability(0.3), tol);
|
||||
Assert.assertEquals(5.0, dist.inverseCumulativeProbability(0.4), tol);
|
||||
Assert.assertEquals(8.0, dist.inverseCumulativeProbability(0.5), tol);
|
||||
Assert.assertEquals(8.0, dist.inverseCumulativeProbability(0.6), tol);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKernelOverrideUniform() {
|
||||
final EmpiricalDistribution dist = new UniformKernelEmpiricalDistribution(5);
|
||||
final double[] data = {1d,2d,3d, 4d,5d,6d, 7d,8d,9d, 10d,11d,12d, 13d,14d,15d};
|
||||
dist.load(data);
|
||||
// Kernels are uniform distributions on [1,3], [4,6], [7,9], [10,12], [13,15]
|
||||
final double bounds[] = {3d, 6d, 9d, 12d};
|
||||
final double tol = 10E-12;
|
||||
for (int i = 0; i < 20; i++) {
|
||||
final double v = dist.sample();
|
||||
// Make sure v is not in the excluded range between bins - that is (bounds[i], bounds[i] + 1)
|
||||
for (int j = 0; j < bounds.length; j++) {
|
||||
Assert.assertFalse(v > bounds[j] + tol && v < bounds[j] + 1 - tol);
|
||||
}
|
||||
}
|
||||
Assert.assertEquals(0.0, dist.cumulativeProbability(1), tol);
|
||||
Assert.assertEquals(0.1, dist.cumulativeProbability(2), tol);
|
||||
Assert.assertEquals(0.6, dist.cumulativeProbability(10), tol);
|
||||
Assert.assertEquals(0.8, dist.cumulativeProbability(12), tol);
|
||||
Assert.assertEquals(0.8, dist.cumulativeProbability(13), tol);
|
||||
Assert.assertEquals(1.0, dist.cumulativeProbability(15), tol);
|
||||
|
||||
Assert.assertEquals(2.0, dist.inverseCumulativeProbability(0.1), tol);
|
||||
Assert.assertEquals(3.0, dist.inverseCumulativeProbability(0.2), tol);
|
||||
Assert.assertEquals(5.0, dist.inverseCumulativeProbability(0.3), tol);
|
||||
Assert.assertEquals(6.0, dist.inverseCumulativeProbability(0.4), tol);
|
||||
Assert.assertEquals(8.0, dist.inverseCumulativeProbability(0.5), tol);
|
||||
Assert.assertEquals(9.0, dist.inverseCumulativeProbability(0.6), tol);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Empirical distribution using a constant smoothing kernel.
|
||||
*/
|
||||
private class ConstantKernelEmpiricalDistribution extends EmpiricalDistribution {
|
||||
private static final long serialVersionUID = 1L;
|
||||
public ConstantKernelEmpiricalDistribution(int i) {
|
||||
super(i);
|
||||
}
|
||||
// Use constant distribution equal to bin mean within bin
|
||||
protected RealDistribution getKernel(SummaryStatistics bStats) {
|
||||
return new ConstantDistribution(bStats.getMean());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Empirical distribution using a uniform smoothing kernel.
|
||||
*/
|
||||
private class UniformKernelEmpiricalDistribution extends EmpiricalDistribution {
|
||||
public UniformKernelEmpiricalDistribution(int i) {
|
||||
super(i);
|
||||
}
|
||||
protected RealDistribution getKernel(SummaryStatistics bStats) {
|
||||
return new UniformRealDistribution(randomData.getRandomGenerator(), bStats.getMin(), bStats.getMax(),
|
||||
UniformRealDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Distribution that takes just one value.
|
||||
*/
|
||||
private class ConstantDistribution extends AbstractRealDistribution {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
/** Singleton value in the sample space */
|
||||
private final double c;
|
||||
|
||||
public ConstantDistribution(double c) {
|
||||
this.c = c;
|
||||
}
|
||||
|
||||
public double density(double x) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
public double cumulativeProbability(double x) {
|
||||
return x < c ? 0 : 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double inverseCumulativeProbability(double p) {
|
||||
if (p < 0.0 || p > 1.0) {
|
||||
throw new OutOfRangeException(p, 0, 1);
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
public double getNumericalMean() {
|
||||
return c;
|
||||
}
|
||||
|
||||
public double getNumericalVariance() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
public double getSupportLowerBound() {
|
||||
return c;
|
||||
}
|
||||
|
||||
public double getSupportUpperBound() {
|
||||
return c;
|
||||
}
|
||||
|
||||
public boolean isSupportLowerBoundInclusive() {
|
||||
return false;
|
||||
}
|
||||
|
||||
public boolean isSupportUpperBoundInclusive() {
|
||||
return true;
|
||||
}
|
||||
|
||||
public boolean isSupportConnected() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double sample() {
|
||||
return c;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue