innerDistribution;
/**
- * Create a discrete distribution using the given probability mass function
+ * Create a discrete real-valued distribution using the given probability mass function
* enumeration.
*
* Note: this constructor will implicitly create an instance of
@@ -77,7 +80,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
}
/**
- * Create a discrete distribution using the given random number generator
+ * Create a discrete real-valued distribution using the given random number generator
* and probability mass function enumeration.
*
* @param rng random number generator.
@@ -95,17 +98,73 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
throws DimensionMismatchException, NotPositiveException, MathArithmeticException,
NotFiniteNumberException, NotANumberException {
super(rng);
+
+ innerDistribution = new EnumeratedDistribution(
+ rng, createDistribution(singletons, probabilities));
+ }
+
+ /**
+ * Create a discrete real-valued distribution from the input data. Values are assigned
+ * mass based on their frequency.
+ *
+ * @param rng random number generator used for sampling
+ * @param data input dataset
+ */
+ public EnumeratedRealDistribution(final RandomGenerator rng, final double[] data) {
+ super(rng);
+ final Map dataMap = new HashMap();
+
+ for (double value : data) {
+ Integer count = dataMap.get(value);
+ if (count == null) {
+ count = new Integer(1);
+ } else {
+ count = new Integer(count.intValue() + 1);
+ }
+ dataMap.put(value, count);
+ }
+ final int massPoints = dataMap.size();
+ final double denom = data.length;
+ final double[] values = new double[massPoints];
+ final double[] probabilities = new double[massPoints];
+ int index = 0;
+ for (Entry entry : dataMap.entrySet()) {
+ values[index] = entry.getKey();
+ probabilities[index] = entry.getValue().intValue() / denom;
+ index++;
+ }
+ innerDistribution = new EnumeratedDistribution(rng, createDistribution(values, probabilities));
+ }
+
+ /**
+ * Create a discrete real-valued distribution from the input data. Values are assigned
+ * mass based on their frequency. For example, [0,1,1,2] as input creates a distribution
+ * with values 0, 1 and 2 having probability masses 0.25, 0.5 and 0.25 respectively,
+ *
+ * @param data input dataset
+ */
+ public EnumeratedRealDistribution(final double[] data) {
+ this(new Well19937c(), data);
+ }
+ /**
+ * Create the list of Pairs representing the distribution from singletons and probabilities.
+ *
+ * @param singletons values
+ * @param probabilities probabilities
+ * @return list of value/probability pairs
+ */
+ private List> createDistribution(double[] singletons, double[] probabilities) {
if (singletons.length != probabilities.length) {
throw new DimensionMismatchException(probabilities.length, singletons.length);
}
- List> samples = new ArrayList>(singletons.length);
+ final List> samples = new ArrayList>(singletons.length);
for (int i = 0; i < singletons.length; i++) {
samples.add(new Pair(singletons[i], probabilities[i]));
}
+ return samples;
- innerDistribution = new EnumeratedDistribution(rng, samples);
}
/**
diff --git a/src/test/java/org/apache/commons/math3/distribution/EnumeratedIntegerDistributionTest.java b/src/test/java/org/apache/commons/math3/distribution/EnumeratedIntegerDistributionTest.java
index dd3d06918..694887532 100644
--- a/src/test/java/org/apache/commons/math3/distribution/EnumeratedIntegerDistributionTest.java
+++ b/src/test/java/org/apache/commons/math3/distribution/EnumeratedIntegerDistributionTest.java
@@ -168,4 +168,12 @@ public class EnumeratedIntegerDistributionTest {
Assert.assertEquals(testDistribution.getNumericalVariance(),
sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2);
}
+
+ @Test
+ public void testCreateFromIntegers() {
+ final int[] data = new int[] {0, 1, 1, 2, 2, 2};
+ EnumeratedIntegerDistribution distribution = new EnumeratedIntegerDistribution(data);
+ Assert.assertEquals(0.5, distribution.probability(2), 0);
+ Assert.assertEquals(0.5, distribution.cumulativeProbability(1), 0);
+ }
}
diff --git a/src/test/java/org/apache/commons/math3/distribution/EnumeratedRealDistributionTest.java b/src/test/java/org/apache/commons/math3/distribution/EnumeratedRealDistributionTest.java
index 961b134a1..4103c0b82 100644
--- a/src/test/java/org/apache/commons/math3/distribution/EnumeratedRealDistributionTest.java
+++ b/src/test/java/org/apache/commons/math3/distribution/EnumeratedRealDistributionTest.java
@@ -256,4 +256,12 @@ public class EnumeratedRealDistributionTest {
assertEquals(18.0, distribution.inverseCumulativeProbability(0.5625), 0.0);
assertEquals(28.0, distribution.inverseCumulativeProbability(0.7500), 0.0);
}
+
+ @Test
+ public void testCreateFromDoubles() {
+ final double[] data = new double[] {0, 1, 1, 2, 2, 2};
+ EnumeratedRealDistribution distribution = new EnumeratedRealDistribution(data);
+ assertEquals(0.5, distribution.probability(2), 0);
+ assertEquals(0.5, distribution.cumulativeProbability(1), 0);
+ }
}