From 430c7f45675ce3fa78c9fed1a41c2442633469e6 Mon Sep 17 00:00:00 2001 From: Phil Steitz Date: Mon, 9 Nov 2015 20:48:21 -0700 Subject: [PATCH] Added constructors taking sample data as arguments to enumerated real and integer distributions. JIRA: MATH-1287. --- src/changes/changes.xml | 3 + .../EnumeratedIntegerDistribution.java | 61 ++++++++++++++++- .../EnumeratedRealDistribution.java | 67 +++++++++++++++++-- .../EnumeratedIntegerDistributionTest.java | 8 +++ .../EnumeratedRealDistributionTest.java | 8 +++ 5 files changed, 142 insertions(+), 5 deletions(-) diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 5bdaf76c8..202f368bf 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -51,6 +51,9 @@ If the output is not quite correct, check for invisible trailing spaces! + + Added constructors taking sample data as arguments to enumerated real and integer distributions. + Fixed FastMath.exp that potentially returned NaN for non-NaN argument. diff --git a/src/main/java/org/apache/commons/math3/distribution/EnumeratedIntegerDistribution.java b/src/main/java/org/apache/commons/math3/distribution/EnumeratedIntegerDistribution.java index 9a0778771..9e29b66cf 100644 --- a/src/main/java/org/apache/commons/math3/distribution/EnumeratedIntegerDistribution.java +++ b/src/main/java/org/apache/commons/math3/distribution/EnumeratedIntegerDistribution.java @@ -17,7 +17,10 @@ package org.apache.commons.math3.distribution; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Map.Entry; import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.MathArithmeticException; @@ -94,6 +97,62 @@ public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution { throws DimensionMismatchException, NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException { super(rng); + innerDistribution = new EnumeratedDistribution( + rng, createDistribution(singletons, probabilities)); + } + + /** + * Create a discrete integer-valued distribution from the input data. Values are assigned + * mass based on their frequency. + * + * @param rng random number generator used for sampling + * @param data input dataset + */ + public EnumeratedIntegerDistribution(final RandomGenerator rng, final int[] data) { + super(rng); + final Map dataMap = new HashMap(); + + for (int value : data) { + Integer count = dataMap.get(value); + if (count == null) { + count = new Integer(1); + } else { + count = new Integer(count.intValue() + 1); + } + dataMap.put(value, count); + } + final int massPoints = dataMap.size(); + final double denom = data.length; + final int[] values = new int[massPoints]; + final double[] probabilities = new double[massPoints]; + int index = 0; + for (Entry entry : dataMap.entrySet()) { + values[index] = entry.getKey(); + probabilities[index] = entry.getValue().intValue() / denom; + index++; + } + innerDistribution = new EnumeratedDistribution(rng, createDistribution(values, probabilities)); + } + + /** + * Create a discrete integer-valued distribution from the input data. Values are assigned + * mass based on their frequency. For example, [0,1,1,2] as input creates a distribution + * with values 0, 1 and 2 having probability masses 0.25, 0.5 and 0.25 respectively, + * + * @param data input dataset + */ + public EnumeratedIntegerDistribution(final int[] data) { + this(new Well19937c(), data); + } + + /** + * Create the list of Pairs representing the distribution from singletons and probabilities. + * + * @param singletons values + * @param probabilities probabilities + * @return list of value/probability pairs + */ + private List> createDistribution(int[] singletons, double[] probabilities) { if (singletons.length != probabilities.length) { throw new DimensionMismatchException(probabilities.length, singletons.length); } @@ -103,8 +162,8 @@ public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution { for (int i = 0; i < singletons.length; i++) { samples.add(new Pair(singletons[i], probabilities[i])); } + return samples; - innerDistribution = new EnumeratedDistribution(rng, samples); } /** diff --git a/src/main/java/org/apache/commons/math3/distribution/EnumeratedRealDistribution.java b/src/main/java/org/apache/commons/math3/distribution/EnumeratedRealDistribution.java index 07b96bc02..2edb37509 100644 --- a/src/main/java/org/apache/commons/math3/distribution/EnumeratedRealDistribution.java +++ b/src/main/java/org/apache/commons/math3/distribution/EnumeratedRealDistribution.java @@ -17,7 +17,10 @@ package org.apache.commons.math3.distribution; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Map.Entry; import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.MathArithmeticException; @@ -51,7 +54,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution { protected final EnumeratedDistribution innerDistribution; /** - * Create a discrete distribution using the given probability mass function + * Create a discrete real-valued distribution using the given probability mass function * enumeration. *

* Note: this constructor will implicitly create an instance of @@ -77,7 +80,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution { } /** - * Create a discrete distribution using the given random number generator + * Create a discrete real-valued distribution using the given random number generator * and probability mass function enumeration. * * @param rng random number generator. @@ -95,17 +98,73 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution { throws DimensionMismatchException, NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException { super(rng); + + innerDistribution = new EnumeratedDistribution( + rng, createDistribution(singletons, probabilities)); + } + + /** + * Create a discrete real-valued distribution from the input data. Values are assigned + * mass based on their frequency. + * + * @param rng random number generator used for sampling + * @param data input dataset + */ + public EnumeratedRealDistribution(final RandomGenerator rng, final double[] data) { + super(rng); + final Map dataMap = new HashMap(); + + for (double value : data) { + Integer count = dataMap.get(value); + if (count == null) { + count = new Integer(1); + } else { + count = new Integer(count.intValue() + 1); + } + dataMap.put(value, count); + } + final int massPoints = dataMap.size(); + final double denom = data.length; + final double[] values = new double[massPoints]; + final double[] probabilities = new double[massPoints]; + int index = 0; + for (Entry entry : dataMap.entrySet()) { + values[index] = entry.getKey(); + probabilities[index] = entry.getValue().intValue() / denom; + index++; + } + innerDistribution = new EnumeratedDistribution(rng, createDistribution(values, probabilities)); + } + + /** + * Create a discrete real-valued distribution from the input data. Values are assigned + * mass based on their frequency. For example, [0,1,1,2] as input creates a distribution + * with values 0, 1 and 2 having probability masses 0.25, 0.5 and 0.25 respectively, + * + * @param data input dataset + */ + public EnumeratedRealDistribution(final double[] data) { + this(new Well19937c(), data); + } + /** + * Create the list of Pairs representing the distribution from singletons and probabilities. + * + * @param singletons values + * @param probabilities probabilities + * @return list of value/probability pairs + */ + private List> createDistribution(double[] singletons, double[] probabilities) { if (singletons.length != probabilities.length) { throw new DimensionMismatchException(probabilities.length, singletons.length); } - List> samples = new ArrayList>(singletons.length); + final List> samples = new ArrayList>(singletons.length); for (int i = 0; i < singletons.length; i++) { samples.add(new Pair(singletons[i], probabilities[i])); } + return samples; - innerDistribution = new EnumeratedDistribution(rng, samples); } /** diff --git a/src/test/java/org/apache/commons/math3/distribution/EnumeratedIntegerDistributionTest.java b/src/test/java/org/apache/commons/math3/distribution/EnumeratedIntegerDistributionTest.java index dd3d06918..694887532 100644 --- a/src/test/java/org/apache/commons/math3/distribution/EnumeratedIntegerDistributionTest.java +++ b/src/test/java/org/apache/commons/math3/distribution/EnumeratedIntegerDistributionTest.java @@ -168,4 +168,12 @@ public class EnumeratedIntegerDistributionTest { Assert.assertEquals(testDistribution.getNumericalVariance(), sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2); } + + @Test + public void testCreateFromIntegers() { + final int[] data = new int[] {0, 1, 1, 2, 2, 2}; + EnumeratedIntegerDistribution distribution = new EnumeratedIntegerDistribution(data); + Assert.assertEquals(0.5, distribution.probability(2), 0); + Assert.assertEquals(0.5, distribution.cumulativeProbability(1), 0); + } } diff --git a/src/test/java/org/apache/commons/math3/distribution/EnumeratedRealDistributionTest.java b/src/test/java/org/apache/commons/math3/distribution/EnumeratedRealDistributionTest.java index 961b134a1..4103c0b82 100644 --- a/src/test/java/org/apache/commons/math3/distribution/EnumeratedRealDistributionTest.java +++ b/src/test/java/org/apache/commons/math3/distribution/EnumeratedRealDistributionTest.java @@ -256,4 +256,12 @@ public class EnumeratedRealDistributionTest { assertEquals(18.0, distribution.inverseCumulativeProbability(0.5625), 0.0); assertEquals(28.0, distribution.inverseCumulativeProbability(0.7500), 0.0); } + + @Test + public void testCreateFromDoubles() { + final double[] data = new double[] {0, 1, 1, 2, 2, 2}; + EnumeratedRealDistribution distribution = new EnumeratedRealDistribution(data); + assertEquals(0.5, distribution.probability(2), 0); + assertEquals(0.5, distribution.cumulativeProbability(1), 0); + } }