MATH-1158.

Sampler functionality defined in "EnumeratedDistribution".
Method "createSampler" overridden in "EnumeratedRealDistribution".
This commit is contained in:
Gilles 2016-03-11 04:48:18 +01:00
parent a6eda3d8ef
commit a5035d0e1c
3 changed files with 140 additions and 2 deletions

View File

@ -31,6 +31,7 @@ import org.apache.commons.math4.exception.NullArgumentException;
import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.math4.random.RandomGenerator;
import org.apache.commons.math4.random.Well19937c;
import org.apache.commons.math4.rng.UniformRandomProvider;
import org.apache.commons.math4.util.MathArrays;
import org.apache.commons.math4.util.Pair;
@ -59,6 +60,7 @@ public class EnumeratedDistribution<T> implements Serializable {
/**
* RNG instance used to generate samples from the distribution.
*/
@Deprecated
protected final RandomGenerator random;
/**
@ -113,6 +115,7 @@ public class EnumeratedDistribution<T> implements Serializable {
* @throws NotANumberException if any of the probabilities are NaN.
* @throws MathArithmeticException all of the probabilities are 0.
*/
@Deprecated
public EnumeratedDistribution(final RandomGenerator rng, final List<Pair<T, Double>> pmf)
throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException {
random = rng;
@ -151,6 +154,7 @@ public class EnumeratedDistribution<T> implements Serializable {
*
* @param seed the new seed
*/
@Deprecated
public void reseedRandomGenerator(long seed) {
random.setSeed(seed);
}
@ -205,6 +209,7 @@ public class EnumeratedDistribution<T> implements Serializable {
*
* @return a random value.
*/
@Deprecated
public T sample() {
final double randomValue = random.nextDouble();
@ -233,6 +238,7 @@ public class EnumeratedDistribution<T> implements Serializable {
* @throws NotStrictlyPositiveException if {@code sampleSize} is not
* positive.
*/
@Deprecated
public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
if (sampleSize <= 0) {
throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
@ -262,6 +268,7 @@ public class EnumeratedDistribution<T> implements Serializable {
* @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
* @throws NullArgumentException if {@code array} is null
*/
@Deprecated
public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
if (sampleSize <= 0) {
throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
@ -288,4 +295,113 @@ public class EnumeratedDistribution<T> implements Serializable {
}
/**
* Creates a {@link Sampler}.
*
* @param rng Random number generator.
*/
public Sampler createSampler(final UniformRandomProvider rng) {
return new Sampler(rng);
}
/**
* Sampler functionality.
*/
public class Sampler {
/** RNG. */
private final UniformRandomProvider random;
/**
* @param rng Random number generator.
*/
Sampler(UniformRandomProvider rng) {
random = rng;
}
/**
* Generates a random value sampled from this distribution.
*
* @return a random value.
*/
public T sample() {
final double randomValue = random.nextDouble();
int index = Arrays.binarySearch(cumulativeProbabilities, randomValue);
if (index < 0) {
index = -index - 1;
}
if (index >= 0 &&
index < probabilities.length &&
randomValue < cumulativeProbabilities[index]) {
return singletons.get(index);
}
// This should never happen, but it ensures we will return a correct
// object in case there is some floating point inequality problem
// wrt the cumulative probabilities.
return singletons.get(singletons.size() - 1);
}
/**
* Generates a random sample from the distribution.
*
* @param sampleSize the number of random values to generate.
* @return an array representing the random sample.
* @throws NotStrictlyPositiveException if {@code sampleSize} is not
* positive.
*/
public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
if (sampleSize <= 0) {
throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
sampleSize);
}
final Object[] out = new Object[sampleSize];
for (int i = 0; i < sampleSize; i++) {
out[i] = sample();
}
return out;
}
/**
* Generates a random sample from the distribution.
* <p>
* If the requested samples fit in the specified array, it is returned
* therein. Otherwise, a new array is allocated with the runtime type of
* the specified array and the size of this collection.
*
* @param sampleSize the number of random values to generate.
* @param array the array to populate.
* @return an array representing the random sample.
* @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
* @throws NullArgumentException if {@code array} is null
*/
public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
if (sampleSize <= 0) {
throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
}
if (array == null) {
throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
}
T[] out;
if (array.length < sampleSize) {
@SuppressWarnings("unchecked") // safe as both are of type T
final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize);
out = unchecked;
} else {
out = array;
}
for (int i = 0; i < sampleSize; i++) {
out[i] = sample();
}
return out;
}
}
}

View File

@ -30,6 +30,7 @@ import org.apache.commons.math4.exception.NotPositiveException;
import org.apache.commons.math4.exception.OutOfRangeException;
import org.apache.commons.math4.random.RandomGenerator;
import org.apache.commons.math4.random.Well19937c;
import org.apache.commons.math4.rng.UniformRandomProvider;
import org.apache.commons.math4.util.Pair;
/**
@ -93,6 +94,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
* @throws NotANumberException if any of the probabilities are NaN.
* @throws MathArithmeticException all of the probabilities are 0.
*/
@Deprecated
public EnumeratedRealDistribution(final RandomGenerator rng,
final double[] singletons, final double[] probabilities)
throws DimensionMismatchException, NotPositiveException, MathArithmeticException,
@ -111,6 +113,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
* @param data input dataset
* @since 3.6
*/
@Deprecated
public EnumeratedRealDistribution(final RandomGenerator rng, final double[] data) {
super(rng);
final Map<Double, Integer> dataMap = new HashMap<Double, Integer>();
@ -319,7 +322,24 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
* {@inheritDoc}
*/
@Override
@Deprecated
public double sample() {
return innerDistribution.sample();
}
/** {@inheritDoc} */
@Override
public RealDistribution.Sampler createSampler(final UniformRandomProvider rng) {
return new RealDistribution.Sampler() {
/** Delegate. */
private final EnumeratedDistribution<Double>.Sampler inner =
innerDistribution.createSampler(rng);
/** {@inheritDoc} */
@Override
public double sample() {
return inner.sample();
}
};
}
}

View File

@ -30,6 +30,7 @@ import org.apache.commons.math4.exception.NotFiniteNumberException;
import org.apache.commons.math4.exception.NotPositiveException;
import org.apache.commons.math4.util.FastMath;
import org.apache.commons.math4.util.Pair;
import org.apache.commons.math4.rng.RandomSource;
import org.junit.Assert;
import org.junit.Test;
@ -175,8 +176,9 @@ public class EnumeratedRealDistributionTest {
@Test
public void testSample() {
final int n = 1000000;
testDistribution.reseedRandomGenerator(-334759360); // fixed seed
final double[] samples = testDistribution.sample(n);
final RealDistribution.Sampler sampler =
testDistribution.createSampler(RandomSource.create(RandomSource.WELL_1024_A, -123456789));
final double[] samples = AbstractRealDistribution.sample(n, sampler);
Assert.assertEquals(n, samples.length);
double sum = 0;
double sumOfSquares = 0;