[MATH-1152] Improved performance of EnumeratedDistribution#sample(). Thanks to Andras Sereny.

This commit is contained in:
Thomas Neidhart 2014-09-30 21:16:07 +02:00
parent 97d32b14e6
commit 97accb47de
2 changed files with 28 additions and 7 deletions

View File

@ -73,6 +73,10 @@ Users are encouraged to upgrade to this version as this release not
2. A few methods in the FastMath class are in fact slower that their 2. A few methods in the FastMath class are in fact slower that their
counterpart in either Math or StrictMath (cf. MATH-740 and MATH-901). counterpart in either Math or StrictMath (cf. MATH-740 and MATH-901).
"> ">
<action dev="tn" type="fix" issue="MATH-1152" due-to="Andras Sereny">
Improved performance of "EnumeratedDistribution#sample()" by caching
the cumulative probabilities and using binary rather than a linear search.
</action>
<action dev="tn" type="fix" issue="MATH-1148" due-to="Guillaume Marceau"> <action dev="tn" type="fix" issue="MATH-1148" due-to="Guillaume Marceau">
"MonotoneChain" did not take the tolerance factor into account when "MonotoneChain" did not take the tolerance factor into account when
sorting the input points. In case of collinear points this could result sorting the input points. In case of collinear points this could result

View File

@ -19,6 +19,7 @@ package org.apache.commons.math3.distribution;
import java.io.Serializable; import java.io.Serializable;
import java.lang.reflect.Array; import java.lang.reflect.Array;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import org.apache.commons.math3.exception.MathArithmeticException; import org.apache.commons.math3.exception.MathArithmeticException;
@ -64,6 +65,7 @@ public class EnumeratedDistribution<T> implements Serializable {
* List of random variable values. * List of random variable values.
*/ */
private final List<T> singletons; private final List<T> singletons;
/** /**
* Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1, * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
* probability[i] is the probability that a random variable following this distribution takes * probability[i] is the probability that a random variable following this distribution takes
@ -71,6 +73,11 @@ public class EnumeratedDistribution<T> implements Serializable {
*/ */
private final double[] probabilities; private final double[] probabilities;
/**
* Cumulative probabilities, cached to speed up sampling.
*/
private final double[] cumulativeProbabilities;
/** /**
* Create an enumerated distribution using the given probability mass function * Create an enumerated distribution using the given probability mass function
* enumeration. * enumeration.
@ -123,6 +130,13 @@ public class EnumeratedDistribution<T> implements Serializable {
} }
probabilities = MathArrays.normalizeArray(probs, 1.0); probabilities = MathArrays.normalizeArray(probs, 1.0);
cumulativeProbabilities = new double[probabilities.length];
double sum = 0;
for (int i = 0; i < probabilities.length; i++) {
sum += probabilities[i];
cumulativeProbabilities[i] = sum;
}
} }
/** /**
@ -186,18 +200,21 @@ public class EnumeratedDistribution<T> implements Serializable {
*/ */
public T sample() { public T sample() {
final double randomValue = random.nextDouble(); final double randomValue = random.nextDouble();
double sum = 0;
for (int i = 0; i < probabilities.length; i++) { int index = Arrays.binarySearch(cumulativeProbabilities, randomValue);
sum += probabilities[i]; if (index < 0) {
if (randomValue < sum) { index = -index-1;
return singletons.get(i); }
if (index >= 0 && index < probabilities.length) {
if (randomValue < cumulativeProbabilities[index]) {
return singletons.get(index);
} }
} }
/* This should never happen, but it ensures we will return a correct /* This should never happen, but it ensures we will return a correct
* object in case the loop above has some floating point inequality * object in case there is some floating point inequality problem
* problem on the final iteration. */ * wrt the cumulative probabilities. */
return singletons.get(singletons.size() - 1); return singletons.get(singletons.size() - 1);
} }