Added constructors taking sample data as arguments to enumerated real and integer distributions. JIRA: MATH-1287.
This commit is contained in:
parent
8aecb842d3
commit
430c7f4567
|
@ -51,6 +51,9 @@ If the output is not quite correct, check for invisible trailing spaces!
|
||||||
</properties>
|
</properties>
|
||||||
<body>
|
<body>
|
||||||
<release version="3.6" date="XXXX-XX-XX" description="">
|
<release version="3.6" date="XXXX-XX-XX" description="">
|
||||||
|
<action dev="psteitz" type="update" issue="MATH-1287">
|
||||||
|
Added constructors taking sample data as arguments to enumerated real and integer distributions.
|
||||||
|
</action>
|
||||||
<action dev="oertl" type="fix" issue="MATH-1269">
|
<action dev="oertl" type="fix" issue="MATH-1269">
|
||||||
Fixed FastMath.exp that potentially returned NaN for non-NaN argument.
|
Fixed FastMath.exp that potentially returned NaN for non-NaN argument.
|
||||||
</action>
|
</action>
|
||||||
|
|
|
@ -17,7 +17,10 @@
|
||||||
package org.apache.commons.math3.distribution;
|
package org.apache.commons.math3.distribution;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
|
||||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||||
import org.apache.commons.math3.exception.MathArithmeticException;
|
import org.apache.commons.math3.exception.MathArithmeticException;
|
||||||
|
@ -94,6 +97,62 @@ public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution {
|
||||||
throws DimensionMismatchException, NotPositiveException, MathArithmeticException,
|
throws DimensionMismatchException, NotPositiveException, MathArithmeticException,
|
||||||
NotFiniteNumberException, NotANumberException {
|
NotFiniteNumberException, NotANumberException {
|
||||||
super(rng);
|
super(rng);
|
||||||
|
innerDistribution = new EnumeratedDistribution<Integer>(
|
||||||
|
rng, createDistribution(singletons, probabilities));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a discrete integer-valued distribution from the input data. Values are assigned
|
||||||
|
* mass based on their frequency.
|
||||||
|
*
|
||||||
|
* @param rng random number generator used for sampling
|
||||||
|
* @param data input dataset
|
||||||
|
*/
|
||||||
|
public EnumeratedIntegerDistribution(final RandomGenerator rng, final int[] data) {
|
||||||
|
super(rng);
|
||||||
|
final Map<Integer, Integer> dataMap = new HashMap<Integer, Integer>();
|
||||||
|
|
||||||
|
for (int value : data) {
|
||||||
|
Integer count = dataMap.get(value);
|
||||||
|
if (count == null) {
|
||||||
|
count = new Integer(1);
|
||||||
|
} else {
|
||||||
|
count = new Integer(count.intValue() + 1);
|
||||||
|
}
|
||||||
|
dataMap.put(value, count);
|
||||||
|
}
|
||||||
|
final int massPoints = dataMap.size();
|
||||||
|
final double denom = data.length;
|
||||||
|
final int[] values = new int[massPoints];
|
||||||
|
final double[] probabilities = new double[massPoints];
|
||||||
|
int index = 0;
|
||||||
|
for (Entry<Integer, Integer> entry : dataMap.entrySet()) {
|
||||||
|
values[index] = entry.getKey();
|
||||||
|
probabilities[index] = entry.getValue().intValue() / denom;
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
innerDistribution = new EnumeratedDistribution<Integer>(rng, createDistribution(values, probabilities));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a discrete integer-valued distribution from the input data. Values are assigned
|
||||||
|
* mass based on their frequency. For example, [0,1,1,2] as input creates a distribution
|
||||||
|
* with values 0, 1 and 2 having probability masses 0.25, 0.5 and 0.25 respectively,
|
||||||
|
*
|
||||||
|
* @param data input dataset
|
||||||
|
*/
|
||||||
|
public EnumeratedIntegerDistribution(final int[] data) {
|
||||||
|
this(new Well19937c(), data);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create the list of Pairs representing the distribution from singletons and probabilities.
|
||||||
|
*
|
||||||
|
* @param singletons values
|
||||||
|
* @param probabilities probabilities
|
||||||
|
* @return list of value/probability pairs
|
||||||
|
*/
|
||||||
|
private List<Pair<Integer, Double>> createDistribution(int[] singletons, double[] probabilities) {
|
||||||
if (singletons.length != probabilities.length) {
|
if (singletons.length != probabilities.length) {
|
||||||
throw new DimensionMismatchException(probabilities.length, singletons.length);
|
throw new DimensionMismatchException(probabilities.length, singletons.length);
|
||||||
}
|
}
|
||||||
|
@ -103,8 +162,8 @@ public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution {
|
||||||
for (int i = 0; i < singletons.length; i++) {
|
for (int i = 0; i < singletons.length; i++) {
|
||||||
samples.add(new Pair<Integer, Double>(singletons[i], probabilities[i]));
|
samples.add(new Pair<Integer, Double>(singletons[i], probabilities[i]));
|
||||||
}
|
}
|
||||||
|
return samples;
|
||||||
|
|
||||||
innerDistribution = new EnumeratedDistribution<Integer>(rng, samples);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -17,7 +17,10 @@
|
||||||
package org.apache.commons.math3.distribution;
|
package org.apache.commons.math3.distribution;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
|
||||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||||
import org.apache.commons.math3.exception.MathArithmeticException;
|
import org.apache.commons.math3.exception.MathArithmeticException;
|
||||||
|
@ -51,7 +54,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
|
||||||
protected final EnumeratedDistribution<Double> innerDistribution;
|
protected final EnumeratedDistribution<Double> innerDistribution;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a discrete distribution using the given probability mass function
|
* Create a discrete real-valued distribution using the given probability mass function
|
||||||
* enumeration.
|
* enumeration.
|
||||||
* <p>
|
* <p>
|
||||||
* <b>Note:</b> this constructor will implicitly create an instance of
|
* <b>Note:</b> this constructor will implicitly create an instance of
|
||||||
|
@ -77,7 +80,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a discrete distribution using the given random number generator
|
* Create a discrete real-valued distribution using the given random number generator
|
||||||
* and probability mass function enumeration.
|
* and probability mass function enumeration.
|
||||||
*
|
*
|
||||||
* @param rng random number generator.
|
* @param rng random number generator.
|
||||||
|
@ -95,17 +98,73 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution {
|
||||||
throws DimensionMismatchException, NotPositiveException, MathArithmeticException,
|
throws DimensionMismatchException, NotPositiveException, MathArithmeticException,
|
||||||
NotFiniteNumberException, NotANumberException {
|
NotFiniteNumberException, NotANumberException {
|
||||||
super(rng);
|
super(rng);
|
||||||
|
|
||||||
|
innerDistribution = new EnumeratedDistribution<Double>(
|
||||||
|
rng, createDistribution(singletons, probabilities));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a discrete real-valued distribution from the input data. Values are assigned
|
||||||
|
* mass based on their frequency.
|
||||||
|
*
|
||||||
|
* @param rng random number generator used for sampling
|
||||||
|
* @param data input dataset
|
||||||
|
*/
|
||||||
|
public EnumeratedRealDistribution(final RandomGenerator rng, final double[] data) {
|
||||||
|
super(rng);
|
||||||
|
final Map<Double, Integer> dataMap = new HashMap<Double, Integer>();
|
||||||
|
|
||||||
|
for (double value : data) {
|
||||||
|
Integer count = dataMap.get(value);
|
||||||
|
if (count == null) {
|
||||||
|
count = new Integer(1);
|
||||||
|
} else {
|
||||||
|
count = new Integer(count.intValue() + 1);
|
||||||
|
}
|
||||||
|
dataMap.put(value, count);
|
||||||
|
}
|
||||||
|
final int massPoints = dataMap.size();
|
||||||
|
final double denom = data.length;
|
||||||
|
final double[] values = new double[massPoints];
|
||||||
|
final double[] probabilities = new double[massPoints];
|
||||||
|
int index = 0;
|
||||||
|
for (Entry<Double, Integer> entry : dataMap.entrySet()) {
|
||||||
|
values[index] = entry.getKey();
|
||||||
|
probabilities[index] = entry.getValue().intValue() / denom;
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
innerDistribution = new EnumeratedDistribution<Double>(rng, createDistribution(values, probabilities));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a discrete real-valued distribution from the input data. Values are assigned
|
||||||
|
* mass based on their frequency. For example, [0,1,1,2] as input creates a distribution
|
||||||
|
* with values 0, 1 and 2 having probability masses 0.25, 0.5 and 0.25 respectively,
|
||||||
|
*
|
||||||
|
* @param data input dataset
|
||||||
|
*/
|
||||||
|
public EnumeratedRealDistribution(final double[] data) {
|
||||||
|
this(new Well19937c(), data);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Create the list of Pairs representing the distribution from singletons and probabilities.
|
||||||
|
*
|
||||||
|
* @param singletons values
|
||||||
|
* @param probabilities probabilities
|
||||||
|
* @return list of value/probability pairs
|
||||||
|
*/
|
||||||
|
private List<Pair<Double, Double>> createDistribution(double[] singletons, double[] probabilities) {
|
||||||
if (singletons.length != probabilities.length) {
|
if (singletons.length != probabilities.length) {
|
||||||
throw new DimensionMismatchException(probabilities.length, singletons.length);
|
throw new DimensionMismatchException(probabilities.length, singletons.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
List<Pair<Double, Double>> samples = new ArrayList<Pair<Double, Double>>(singletons.length);
|
final List<Pair<Double, Double>> samples = new ArrayList<Pair<Double, Double>>(singletons.length);
|
||||||
|
|
||||||
for (int i = 0; i < singletons.length; i++) {
|
for (int i = 0; i < singletons.length; i++) {
|
||||||
samples.add(new Pair<Double, Double>(singletons[i], probabilities[i]));
|
samples.add(new Pair<Double, Double>(singletons[i], probabilities[i]));
|
||||||
}
|
}
|
||||||
|
return samples;
|
||||||
|
|
||||||
innerDistribution = new EnumeratedDistribution<Double>(rng, samples);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -168,4 +168,12 @@ public class EnumeratedIntegerDistributionTest {
|
||||||
Assert.assertEquals(testDistribution.getNumericalVariance(),
|
Assert.assertEquals(testDistribution.getNumericalVariance(),
|
||||||
sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2);
|
sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateFromIntegers() {
|
||||||
|
final int[] data = new int[] {0, 1, 1, 2, 2, 2};
|
||||||
|
EnumeratedIntegerDistribution distribution = new EnumeratedIntegerDistribution(data);
|
||||||
|
Assert.assertEquals(0.5, distribution.probability(2), 0);
|
||||||
|
Assert.assertEquals(0.5, distribution.cumulativeProbability(1), 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -256,4 +256,12 @@ public class EnumeratedRealDistributionTest {
|
||||||
assertEquals(18.0, distribution.inverseCumulativeProbability(0.5625), 0.0);
|
assertEquals(18.0, distribution.inverseCumulativeProbability(0.5625), 0.0);
|
||||||
assertEquals(28.0, distribution.inverseCumulativeProbability(0.7500), 0.0);
|
assertEquals(28.0, distribution.inverseCumulativeProbability(0.7500), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateFromDoubles() {
|
||||||
|
final double[] data = new double[] {0, 1, 1, 2, 2, 2};
|
||||||
|
EnumeratedRealDistribution distribution = new EnumeratedRealDistribution(data);
|
||||||
|
assertEquals(0.5, distribution.probability(2), 0);
|
||||||
|
assertEquals(0.5, distribution.cumulativeProbability(1), 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue