Added constructors taking sample data as arguments to enumerated real and integer distributions. JIRA: MATH-1287.

This commit is contained in:
Phil Steitz 2015-11-09 20:48:21 -07:00
parent 8aecb842d3
commit 430c7f4567
5 changed files with 142 additions and 5 deletions

View File

@ -51,6 +51,9 @@ If the output is not quite correct, check for invisible trailing spaces!
</properties> </properties>
<body> <body>
<release version="3.6" date="XXXX-XX-XX" description=""> <release version="3.6" date="XXXX-XX-XX" description="">
<action dev="psteitz" type="update" issue="MATH-1287">
Added constructors taking sample data as arguments to enumerated real and integer distributions.
</action>
<action dev="oertl" type="fix" issue="MATH-1269"> <action dev="oertl" type="fix" issue="MATH-1269">
Fixed FastMath.exp that potentially returned NaN for non-NaN argument. Fixed FastMath.exp that potentially returned NaN for non-NaN argument.
</action> </action>

View File

@ -17,7 +17,10 @@
package org.apache.commons.math3.distribution; package org.apache.commons.math3.distribution;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathArithmeticException; import org.apache.commons.math3.exception.MathArithmeticException;
@ -94,6 +97,62 @@ public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution {
throws DimensionMismatchException, NotPositiveException, MathArithmeticException, throws DimensionMismatchException, NotPositiveException, MathArithmeticException,
NotFiniteNumberException, NotANumberException { NotFiniteNumberException, NotANumberException {
super(rng); super(rng);
innerDistribution = new EnumeratedDistribution<Integer>(
rng, createDistribution(singletons, probabilities));
}
/**
* Create a discrete integer-valued distribution from the input data. Values are assigned
* mass based on their frequency.
*
* @param rng random number generator used for sampling
* @param data input dataset
*/
public EnumeratedIntegerDistribution(final RandomGenerator rng, final int[] data) {
super(rng);
final Map<Integer, Integer> dataMap = new HashMap<Integer, Integer>();
for (int value : data) {
Integer count = dataMap.get(value);
if (count == null) {
count = new Integer(1);
} else {
count = new Integer(count.intValue() + 1);
}
dataMap.put(value, count);
}
final int massPoints = dataMap.size();
final double denom = data.length;
final int[] values = new int[massPoints];
final double[] probabilities = new double[massPoints];
int index = 0;
for (Entry<Integer, Integer> entry : dataMap.entrySet()) {
values[index] = entry.getKey();
probabilities[index] = entry.getValue().intValue() / denom;
index++;
}
innerDistribution = new EnumeratedDistribution<Integer>(rng, createDistribution(values, probabilities));
}
/**
* Create a discrete integer-valued distribution from the input data. Values are assigned
* mass based on their frequency. For example, [0,1,1,2] as input creates a distribution
* with values 0, 1 and 2 having probability masses 0.25, 0.5 and 0.25 respectively,
*
* @param data input dataset
*/
public EnumeratedIntegerDistribution(final int[] data) {
this(new Well19937c(), data);
}
/**
* Create the list of Pairs representing the distribution from singletons and probabilities.
*
* @param singletons values
* @param probabilities probabilities
* @return list of value/probability pairs
*/
private List<Pair<Integer, Double>> createDistribution(int[] singletons, double[] probabilities) {
if (singletons.length != probabilities.length) { if (singletons.length != probabilities.length) {
throw new DimensionMismatchException(probabilities.length, singletons.length); throw new DimensionMismatchException(probabilities.length, singletons.length);
} }
@ -103,8 +162,8 @@ public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution {
for (int i = 0; i < singletons.length; i++) { for (int i = 0; i < singletons.length; i++) {
samples.add(new Pair<Integer, Double>(singletons[i], probabilities[i])); samples.add(new Pair<Integer, Double>(singletons[i], probabilities[i]));
} }
return samples;
innerDistribution = new EnumeratedDistribution<Integer>(rng, samples);
} }
/** /**

View File

@ -17,7 +17,10 @@
package org.apache.commons.math3.distribution; package org.apache.commons.math3.distribution;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathArithmeticException; import org.apache.commons.math3.exception.MathArithmeticException;
@ -51,7 +54,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
protected final EnumeratedDistribution<Double> innerDistribution; protected final EnumeratedDistribution<Double> innerDistribution;
/** /**
* Create a discrete distribution using the given probability mass function * Create a discrete real-valued distribution using the given probability mass function
* enumeration. * enumeration.
* <p> * <p>
* <b>Note:</b> this constructor will implicitly create an instance of * <b>Note:</b> this constructor will implicitly create an instance of
@ -77,7 +80,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
} }
/** /**
* Create a discrete distribution using the given random number generator * Create a discrete real-valued distribution using the given random number generator
* and probability mass function enumeration. * and probability mass function enumeration.
* *
* @param rng random number generator. * @param rng random number generator.
@ -95,17 +98,73 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
throws DimensionMismatchException, NotPositiveException, MathArithmeticException, throws DimensionMismatchException, NotPositiveException, MathArithmeticException,
NotFiniteNumberException, NotANumberException { NotFiniteNumberException, NotANumberException {
super(rng); super(rng);
innerDistribution = new EnumeratedDistribution<Double>(
rng, createDistribution(singletons, probabilities));
}
/**
* Create a discrete real-valued distribution from the input data. Values are assigned
* mass based on their frequency.
*
* @param rng random number generator used for sampling
* @param data input dataset
*/
public EnumeratedRealDistribution(final RandomGenerator rng, final double[] data) {
super(rng);
final Map<Double, Integer> dataMap = new HashMap<Double, Integer>();
for (double value : data) {
Integer count = dataMap.get(value);
if (count == null) {
count = new Integer(1);
} else {
count = new Integer(count.intValue() + 1);
}
dataMap.put(value, count);
}
final int massPoints = dataMap.size();
final double denom = data.length;
final double[] values = new double[massPoints];
final double[] probabilities = new double[massPoints];
int index = 0;
for (Entry<Double, Integer> entry : dataMap.entrySet()) {
values[index] = entry.getKey();
probabilities[index] = entry.getValue().intValue() / denom;
index++;
}
innerDistribution = new EnumeratedDistribution<Double>(rng, createDistribution(values, probabilities));
}
/**
* Create a discrete real-valued distribution from the input data. Values are assigned
* mass based on their frequency. For example, [0,1,1,2] as input creates a distribution
* with values 0, 1 and 2 having probability masses 0.25, 0.5 and 0.25 respectively,
*
* @param data input dataset
*/
public EnumeratedRealDistribution(final double[] data) {
this(new Well19937c(), data);
}
/**
* Create the list of Pairs representing the distribution from singletons and probabilities.
*
* @param singletons values
* @param probabilities probabilities
* @return list of value/probability pairs
*/
private List<Pair<Double, Double>> createDistribution(double[] singletons, double[] probabilities) {
if (singletons.length != probabilities.length) { if (singletons.length != probabilities.length) {
throw new DimensionMismatchException(probabilities.length, singletons.length); throw new DimensionMismatchException(probabilities.length, singletons.length);
} }
List<Pair<Double, Double>> samples = new ArrayList<Pair<Double, Double>>(singletons.length); final List<Pair<Double, Double>> samples = new ArrayList<Pair<Double, Double>>(singletons.length);
for (int i = 0; i < singletons.length; i++) { for (int i = 0; i < singletons.length; i++) {
samples.add(new Pair<Double, Double>(singletons[i], probabilities[i])); samples.add(new Pair<Double, Double>(singletons[i], probabilities[i]));
} }
return samples;
innerDistribution = new EnumeratedDistribution<Double>(rng, samples);
} }
/** /**

View File

@ -168,4 +168,12 @@ public class EnumeratedIntegerDistributionTest {
Assert.assertEquals(testDistribution.getNumericalVariance(), Assert.assertEquals(testDistribution.getNumericalVariance(),
sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2); sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2);
} }
@Test
public void testCreateFromIntegers() {
final int[] data = new int[] {0, 1, 1, 2, 2, 2};
EnumeratedIntegerDistribution distribution = new EnumeratedIntegerDistribution(data);
Assert.assertEquals(0.5, distribution.probability(2), 0);
Assert.assertEquals(0.5, distribution.cumulativeProbability(1), 0);
}
} }

View File

@ -256,4 +256,12 @@ public class EnumeratedRealDistributionTest {
assertEquals(18.0, distribution.inverseCumulativeProbability(0.5625), 0.0); assertEquals(18.0, distribution.inverseCumulativeProbability(0.5625), 0.0);
assertEquals(28.0, distribution.inverseCumulativeProbability(0.7500), 0.0); assertEquals(28.0, distribution.inverseCumulativeProbability(0.7500), 0.0);
} }
@Test
public void testCreateFromDoubles() {
final double[] data = new double[] {0, 1, 1, 2, 2, 2};
EnumeratedRealDistribution distribution = new EnumeratedRealDistribution(data);
assertEquals(0.5, distribution.probability(2), 0);
assertEquals(0.5, distribution.cumulativeProbability(1), 0);
}
} }