From 1b238344cd2c311fc68d998043602320c5f11970 Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Sun, 10 Oct 2010 14:50:56 +0000 Subject: [PATCH] Improved Percentile performance by using a selection algorithm instead of a complete sort, and by allowing caching data array and pivots when several different percentiles are desired JIRA: MATH-417 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1006301 13f79535-47bb-0310-9956-ffa450edef68 --- findbugs-exclude-filter.xml | 5 + .../AbstractUnivariateStatistic.java | 60 ++++- .../stat/descriptive/moment/FirstMoment.java | 1 + .../descriptive/moment/GeometricMean.java | 1 + .../stat/descriptive/moment/Kurtosis.java | 1 + .../math/stat/descriptive/moment/Mean.java | 1 + .../stat/descriptive/moment/SemiVariance.java | 1 + .../stat/descriptive/moment/Skewness.java | 1 + .../descriptive/moment/StandardDeviation.java | 1 + .../stat/descriptive/moment/Variance.java | 1 + .../math/stat/descriptive/rank/Max.java | 1 + .../math/stat/descriptive/rank/Min.java | 1 + .../stat/descriptive/rank/Percentile.java | 235 +++++++++++++++++- .../stat/descriptive/summary/Product.java | 1 + .../math/stat/descriptive/summary/Sum.java | 1 + .../stat/descriptive/summary/SumOfLogs.java | 1 + .../descriptive/summary/SumOfSquares.java | 1 + src/site/xdoc/changes.xml | 5 + 18 files changed, 307 insertions(+), 12 deletions(-) diff --git a/findbugs-exclude-filter.xml b/findbugs-exclude-filter.xml index a4b7d57d4..0fe0ed059 100644 --- a/findbugs-exclude-filter.xml +++ b/findbugs-exclude-filter.xml @@ -92,6 +92,11 @@ + + + + + diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/AbstractUnivariateStatistic.java b/src/main/java/org/apache/commons/math/stat/descriptive/AbstractUnivariateStatistic.java index f8f62c01d..90a14b6b1 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/AbstractUnivariateStatistic.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/AbstractUnivariateStatistic.java @@ -17,10 +17,10 @@ package org.apache.commons.math.stat.descriptive; import org.apache.commons.math.MathRuntimeException; -import org.apache.commons.math.exception.util.LocalizedFormats; -import org.apache.commons.math.exception.NullArgumentException; -import org.apache.commons.math.exception.NotPositiveException; import org.apache.commons.math.exception.DimensionMismatchException; +import org.apache.commons.math.exception.NotPositiveException; +import org.apache.commons.math.exception.NullArgumentException; +import org.apache.commons.math.exception.util.LocalizedFormats; /** * Abstract base class for all implementations of the @@ -38,6 +38,60 @@ import org.apache.commons.math.exception.DimensionMismatchException; public abstract class AbstractUnivariateStatistic implements UnivariateStatistic { + /** Stored data. */ + private double[] storedData; + + /** + * Set the data array. + *

+ * The stored value is a copy of the parameter array, not the array itself + *

+ * @param values data array to store (may be null to remove stored data) + * @see #evaluate() + */ + public void setData(final double[] values) { + storedData = (values == null) ? null : values.clone(); + } + + /** + * Get a copy of the stored data array. + * @return copy of the stored data array (may be null) + */ + public double[] getData() { + return (storedData == null) ? null : storedData.clone(); + } + + /** + * Get a reference to the stored data array. + * @return reference to the stored data array (may be null) + */ + protected double[] getDataRef() { + return storedData; + } + + /** + * Set the data array. + * @param values data array to store + * @param begin the index of the first element to include + * @param length the number of elements to include + * @see #evaluate() + */ + public void setData(final double[] values, final int begin, final int length) { + storedData = new double[length]; + System.arraycopy(values, begin, storedData, 0, length); + } + + /** + * Returns the result of evaluating the statistic over the stored data. + *

+ * The stored array is the one which was set by previous calls to + *

+ * @return the value of the statistic applied to the stored data + */ + public double evaluate() { + return evaluate(storedData); + } + /** * {@inheritDoc} */ diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/moment/FirstMoment.java b/src/main/java/org/apache/commons/math/stat/descriptive/moment/FirstMoment.java index 195d0005b..4880f0530 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/moment/FirstMoment.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/moment/FirstMoment.java @@ -151,6 +151,7 @@ public class FirstMoment extends AbstractStorelessUnivariateStatistic * @throws NullPointerException if either source or dest is null */ public static void copy(FirstMoment source, FirstMoment dest) { + dest.setData(source.getDataRef()); dest.n = source.n; dest.m1 = source.m1; dest.dev = source.dev; diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/moment/GeometricMean.java b/src/main/java/org/apache/commons/math/stat/descriptive/moment/GeometricMean.java index fdeea6bb3..dec8c2679 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/moment/GeometricMean.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/moment/GeometricMean.java @@ -186,6 +186,7 @@ public class GeometricMean extends AbstractStorelessUnivariateStatistic implemen * @throws NullPointerException if either source or dest is null */ public static void copy(GeometricMean source, GeometricMean dest) { + dest.setData(source.getDataRef()); dest.sumOfLogs = source.sumOfLogs.copy(); } diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/moment/Kurtosis.java b/src/main/java/org/apache/commons/math/stat/descriptive/moment/Kurtosis.java index a993aa8ed..adc130763 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/moment/Kurtosis.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/moment/Kurtosis.java @@ -214,6 +214,7 @@ public class Kurtosis extends AbstractStorelessUnivariateStatistic implements S * @throws NullPointerException if either source or dest is null */ public static void copy(Kurtosis source, Kurtosis dest) { + dest.setData(source.getDataRef()); dest.moment = source.moment.copy(); dest.incMoment = source.incMoment; } diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/moment/Mean.java b/src/main/java/org/apache/commons/math/stat/descriptive/moment/Mean.java index 24d7d0a1d..880bc126c 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/moment/Mean.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/moment/Mean.java @@ -265,6 +265,7 @@ public class Mean extends AbstractStorelessUnivariateStatistic * @throws NullPointerException if either source or dest is null */ public static void copy(Mean source, Mean dest) { + dest.setData(source.getDataRef()); dest.incMoment = source.incMoment; dest.moment = source.moment.copy(); } diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/moment/SemiVariance.java b/src/main/java/org/apache/commons/math/stat/descriptive/moment/SemiVariance.java index 09494aa2b..3e140b97a 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/moment/SemiVariance.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/moment/SemiVariance.java @@ -159,6 +159,7 @@ public class SemiVariance extends AbstractUnivariateStatistic implements Seriali * @throws NullPointerException if either source or dest is null */ public static void copy(final SemiVariance source, SemiVariance dest) { + dest.setData(source.getDataRef()); dest.biasCorrected = source.biasCorrected; dest.varianceDirection = source.varianceDirection; } diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/moment/Skewness.java b/src/main/java/org/apache/commons/math/stat/descriptive/moment/Skewness.java index ac61ff0f1..192a0631b 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/moment/Skewness.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/moment/Skewness.java @@ -206,6 +206,7 @@ public class Skewness extends AbstractStorelessUnivariateStatistic implements Se * @throws NullPointerException if either source or dest is null */ public static void copy(Skewness source, Skewness dest) { + dest.setData(source.getDataRef()); dest.moment = new ThirdMoment(source.moment.copy()); dest.incMoment = source.incMoment; } diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/moment/StandardDeviation.java b/src/main/java/org/apache/commons/math/stat/descriptive/moment/StandardDeviation.java index 2a811425c..2c2ef8d99 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/moment/StandardDeviation.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/moment/StandardDeviation.java @@ -264,6 +264,7 @@ public class StandardDeviation extends AbstractStorelessUnivariateStatistic * @throws NullPointerException if either source or dest is null */ public static void copy(StandardDeviation source, StandardDeviation dest) { + dest.setData(source.getDataRef()); dest.variance = source.variance.copy(); } diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/moment/Variance.java b/src/main/java/org/apache/commons/math/stat/descriptive/moment/Variance.java index 146623891..cabd8373a 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/moment/Variance.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/moment/Variance.java @@ -602,6 +602,7 @@ public class Variance extends AbstractStorelessUnivariateStatistic implements Se dest == null) { throw new NullArgumentException(); } + dest.setData(source.getDataRef()); dest.moment = source.moment.copy(); dest.isBiasCorrected = source.isBiasCorrected; dest.incMoment = source.incMoment; diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/rank/Max.java b/src/main/java/org/apache/commons/math/stat/descriptive/rank/Max.java index 653b13c8e..ce8b9d6f9 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/rank/Max.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/rank/Max.java @@ -156,6 +156,7 @@ public class Max extends AbstractStorelessUnivariateStatistic implements Seriali * @throws NullPointerException if either source or dest is null */ public static void copy(Max source, Max dest) { + dest.setData(source.getDataRef()); dest.n = source.n; dest.value = source.value; } diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/rank/Min.java b/src/main/java/org/apache/commons/math/stat/descriptive/rank/Min.java index c01984bc4..4fe1b4dfc 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/rank/Min.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/rank/Min.java @@ -156,6 +156,7 @@ public class Min extends AbstractStorelessUnivariateStatistic implements Seriali * @throws NullPointerException if either source or dest is null */ public static void copy(Min source, Min dest) { + dest.setData(source.getDataRef()); dest.n = source.n; dest.value = source.value; } diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/rank/Percentile.java b/src/main/java/org/apache/commons/math/stat/descriptive/rank/Percentile.java index b85924b4a..4dd8d6d48 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/rank/Percentile.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/rank/Percentile.java @@ -47,8 +47,8 @@ import org.apache.commons.math.util.FastMath; * *

*

- * To compute percentiles, the data must be (totally) ordered. Input arrays - * are copied and then sorted using {@link java.util.Arrays#sort(double[])}. + * To compute percentiles, the data must be at least partially ordered. Input + * arrays are copied and recursively partitioned using an ordering definition. * The ordering used by Arrays.sort(double[]) is the one determined * by {@link java.lang.Double#compareTo(Double)}. This ordering makes * Double.NaN larger than any other value (including @@ -60,6 +60,18 @@ import org.apache.commons.math.util.FastMath; * elements, arrays containing NaN or infinite values will often * result in NaN or infinite values returned.

*

+ * Since 2.2, Percentile implementation uses only selection instead of complete + * sorting and caches selection algorithm state between calls to the various + * {@code evaluate} methods when several percentiles are to be computed on the same data. + * This greatly improves efficiency, both for single percentile and multiple + * percentiles computations. However, it also induces a need to be sure the data + * at one call to {@code evaluate} is the same as the data with the cached algorithm + * state from the previous calls. Percentile does this by checking the array reference + * itself and a checksum of its content by default. If the user already knows he calls + * {@code evaluate} on an immutable array, he can save the checking time by calling the + * {@code evaluate} methods that do not + *

+ *

* Note that this implementation is not synchronized. If * multiple threads access an instance of this class concurrently, and at least * one of the threads invokes the increment() or @@ -72,10 +84,19 @@ public class Percentile extends AbstractUnivariateStatistic implements Serializa /** Serializable version identifier */ private static final long serialVersionUID = -8091216485095130416L; + /** Minimum size under which we use a simple insertion sort rather than Hoare's select. */ + private static final int MIN_SELECT_SIZE = 15; + + /** Maximum number of partitioning pivots cached (each level double the number of pivots). */ + private static final int MAX_CACHED_LEVELS = 10; + /** Determines what percentile is computed when evaluate() is activated * with no quantile argument */ private double quantile = 0.0; + /** Cached pivots. */ + private int[] cachedPivots; + /** * Constructs a Percentile with a default quantile * value of 50.0. @@ -92,6 +113,7 @@ public class Percentile extends AbstractUnivariateStatistic implements Serializa */ public Percentile(final double p) { setQuantile(p); + cachedPivots = null; } /** @@ -104,6 +126,42 @@ public class Percentile extends AbstractUnivariateStatistic implements Serializa copy(original, this); } + /** {@inheritDoc} */ + @Override + public void setData(final double[] values) { + if (values == null) { + cachedPivots = null; + } else { + cachedPivots = new int[(0x1 << MAX_CACHED_LEVELS) - 1]; + Arrays.fill(cachedPivots, -1); + } + super.setData(values); + } + + /** {@inheritDoc} */ + @Override + public void setData(final double[] values, final int begin, final int length) { + if (values == null) { + cachedPivots = null; + } else { + cachedPivots = new int[(0x1 << MAX_CACHED_LEVELS) - 1]; + Arrays.fill(cachedPivots, -1); + } + super.setData(values, begin, length); + } + + /** + * Returns the result of evaluating the statistic over the stored data. + *

+ * The stored array is the one which was set by previous calls to + *

+ * @param p the percentile value to compute + * @return the value of the statistic applied to the stored data + */ + public double evaluate(final double p) { + return evaluate(getDataRef(), p); + } + /** * Returns an estimate of the pth percentile of the values * in the values array. @@ -214,21 +272,176 @@ public class Percentile extends AbstractUnivariateStatistic implements Serializa double fpos = FastMath.floor(pos); int intPos = (int) fpos; double dif = pos - fpos; - double[] sorted = new double[length]; - System.arraycopy(values, begin, sorted, 0, length); - Arrays.sort(sorted); + double[] work; + int[] pivotsHeap; + if (values == getDataRef()) { + work = getDataRef(); + pivotsHeap = cachedPivots; + } else { + work = new double[length]; + System.arraycopy(values, begin, work, 0, length); + pivotsHeap = new int[(0x1 << MAX_CACHED_LEVELS) - 1]; + Arrays.fill(pivotsHeap, -1); + } if (pos < 1) { - return sorted[0]; + return select(work, pivotsHeap, 0); } if (pos >= n) { - return sorted[length - 1]; + return select(work, pivotsHeap, length - 1); } - double lower = sorted[intPos - 1]; - double upper = sorted[intPos]; + double lower = select(work, pivotsHeap, intPos - 1); + double upper = select(work, pivotsHeap, intPos); return lower + dif * (upper - lower); } + /** + * Select the kth smallest element from work array + * @param work work array (will be reorganized during the call) + * @param pivotsHeap set of pivot index corresponding to elements that + * are already at their sorted location, stored as an implicit heap + * (i.e. a sorted binary tree stored in a flat array, where the + * children of a node at index n are at indices 2n+1 for the left + * child and 2n+2 for the right child, with 0-based indices) + * @param k index of the desired element + * @return kth smallest element + */ + private double select(final double[] work, final int[] pivotsHeap, final int k) { + + int begin = 0; + int end = work.length; + int node = 0; + + while (end - begin > MIN_SELECT_SIZE) { + + final int pivot; + if ((node < pivotsHeap.length) && (pivotsHeap[node] >= 0)) { + // the pivot has already been found in a previous call + // and the array has already been partitioned around it + pivot = pivotsHeap[node]; + } else { + // select a pivot and partition work array around it + pivot = partition(work, begin, end, medianOf3(work, begin, end)); + if (node < pivotsHeap.length) { + pivotsHeap[node] = pivot; + } + } + + if (k == pivot) { + // the pivot was exactly the element we wanted + return work[k]; + } else if (k < pivot) { + // the element is in the left partition + end = pivot; + node = Math.min(2 * node + 1, pivotsHeap.length); // the min is here to avoid integer overflow + } else { + // the element is in the right partition + begin = pivot + 1; + node = Math.min(2 * node + 2, pivotsHeap.length); // the min is here to avoid integer overflow + } + + } + + // the element is somewhere in the small sub-array + // sort the sub-array using insertion sort + insertionSort(work, begin, end); + return work[k]; + + } + + /** Select a pivot index as the median of three + * @param work data array + * @param begin index of the first element of the slice + * @param end index after the last element of the slice + * @return the index of the median element chosen between the + * first, the middle and the last element of the array slice + */ + int medianOf3(final double[] work, final int begin, final int end) { + + final int inclusiveEnd = end - 1; + final int middle = begin + (inclusiveEnd - begin) / 2; + final double wBegin = work[begin]; + final double wMiddle = work[middle]; + final double wEnd = work[inclusiveEnd]; + + if (wBegin < wMiddle) { + if (wMiddle < wEnd) { + return middle; + } else { + return (wBegin < wEnd) ? inclusiveEnd : begin; + } + } else { + if (wBegin < wEnd) { + return begin; + } else { + return (wMiddle < wEnd) ? inclusiveEnd : middle; + } + } + + } + + /** + * Partition an array slice around a pivot + *

+ * Partitioning exchanges array elements such that all elements + * smaller than pivot are before it and all elements larger than + * pivot are after it + *

+ * @param work data array + * @param begin index of the first element of the slice + * @param end index after the last element of the slice + * @param pivot initial index of the pivot + * @return index of the pivot after partition + */ + private int partition(final double[] work, final int begin, final int end, final int pivot) { + + final double value = work[pivot]; + work[pivot] = work[begin]; + + int i = begin + 1; + int j = end - 1; + while (i < j) { + while ((i < j) && (work[j] >= value)) { + --j; + } + while ((i < j) && (work[i] <= value)) { + ++i; + } + + if (i < j) { + final double tmp = work[i]; + work[i++] = work[j]; + work[j--] = tmp; + } + } + + if ((i >= end) || (work[i] > value)) { + --i; + } + work[begin] = work[i]; + work[i] = value; + return i; + + } + + /** + * Sort in place a (small) array slice using insertion sort + * @param work array to sort + * @param begin index of the first element of the slice to sort + * @param end index after the last element of the slice to sort + */ + private void insertionSort(final double[] work, final int begin, final int end) { + for (int j = begin + 1; j < end; j++) { + final double saved = work[j]; + int i = j - 1; + while ((i >= begin) && (saved < work[i])) { + work[i + 1] = work[i]; + i--; + } + work[i + 1] = saved; + } + } + /** * Returns the value of the quantile field (determines what percentile is * computed when evaluate() is called with no quantile argument). @@ -274,6 +487,10 @@ public class Percentile extends AbstractUnivariateStatistic implements Serializa * @throws NullPointerException if either source or dest is null */ public static void copy(Percentile source, Percentile dest) { + dest.setData(source.getDataRef()); + if (source.cachedPivots != null) { + System.arraycopy(source.cachedPivots, 0, dest.cachedPivots, 0, source.cachedPivots.length); + } dest.quantile = source.quantile; } diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/summary/Product.java b/src/main/java/org/apache/commons/math/stat/descriptive/summary/Product.java index f9796b458..abe27d4de 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/summary/Product.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/summary/Product.java @@ -213,6 +213,7 @@ public class Product extends AbstractStorelessUnivariateStatistic implements Ser * @throws NullPointerException if either source or dest is null */ public static void copy(Product source, Product dest) { + dest.setData(source.getDataRef()); dest.n = source.n; dest.value = source.value; } diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/summary/Sum.java b/src/main/java/org/apache/commons/math/stat/descriptive/summary/Sum.java index b1d9059d8..997cc3aec 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/summary/Sum.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/summary/Sum.java @@ -209,6 +209,7 @@ public class Sum extends AbstractStorelessUnivariateStatistic implements Seriali * @throws NullPointerException if either source or dest is null */ public static void copy(Sum source, Sum dest) { + dest.setData(source.getDataRef()); dest.n = source.n; dest.value = source.value; } diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/summary/SumOfLogs.java b/src/main/java/org/apache/commons/math/stat/descriptive/summary/SumOfLogs.java index a4ce08ef1..27264a3a4 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/summary/SumOfLogs.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/summary/SumOfLogs.java @@ -155,6 +155,7 @@ public class SumOfLogs extends AbstractStorelessUnivariateStatistic implements S * @throws NullPointerException if either source or dest is null */ public static void copy(SumOfLogs source, SumOfLogs dest) { + dest.setData(source.getDataRef()); dest.n = source.n; dest.value = source.value; } diff --git a/src/main/java/org/apache/commons/math/stat/descriptive/summary/SumOfSquares.java b/src/main/java/org/apache/commons/math/stat/descriptive/summary/SumOfSquares.java index 36a216817..ac2317703 100644 --- a/src/main/java/org/apache/commons/math/stat/descriptive/summary/SumOfSquares.java +++ b/src/main/java/org/apache/commons/math/stat/descriptive/summary/SumOfSquares.java @@ -143,6 +143,7 @@ public class SumOfSquares extends AbstractStorelessUnivariateStatistic implement * @throws NullPointerException if either source or dest is null */ public static void copy(SumOfSquares source, SumOfSquares dest) { + dest.setData(source.getDataRef()); dest.n = source.n; dest.value = source.value; } diff --git a/src/site/xdoc/changes.xml b/src/site/xdoc/changes.xml index 1d9989434..aaf51d40a 100644 --- a/src/site/xdoc/changes.xml +++ b/src/site/xdoc/changes.xml @@ -85,6 +85,11 @@ The type attribute can be add,update,fix,remove. + + Improved Percentile performance by using a selection algorithm instead of a + complete sort, and by allowing caching data array and pivots when several + different percentiles are desired + Fixed an error preventing zero length vectors to be built by some constructors