[MATH-1152] Improved performance of EnumeratedDistribution#sample(). Thanks to Andras Sereny.
This commit is contained in:
parent
97d32b14e6
commit
97accb47de
|
@ -73,6 +73,10 @@ Users are encouraged to upgrade to this version as this release not
|
|||
2. A few methods in the FastMath class are in fact slower that their
|
||||
counterpart in either Math or StrictMath (cf. MATH-740 and MATH-901).
|
||||
">
|
||||
<action dev="tn" type="fix" issue="MATH-1152" due-to="Andras Sereny">
|
||||
Improved performance of "EnumeratedDistribution#sample()" by caching
|
||||
the cumulative probabilities and using binary rather than a linear search.
|
||||
</action>
|
||||
<action dev="tn" type="fix" issue="MATH-1148" due-to="Guillaume Marceau">
|
||||
"MonotoneChain" did not take the tolerance factor into account when
|
||||
sorting the input points. In case of collinear points this could result
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.commons.math3.distribution;
|
|||
import java.io.Serializable;
|
||||
import java.lang.reflect.Array;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.math3.exception.MathArithmeticException;
|
||||
|
@ -64,6 +65,7 @@ public class EnumeratedDistribution<T> implements Serializable {
|
|||
* List of random variable values.
|
||||
*/
|
||||
private final List<T> singletons;
|
||||
|
||||
/**
|
||||
* Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
|
||||
* probability[i] is the probability that a random variable following this distribution takes
|
||||
|
@ -71,6 +73,11 @@ public class EnumeratedDistribution<T> implements Serializable {
|
|||
*/
|
||||
private final double[] probabilities;
|
||||
|
||||
/**
|
||||
* Cumulative probabilities, cached to speed up sampling.
|
||||
*/
|
||||
private final double[] cumulativeProbabilities;
|
||||
|
||||
/**
|
||||
* Create an enumerated distribution using the given probability mass function
|
||||
* enumeration.
|
||||
|
@ -123,6 +130,13 @@ public class EnumeratedDistribution<T> implements Serializable {
|
|||
}
|
||||
|
||||
probabilities = MathArrays.normalizeArray(probs, 1.0);
|
||||
|
||||
cumulativeProbabilities = new double[probabilities.length];
|
||||
double sum = 0;
|
||||
for (int i = 0; i < probabilities.length; i++) {
|
||||
sum += probabilities[i];
|
||||
cumulativeProbabilities[i] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -186,18 +200,21 @@ public class EnumeratedDistribution<T> implements Serializable {
|
|||
*/
|
||||
public T sample() {
|
||||
final double randomValue = random.nextDouble();
|
||||
double sum = 0;
|
||||
|
||||
for (int i = 0; i < probabilities.length; i++) {
|
||||
sum += probabilities[i];
|
||||
if (randomValue < sum) {
|
||||
return singletons.get(i);
|
||||
int index = Arrays.binarySearch(cumulativeProbabilities, randomValue);
|
||||
if (index < 0) {
|
||||
index = -index-1;
|
||||
}
|
||||
|
||||
if (index >= 0 && index < probabilities.length) {
|
||||
if (randomValue < cumulativeProbabilities[index]) {
|
||||
return singletons.get(index);
|
||||
}
|
||||
}
|
||||
|
||||
/* This should never happen, but it ensures we will return a correct
|
||||
* object in case the loop above has some floating point inequality
|
||||
* problem on the final iteration. */
|
||||
* object in case there is some floating point inequality problem
|
||||
* wrt the cumulative probabilities. */
|
||||
return singletons.get(singletons.size() - 1);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue