From 97accb47de63ee5063eda23641c6017e29ab81d7 Mon Sep 17 00:00:00 2001 From: Thomas Neidhart Date: Tue, 30 Sep 2014 21:16:07 +0200 Subject: [PATCH] [MATH-1152] Improved performance of EnumeratedDistribution#sample(). Thanks to Andras Sereny. --- src/changes/changes.xml | 4 +++ .../distribution/EnumeratedDistribution.java | 31 ++++++++++++++----- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/changes/changes.xml b/src/changes/changes.xml index c64fb5879..fdb2bd41f 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -73,6 +73,10 @@ Users are encouraged to upgrade to this version as this release not 2. A few methods in the FastMath class are in fact slower that their counterpart in either Math or StrictMath (cf. MATH-740 and MATH-901). "> + + Improved performance of "EnumeratedDistribution#sample()" by caching + the cumulative probabilities and using binary rather than a linear search. + "MonotoneChain" did not take the tolerance factor into account when sorting the input points. In case of collinear points this could result diff --git a/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java b/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java index 5117c2af5..e95098c2f 100644 --- a/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java +++ b/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java @@ -19,6 +19,7 @@ package org.apache.commons.math3.distribution; import java.io.Serializable; import java.lang.reflect.Array; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.commons.math3.exception.MathArithmeticException; @@ -64,6 +65,7 @@ public class EnumeratedDistribution implements Serializable { * List of random variable values. */ private final List singletons; + /** * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1, * probability[i] is the probability that a random variable following this distribution takes @@ -71,6 +73,11 @@ public class EnumeratedDistribution implements Serializable { */ private final double[] probabilities; + /** + * Cumulative probabilities, cached to speed up sampling. + */ + private final double[] cumulativeProbabilities; + /** * Create an enumerated distribution using the given probability mass function * enumeration. @@ -123,6 +130,13 @@ public class EnumeratedDistribution implements Serializable { } probabilities = MathArrays.normalizeArray(probs, 1.0); + + cumulativeProbabilities = new double[probabilities.length]; + double sum = 0; + for (int i = 0; i < probabilities.length; i++) { + sum += probabilities[i]; + cumulativeProbabilities[i] = sum; + } } /** @@ -186,18 +200,21 @@ public class EnumeratedDistribution implements Serializable { */ public T sample() { final double randomValue = random.nextDouble(); - double sum = 0; - for (int i = 0; i < probabilities.length; i++) { - sum += probabilities[i]; - if (randomValue < sum) { - return singletons.get(i); + int index = Arrays.binarySearch(cumulativeProbabilities, randomValue); + if (index < 0) { + index = -index-1; + } + + if (index >= 0 && index < probabilities.length) { + if (randomValue < cumulativeProbabilities[index]) { + return singletons.get(index); } } /* This should never happen, but it ensures we will return a correct - * object in case the loop above has some floating point inequality - * problem on the final iteration. */ + * object in case there is some floating point inequality problem + * wrt the cumulative probabilities. */ return singletons.get(singletons.size() - 1); }