Use code moved to "Commons RNG" (v1.1 snapshot).

This commit is contained in:
Gilles 2018-01-25 17:24:55 +01:00
parent c4218b8385
commit c3ff46e303
1 changed files with 5 additions and 20 deletions

View File

@ -30,6 +30,7 @@ import org.apache.commons.math4.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.exception.NullArgumentException;
import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.DiscreteProbabilityCollectionSampler;
import org.apache.commons.math4.util.MathArrays;
import org.apache.commons.math4.util.Pair;
@ -172,14 +173,14 @@ public class EnumeratedDistribution<T> implements Serializable {
* Sampler functionality.
*/
public class Sampler {
/** RNG. */
private final UniformRandomProvider random;
/** Underlying sampler. */
private final DiscreteProbabilityCollectionSampler<T> sampler;
/**
* @param rng Random number generator.
*/
Sampler(UniformRandomProvider rng) {
random = rng;
sampler = new DiscreteProbabilityCollectionSampler<T>(rng, singletons, probabilities);
}
/**
@ -188,23 +189,7 @@ public class EnumeratedDistribution<T> implements Serializable {
* @return a random value.
*/
public T sample() {
final double randomValue = random.nextDouble();
int index = Arrays.binarySearch(cumulativeProbabilities, randomValue);
if (index < 0) {
index = -index - 1;
}
if (index >= 0 &&
index < probabilities.length &&
randomValue < cumulativeProbabilities[index]) {
return singletons.get(index);
}
// This should never happen, but it ensures we will return a correct
// object in case there is some floating point inequality problem
// wrt the cumulative probabilities.
return singletons.get(singletons.size() - 1);
return sampler.sample();
}
/**