diff --git a/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java b/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java index 2ca612e47..149f4035e 100644 --- a/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java @@ -30,6 +30,7 @@ import org.apache.commons.math4.exception.NotStrictlyPositiveException; import org.apache.commons.math4.exception.NullArgumentException; import org.apache.commons.math4.exception.util.LocalizedFormats; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.DiscreteProbabilityCollectionSampler; import org.apache.commons.math4.util.MathArrays; import org.apache.commons.math4.util.Pair; @@ -172,14 +173,14 @@ public class EnumeratedDistribution implements Serializable { * Sampler functionality. */ public class Sampler { - /** RNG. */ - private final UniformRandomProvider random; + /** Underlying sampler. */ + private final DiscreteProbabilityCollectionSampler sampler; /** * @param rng Random number generator. */ Sampler(UniformRandomProvider rng) { - random = rng; + sampler = new DiscreteProbabilityCollectionSampler(rng, singletons, probabilities); } /** @@ -188,23 +189,7 @@ public class EnumeratedDistribution implements Serializable { * @return a random value. */ public T sample() { - final double randomValue = random.nextDouble(); - - int index = Arrays.binarySearch(cumulativeProbabilities, randomValue); - if (index < 0) { - index = -index - 1; - } - - if (index >= 0 && - index < probabilities.length && - randomValue < cumulativeProbabilities[index]) { - return singletons.get(index); - } - - // This should never happen, but it ensures we will return a correct - // object in case there is some floating point inequality problem - // wrt the cumulative probabilities. - return singletons.get(singletons.size() - 1); + return sampler.sample(); } /**