MATH-1158.
Sampler functionality defined in "EnumeratedDistribution". Method "createSampler" overridden in "EnumeratedRealDistribution".
This commit is contained in:
parent
a6eda3d8ef
commit
a5035d0e1c
|
@ -31,6 +31,7 @@ import org.apache.commons.math4.exception.NullArgumentException;
|
|||
import org.apache.commons.math4.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math4.random.RandomGenerator;
|
||||
import org.apache.commons.math4.random.Well19937c;
|
||||
import org.apache.commons.math4.rng.UniformRandomProvider;
|
||||
import org.apache.commons.math4.util.MathArrays;
|
||||
import org.apache.commons.math4.util.Pair;
|
||||
|
||||
|
@ -59,6 +60,7 @@ public class EnumeratedDistribution<T> implements Serializable {
|
|||
/**
|
||||
* RNG instance used to generate samples from the distribution.
|
||||
*/
|
||||
@Deprecated
|
||||
protected final RandomGenerator random;
|
||||
|
||||
/**
|
||||
|
@ -113,6 +115,7 @@ public class EnumeratedDistribution<T> implements Serializable {
|
|||
* @throws NotANumberException if any of the probabilities are NaN.
|
||||
* @throws MathArithmeticException all of the probabilities are 0.
|
||||
*/
|
||||
@Deprecated
|
||||
public EnumeratedDistribution(final RandomGenerator rng, final List<Pair<T, Double>> pmf)
|
||||
throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException {
|
||||
random = rng;
|
||||
|
@ -151,6 +154,7 @@ public class EnumeratedDistribution<T> implements Serializable {
|
|||
*
|
||||
* @param seed the new seed
|
||||
*/
|
||||
@Deprecated
|
||||
public void reseedRandomGenerator(long seed) {
|
||||
random.setSeed(seed);
|
||||
}
|
||||
|
@ -205,6 +209,7 @@ public class EnumeratedDistribution<T> implements Serializable {
|
|||
*
|
||||
* @return a random value.
|
||||
*/
|
||||
@Deprecated
|
||||
public T sample() {
|
||||
final double randomValue = random.nextDouble();
|
||||
|
||||
|
@ -233,6 +238,7 @@ public class EnumeratedDistribution<T> implements Serializable {
|
|||
* @throws NotStrictlyPositiveException if {@code sampleSize} is not
|
||||
* positive.
|
||||
*/
|
||||
@Deprecated
|
||||
public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
|
||||
if (sampleSize <= 0) {
|
||||
throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
|
||||
|
@ -262,6 +268,7 @@ public class EnumeratedDistribution<T> implements Serializable {
|
|||
* @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
|
||||
* @throws NullArgumentException if {@code array} is null
|
||||
*/
|
||||
@Deprecated
|
||||
public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
|
||||
if (sampleSize <= 0) {
|
||||
throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
|
||||
|
@ -288,4 +295,113 @@ public class EnumeratedDistribution<T> implements Serializable {
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link Sampler}.
|
||||
*
|
||||
* @param rng Random number generator.
|
||||
*/
|
||||
public Sampler createSampler(final UniformRandomProvider rng) {
|
||||
return new Sampler(rng);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sampler functionality.
|
||||
*/
|
||||
public class Sampler {
|
||||
/** RNG. */
|
||||
private final UniformRandomProvider random;
|
||||
|
||||
/**
|
||||
* @param rng Random number generator.
|
||||
*/
|
||||
Sampler(UniformRandomProvider rng) {
|
||||
random = rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a random value sampled from this distribution.
|
||||
*
|
||||
* @return a random value.
|
||||
*/
|
||||
public T sample() {
|
||||
final double randomValue = random.nextDouble();
|
||||
|
||||
int index = Arrays.binarySearch(cumulativeProbabilities, randomValue);
|
||||
if (index < 0) {
|
||||
index = -index - 1;
|
||||
}
|
||||
|
||||
if (index >= 0 &&
|
||||
index < probabilities.length &&
|
||||
randomValue < cumulativeProbabilities[index]) {
|
||||
return singletons.get(index);
|
||||
}
|
||||
|
||||
// This should never happen, but it ensures we will return a correct
|
||||
// object in case there is some floating point inequality problem
|
||||
// wrt the cumulative probabilities.
|
||||
return singletons.get(singletons.size() - 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a random sample from the distribution.
|
||||
*
|
||||
* @param sampleSize the number of random values to generate.
|
||||
* @return an array representing the random sample.
|
||||
* @throws NotStrictlyPositiveException if {@code sampleSize} is not
|
||||
* positive.
|
||||
*/
|
||||
public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
|
||||
if (sampleSize <= 0) {
|
||||
throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
|
||||
sampleSize);
|
||||
}
|
||||
|
||||
final Object[] out = new Object[sampleSize];
|
||||
|
||||
for (int i = 0; i < sampleSize; i++) {
|
||||
out[i] = sample();
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a random sample from the distribution.
|
||||
* <p>
|
||||
* If the requested samples fit in the specified array, it is returned
|
||||
* therein. Otherwise, a new array is allocated with the runtime type of
|
||||
* the specified array and the size of this collection.
|
||||
*
|
||||
* @param sampleSize the number of random values to generate.
|
||||
* @param array the array to populate.
|
||||
* @return an array representing the random sample.
|
||||
* @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
|
||||
* @throws NullArgumentException if {@code array} is null
|
||||
*/
|
||||
public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
|
||||
if (sampleSize <= 0) {
|
||||
throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
|
||||
}
|
||||
|
||||
if (array == null) {
|
||||
throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
|
||||
}
|
||||
|
||||
T[] out;
|
||||
if (array.length < sampleSize) {
|
||||
@SuppressWarnings("unchecked") // safe as both are of type T
|
||||
final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize);
|
||||
out = unchecked;
|
||||
} else {
|
||||
out = array;
|
||||
}
|
||||
|
||||
for (int i = 0; i < sampleSize; i++) {
|
||||
out[i] = sample();
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.commons.math4.exception.NotPositiveException;
|
|||
import org.apache.commons.math4.exception.OutOfRangeException;
|
||||
import org.apache.commons.math4.random.RandomGenerator;
|
||||
import org.apache.commons.math4.random.Well19937c;
|
||||
import org.apache.commons.math4.rng.UniformRandomProvider;
|
||||
import org.apache.commons.math4.util.Pair;
|
||||
|
||||
/**
|
||||
|
@ -93,6 +94,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
|
|||
* @throws NotANumberException if any of the probabilities are NaN.
|
||||
* @throws MathArithmeticException all of the probabilities are 0.
|
||||
*/
|
||||
@Deprecated
|
||||
public EnumeratedRealDistribution(final RandomGenerator rng,
|
||||
final double[] singletons, final double[] probabilities)
|
||||
throws DimensionMismatchException, NotPositiveException, MathArithmeticException,
|
||||
|
@ -111,6 +113,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
|
|||
* @param data input dataset
|
||||
* @since 3.6
|
||||
*/
|
||||
@Deprecated
|
||||
public EnumeratedRealDistribution(final RandomGenerator rng, final double[] data) {
|
||||
super(rng);
|
||||
final Map<Double, Integer> dataMap = new HashMap<Double, Integer>();
|
||||
|
@ -319,7 +322,24 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
|
|||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
@Deprecated
|
||||
public double sample() {
|
||||
return innerDistribution.sample();
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public RealDistribution.Sampler createSampler(final UniformRandomProvider rng) {
|
||||
return new RealDistribution.Sampler() {
|
||||
/** Delegate. */
|
||||
private final EnumeratedDistribution<Double>.Sampler inner =
|
||||
innerDistribution.createSampler(rng);
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public double sample() {
|
||||
return inner.sample();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.commons.math4.exception.NotFiniteNumberException;
|
|||
import org.apache.commons.math4.exception.NotPositiveException;
|
||||
import org.apache.commons.math4.util.FastMath;
|
||||
import org.apache.commons.math4.util.Pair;
|
||||
import org.apache.commons.math4.rng.RandomSource;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
|
@ -175,8 +176,9 @@ public class EnumeratedRealDistributionTest {
|
|||
@Test
|
||||
public void testSample() {
|
||||
final int n = 1000000;
|
||||
testDistribution.reseedRandomGenerator(-334759360); // fixed seed
|
||||
final double[] samples = testDistribution.sample(n);
|
||||
final RealDistribution.Sampler sampler =
|
||||
testDistribution.createSampler(RandomSource.create(RandomSource.WELL_1024_A, -123456789));
|
||||
final double[] samples = AbstractRealDistribution.sample(n, sampler);
|
||||
Assert.assertEquals(n, samples.length);
|
||||
double sum = 0;
|
||||
double sumOfSquares = 0;
|
||||
|
|
Loading…
Reference in New Issue