Added discrete distributions.

Patch contributed by Piotr Wydrych.

JIRA: MATH-941

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1454439 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2013-03-08 15:59:29 +00:00
parent e6a5e6dc9c
commit 91ebbb6294
7 changed files with 1030 additions and 0 deletions

View File

@ -279,6 +279,9 @@
<contributor>
<name>Christian Winter</name>
</contributor>
<contributor>
<name>Piotr Wydrych</name>
</contributor>
<contributor>
<name>Xiaogang Zhang</name>
</contributor>

View File

@ -55,6 +55,9 @@ This is a minor release: It combines bug fixes and new features.
Changes to existing features were made in a backwards-compatible
way such as to allow drop-in replacement of the v3.1[.1] JAR file.
">
<action dev="luc" type="add" issue="MATH-941" due-to="Piotr Wydrych" >
Added discrete distributions.
</action>
<action dev="luc" type="fix" issue="MATH-940" due-to="Piotr Wydrych" >
Fixed abstract test class naming that broke ant builds.
</action>

View File

@ -0,0 +1,197 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.distribution;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well19937c;
import org.apache.commons.math3.util.MathArrays;
import org.apache.commons.math3.util.Pair;
/**
* Generic implementation of the discrete distribution.
*
* @param <T> type of the random variable.
* @see <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">Discrete probability distribution (Wikipedia)</a>
* @see <a href="http://mathworld.wolfram.com/DiscreteDistribution.html">Discrete Distribution (MathWorld)</a>
* @version $Id: DiscreteDistribution.java 169 2013-03-08 09:02:38Z wydrych $
*/
public class DiscreteDistribution<T> {
/**
* RNG instance used to generate samples from the distribution.
*/
protected final RandomGenerator random;
/**
* List of random variable values.
*/
private final List<T> singletons;
/**
* Normalized array of probabilities of respective random variable values.
*/
private final double[] probabilities;
/**
* Create a discrete distribution using the given probability mass function
* definition.
*
* @param samples definition of probability mass function in the format of
* list of pairs.
* @throws NotPositiveException if probability of at least one value is
* negative.
* @throws MathArithmeticException if the probabilities sum to zero.
* @throws MathIllegalArgumentException if probability of at least one value
* is infinite.
*/
public DiscreteDistribution(final List<Pair<T, Double>> samples)
throws NotPositiveException, MathArithmeticException, MathIllegalArgumentException {
this(new Well19937c(), samples);
}
/**
* Create a discrete distribution using the given random number generator
* and probability mass function definition.
*
* @param rng random number generator.
* @param samples definition of probability mass function in the format of
* list of pairs.
* @throws NotPositiveException if probability of at least one value is
* negative.
* @throws MathArithmeticException if the probabilities sum to zero.
* @throws MathIllegalArgumentException if probability of at least one value
* is infinite.
*/
public DiscreteDistribution(final RandomGenerator rng, final List<Pair<T, Double>> samples)
throws NotPositiveException, MathArithmeticException, MathIllegalArgumentException {
random = rng;
singletons = new ArrayList<T>(samples.size());
final double[] probs = new double[samples.size()];
for (int i = 0; i < samples.size(); i++) {
final Pair<T, Double> sample = samples.get(i);
singletons.add(sample.getKey());
if (sample.getValue() < 0) {
throw new NotPositiveException(sample.getValue());
}
probs[i] = sample.getValue();
}
probabilities = MathArrays.normalizeArray(probs, 1.0);
}
/**
* Reseed the random generator used to generate samples.
*
* @param seed the new seed
*/
public void reseedRandomGenerator(long seed) {
random.setSeed(seed);
}
/**
* For a random variable {@code X} whose values are distributed according to
* this distribution, this method returns {@code P(X = x)}. In other words,
* this method represents the probability mass function (PMF) for the
* distribution.
*
* @param x the point at which the PMF is evaluated
* @return the value of the probability mass function at {@code x}
*/
double probability(final T x) {
double probability = 0;
for (int i = 0; i < probabilities.length; i++) {
if ((x == null && singletons.get(i) == null) ||
(x != null && x.equals(singletons.get(i)))) {
probability += probabilities[i];
}
}
return probability;
}
/**
* Return the definition of probability mass function in the format of list
* of pairs.
*
* @return definition of probability mass function.
*/
public List<Pair<T, Double>> getSamples() {
final List<Pair<T, Double>> samples = new ArrayList<Pair<T, Double>>(probabilities.length);
for (int i = 0; i < probabilities.length; i++) {
samples.add(new Pair<T, Double>(singletons.get(i), probabilities[i]));
}
return samples;
}
/**
* Generate a random value sampled from this distribution.
*
* @return a random value.
*/
public T sample() {
final double randomValue = random.nextDouble();
double sum = 0;
for (int i = 0; i < probabilities.length; i++) {
sum += probabilities[i];
if (randomValue < sum) {
return singletons.get(i);
}
}
/* This should never happen, but it ensures we will return a correct
* object in case the loop above has some floating point inequality
* problem on the final iteration. */
return singletons.get(singletons.size() - 1);
}
/**
* Generate a random sample from the distribution.
*
* @param sampleSize the number of random values to generate.
* @return an array representing the random sample.
* @throws NotStrictlyPositiveException if {@code sampleSize} is not
* positive.
*/
public T[] sample(int sampleSize) throws NotStrictlyPositiveException {
if (sampleSize <= 0) {
throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
sampleSize);
}
@SuppressWarnings("unchecked")
final T[]out = (T[]) Array.newInstance(singletons.get(0).getClass(), sampleSize);
for (int i = 0; i < sampleSize; i++) {
out[i] = sample();
}
return out;
}
}

View File

@ -0,0 +1,209 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.distribution;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well19937c;
import org.apache.commons.math3.util.Pair;
/**
* Implementation of the integer-valued discrete distribution.
*
* Note: values with zero-probability are allowed but they do not extend the
* support.
*
* @see <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">Discrete probability distribution (Wikipedia)</a>
* @see <a href="http://mathworld.wolfram.com/DiscreteDistribution.html">Discrete Distribution (MathWorld)</a>
* @version $Id: DiscreteIntegerDistribution.java 169 2013-03-08 09:02:38Z wydrych $
*/
public class DiscreteIntegerDistribution extends AbstractIntegerDistribution {
/** Serializable UID. */
private static final long serialVersionUID = 20130308L;
/**
* {@link DiscreteDistribution} instance (using the {@link Integer} wrapper)
* used to generate samples.
*/
protected final DiscreteDistribution<Integer> innerDistribution;
/**
* Create a discrete distribution using the given probability mass function
* definition.
*
* @param singletons array of random variable values.
* @param probabilities array of probabilities.
* @throws DimensionMismatchException if
* {@code singletons.length != probabilities.length}
* @throws NotPositiveException if probability of at least one value is
* negative.
* @throws MathArithmeticException if the probabilities sum to zero.
* @throws MathIllegalArgumentException if probability of at least one value
* is infinite.
*/
public DiscreteIntegerDistribution(final int[] singletons, final double[] probabilities)
throws DimensionMismatchException, NotPositiveException, MathArithmeticException, MathIllegalArgumentException {
this(new Well19937c(), singletons, probabilities);
}
/**
* Create a discrete distribution using the given random number generator
* and probability mass function definition.
*
* @param rng random number generator.
* @param singletons array of random variable values.
* @param probabilities array of probabilities.
* @throws DimensionMismatchException if
* {@code singletons.length != probabilities.length}
* @throws NotPositiveException if probability of at least one value is
* negative.
* @throws MathArithmeticException if the probabilities sum to zero.
* @throws MathIllegalArgumentException if probability of at least one value
* is infinite.
*/
public DiscreteIntegerDistribution(final RandomGenerator rng,
final int[] singletons, final double[] probabilities)
throws DimensionMismatchException, NotPositiveException, MathArithmeticException, MathIllegalArgumentException {
super(rng);
if (singletons.length != probabilities.length) {
throw new DimensionMismatchException(probabilities.length, singletons.length);
}
final List<Pair<Integer, Double>> samples = new ArrayList<Pair<Integer, Double>>(singletons.length);
for (int i = 0; i < singletons.length; i++) {
samples.add(new Pair<Integer, Double>(singletons[i], probabilities[i]));
}
innerDistribution = new DiscreteDistribution<Integer>(rng, samples);
}
/**
* {@inheritDoc}
*/
public double probability(final int x) {
return innerDistribution.probability(x);
}
/**
* {@inheritDoc}
*/
public double cumulativeProbability(final int x) {
double probability = 0;
for (final Pair<Integer, Double> sample : innerDistribution.getSamples()) {
if (sample.getKey() <= x) {
probability += sample.getValue();
}
}
return probability;
}
/**
* {@inheritDoc}
*
* @return {@code sum(singletons[i] * probabilities[i])}
*/
public double getNumericalMean() {
double mean = 0;
for (final Pair<Integer, Double> sample : innerDistribution.getSamples()) {
mean += sample.getValue() * sample.getKey();
}
return mean;
}
/**
* {@inheritDoc}
*
* @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])}
*/
public double getNumericalVariance() {
double mean = 0;
double meanOfSquares = 0;
for (final Pair<Integer, Double> sample : innerDistribution.getSamples()) {
mean += sample.getValue() * sample.getKey();
meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
}
return meanOfSquares - mean * mean;
}
/**
* {@inheritDoc}
*
* Returns the lowest value with non-zero probability.
*
* @return the lowest value with non-zero probability.
*/
public int getSupportLowerBound() {
int min = Integer.MAX_VALUE;
for (final Pair<Integer, Double> sample : innerDistribution.getSamples()) {
if (sample.getKey() < min && sample.getValue() > 0) {
min = sample.getKey();
}
}
return min;
}
/**
* {@inheritDoc}
*
* Returns the highest value with non-zero probability.
*
* @return the highest value with non-zero probability.
*/
public int getSupportUpperBound() {
int max = Integer.MIN_VALUE;
for (final Pair<Integer, Double> sample : innerDistribution.getSamples()) {
if (sample.getKey() > max && sample.getValue() > 0) {
max = sample.getKey();
}
}
return max;
}
/**
* {@inheritDoc}
*
* The support of this distribution is connected.
*
* @return {@code true}
*/
public boolean isSupportConnected() {
return true;
}
/**
* {@inheritDoc}
*/
@Override
public int sample() {
return innerDistribution.sample();
}
}

View File

@ -0,0 +1,245 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.distribution;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well19937c;
import org.apache.commons.math3.util.Pair;
/**
* Implementation of the discrete distribution on the reals.
*
* Note: values with zero-probability are allowed but they do not extend the
* support.
*
* @see <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">Discrete probability distribution (Wikipedia)</a>
* @see <a href="http://mathworld.wolfram.com/DiscreteDistribution.html">Discrete Distribution (MathWorld)</a>
* @version $Id: DiscreteRealDistribution.java 169 2013-03-08 09:02:38Z wydrych $
*/
public class DiscreteRealDistribution extends AbstractRealDistribution {
/** Serializable UID. */
private static final long serialVersionUID = 20130308L;
/**
* {@link DiscreteDistribution} instance (using the {@link Double} wrapper)
* used to generate samples.
*/
protected final DiscreteDistribution<Double> innerDistribution;
/**
* Create a discrete distribution using the given probability mass function
* definition.
*
* @param singletons array of random variable values.
* @param probabilities array of probabilities.
* @throws DimensionMismatchException if
* {@code singletons.length != probabilities.length}
* @throws NotPositiveException if probability of at least one value is
* negative.
* @throws MathArithmeticException if the probabilities sum to zero.
* @throws MathIllegalArgumentException if probability of at least one value
* is infinite.
*/
public DiscreteRealDistribution(final double[] singletons, final double[] probabilities)
throws DimensionMismatchException, NotPositiveException, MathArithmeticException, MathIllegalArgumentException {
this(new Well19937c(), singletons, probabilities);
}
/**
* Create a discrete distribution using the given random number generator
* and probability mass function definition.
*
* @param rng random number generator.
* @param singletons array of random variable values.
* @param probabilities array of probabilities.
* @throws DimensionMismatchException if
* {@code singletons.length != probabilities.length}
* @throws NotPositiveException if probability of at least one value is
* negative.
* @throws MathArithmeticException if the probabilities sum to zero.
* @throws MathIllegalArgumentException if probability of at least one value
* is infinite.
*/
public DiscreteRealDistribution(final RandomGenerator rng,
final double[] singletons, final double[] probabilities)
throws DimensionMismatchException, NotPositiveException, MathArithmeticException, MathIllegalArgumentException {
super(rng);
if (singletons.length != probabilities.length) {
throw new DimensionMismatchException(probabilities.length, singletons.length);
}
List<Pair<Double, Double>> samples = new ArrayList<Pair<Double, Double>>(singletons.length);
for (int i = 0; i < singletons.length; i++) {
samples.add(new Pair<Double, Double>(singletons[i], probabilities[i]));
}
innerDistribution = new DiscreteDistribution<Double>(rng, samples);
}
/**
* {@inheritDoc}
*/
@Override
public double probability(final double x) {
return innerDistribution.probability(x);
}
/**
* For a random variable {@code X} whose values are distributed according to
* this distribution, this method returns {@code P(X = x)}. In other words,
* this method represents the probability mass function (PMF) for the
* distribution.
*
* @param x the point at which the PMF is evaluated
* @return the value of the probability mass function at point {@code x}
*/
public double density(final double x) {
return probability(x);
}
/**
* {@inheritDoc}
*/
public double cumulativeProbability(final double x) {
double probability = 0;
for (final Pair<Double, Double> sample : innerDistribution.getSamples()) {
if (sample.getKey() <= x) {
probability += sample.getValue();
}
}
return probability;
}
/**
* {@inheritDoc}
*
* @return {@code sum(singletons[i] * probabilities[i])}
*/
public double getNumericalMean() {
double mean = 0;
for (final Pair<Double, Double> sample : innerDistribution.getSamples()) {
mean += sample.getValue() * sample.getKey();
}
return mean;
}
/**
* {@inheritDoc}
*
* @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])}
*/
public double getNumericalVariance() {
double mean = 0;
double meanOfSquares = 0;
for (final Pair<Double, Double> sample : innerDistribution.getSamples()) {
mean += sample.getValue() * sample.getKey();
meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
}
return meanOfSquares - mean * mean;
}
/**
* {@inheritDoc}
*
* Returns the lowest value with non-zero probability.
*
* @return the lowest value with non-zero probability.
*/
public double getSupportLowerBound() {
double min = Double.POSITIVE_INFINITY;
for (final Pair<Double, Double> sample : innerDistribution.getSamples()) {
if (sample.getKey() < min && sample.getValue() > 0) {
min = sample.getKey();
}
}
return min;
}
/**
* {@inheritDoc}
*
* Returns the highest value with non-zero probability.
*
* @return the highest value with non-zero probability.
*/
public double getSupportUpperBound() {
double max = Double.NEGATIVE_INFINITY;
for (final Pair<Double, Double> sample : innerDistribution.getSamples()) {
if (sample.getKey() > max && sample.getValue() > 0) {
max = sample.getKey();
}
}
return max;
}
/**
* {@inheritDoc}
*
* The support of this distribution includes the lower bound.
*
* @return {@code true}
*/
public boolean isSupportLowerBoundInclusive() {
return true;
}
/**
* {@inheritDoc}
*
* The support of this distribution includes the upper bound.
*
* @return {@code true}
*/
public boolean isSupportUpperBoundInclusive() {
return true;
}
/**
* {@inheritDoc}
*
* The support of this distribution is connected.
*
* @return {@code true}
*/
public boolean isSupportConnected() {
return true;
}
/**
* {@inheritDoc}
*/
@Override
public double sample() {
return innerDistribution.sample();
}
}

View File

@ -0,0 +1,171 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.distribution;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
/**
* Test class for {@link DiscreteIntegerDistribution}.
*
* @version $Id: DiscreteIntegerDistributionTest.java 161 2013-03-07 09:47:32Z wydrych $
*/
public class DiscreteIntegerDistributionTest {
/**
* The distribution object used for testing.
*/
private final DiscreteIntegerDistribution testDistribution;
/**
* Creates the default distribution object uded for testing.
*/
public DiscreteIntegerDistributionTest() {
// Non-sorted singleton array with duplicates should be allowed.
// Values with zero-probability do not extend the support.
testDistribution = new DiscreteIntegerDistribution(
new int[]{3, -1, 3, 7, -2, 8},
new double[]{0.2, 0.2, 0.3, 0.3, 0.0, 0.0});
}
/**
* Tests if the {@link DiscreteIntegerDistribution} constructor throws
* exceptions for ivalid data.
*/
@Test
public void testExceptions() {
DiscreteIntegerDistribution invalid = null;
try {
invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0});
Assert.fail("Expected DimensionMismatchException");
} catch (DimensionMismatchException e) {
}
try {
invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0, -1.0});
Assert.fail("Expected NotPositiveException");
} catch (NotPositiveException e) {
}
try {
invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0, 0.0});
Assert.fail("Expected MathArithmeticException");
} catch (MathArithmeticException e) {
}
try {
invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0, Double.NaN});
Assert.fail("Expected MathArithmeticException");
} catch (MathArithmeticException e) {
}
try {
invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0, Double.POSITIVE_INFINITY});
Assert.fail("Expected MathIllegalArgumentException");
} catch (MathIllegalArgumentException e) {
}
Assert.assertNull("Expected non-initialized DiscreteRealDistribution", invalid);
}
/**
* Tests if the distribution returns proper probability values.
*/
@Test
public void testProbability() {
int[] points = new int[]{-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8};
double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0};
for (int p = 0; p < points.length; p++) {
double probability = testDistribution.probability(points[p]);
Assert.assertEquals(results[p], probability, 0.0);
}
}
/**
* Tests if the distribution returns proper cumulative probability values.
*/
@Test
public void testCumulativeProbability() {
int[] points = new int[]{-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8};
double[] results = new double[]{0, 0.2, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7, 0.7, 1.0, 1.0};
for (int p = 0; p < points.length; p++) {
double probability = testDistribution.cumulativeProbability(points[p]);
Assert.assertEquals(results[p], probability, 1e-10);
}
}
/**
* Tests if the distribution returns proper mean value.
*/
@Test
public void testGetNumericalMean() {
Assert.assertEquals(3.4, testDistribution.getNumericalMean(), 1e-10);
}
/**
* Tests if the distribution returns proper variance.
*/
@Test
public void testGetNumericalVariance() {
Assert.assertEquals(7.84, testDistribution.getNumericalVariance(), 1e-10);
}
/**
* Tests if the distribution returns proper lower bound.
*/
@Test
public void testGetSupportLowerBound() {
Assert.assertEquals(-1, testDistribution.getSupportLowerBound());
}
/**
* Tests if the distribution returns proper upper bound.
*/
@Test
public void testGetSupportUpperBound() {
Assert.assertEquals(7, testDistribution.getSupportUpperBound());
}
/**
* Tests if the distribution returns properly that the support is connected.
*/
@Test
public void testIsSupportConnected() {
Assert.assertTrue(testDistribution.isSupportConnected());
}
/**
* Tests sampling.
*/
@Test
public void testSample() {
final int n = 1000000;
testDistribution.reseedRandomGenerator(-334759360); // fixed seed
final int[] samples = testDistribution.sample(n);
Assert.assertEquals(n, samples.length);
double sum = 0;
double sumOfSquares = 0;
for (int i = 0; i < samples.length; i++) {
sum += samples[i];
sumOfSquares += samples[i] * samples[i];
}
Assert.assertEquals(testDistribution.getNumericalMean(),
sum / n, 1e-2);
Assert.assertEquals(testDistribution.getNumericalVariance(),
sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2);
}
}

View File

@ -0,0 +1,202 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.distribution;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
/**
* Test class for {@link DiscreteRealDistribution}.
*
* @version $Id: DiscreteRealDistributionTest.java 161 2013-03-07 09:47:32Z wydrych $
*/
public class DiscreteRealDistributionTest {
/**
* The distribution object used for testing.
*/
private final DiscreteRealDistribution testDistribution;
/**
* Creates the default distribution object uded for testing.
*/
public DiscreteRealDistributionTest() {
// Non-sorted singleton array with duplicates should be allowed.
// Values with zero-probability do not extend the support.
testDistribution = new DiscreteRealDistribution(
new double[]{3.0, -1.0, 3.0, 7.0, -2.0, 8.0},
new double[]{0.2, 0.2, 0.3, 0.3, 0.0, 0.0});
}
/**
* Tests if the {@link DiscreteRealDistribution} constructor throws
* exceptions for ivalid data.
*/
@Test
public void testExceptions() {
DiscreteRealDistribution invalid = null;
try {
invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0});
Assert.fail("Expected DimensionMismatchException");
} catch (DimensionMismatchException e) {
}
try{
invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, -1.0});
Assert.fail("Expected NotPositiveException");
} catch (NotPositiveException e) {
}
try {
invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, 0.0});
Assert.fail("Expected MathArithmeticException");
} catch (MathArithmeticException e) {
}
try {
invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, Double.NaN});
Assert.fail("Expected MathArithmeticException");
} catch (MathArithmeticException e) {
}
try {
invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, Double.POSITIVE_INFINITY});
Assert.fail("Expected MathIllegalArgumentException");
} catch (MathIllegalArgumentException e) {
}
Assert.assertNull("Expected non-initialized DiscreteRealDistribution", invalid);
}
/**
* Tests if the distribution returns proper probability values.
*/
@Test
public void testProbability() {
double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0};
for (int p = 0; p < points.length; p++) {
double density = testDistribution.probability(points[p]);
Assert.assertEquals(results[p], density, 0.0);
}
}
/**
* Tests if the distribution returns proper density values.
*/
@Test
public void testDensity() {
double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0};
for (int p = 0; p < points.length; p++) {
double density = testDistribution.density(points[p]);
Assert.assertEquals(results[p], density, 0.0);
}
}
/**
* Tests if the distribution returns proper cumulative probability values.
*/
@Test
public void testCumulativeProbability() {
double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
double[] results = new double[]{0, 0.2, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7, 0.7, 1.0, 1.0};
for (int p = 0; p < points.length; p++) {
double probability = testDistribution.cumulativeProbability(points[p]);
Assert.assertEquals(results[p], probability, 1e-10);
}
}
/**
* Tests if the distribution returns proper mean value.
*/
@Test
public void testGetNumericalMean() {
Assert.assertEquals(3.4, testDistribution.getNumericalMean(), 1e-10);
}
/**
* Tests if the distribution returns proper variance.
*/
@Test
public void testGetNumericalVariance() {
Assert.assertEquals(7.84, testDistribution.getNumericalVariance(), 1e-10);
}
/**
* Tests if the distribution returns proper lower bound.
*/
@Test
public void testGetSupportLowerBound() {
Assert.assertEquals(-1, testDistribution.getSupportLowerBound(), 0);
}
/**
* Tests if the distribution returns proper upper bound.
*/
@Test
public void testGetSupportUpperBound() {
Assert.assertEquals(7, testDistribution.getSupportUpperBound(), 0);
}
/**
* Tests if the distribution returns properly that the support includes the
* lower bound.
*/
@Test
public void testIsSupportLowerBoundInclusive() {
Assert.assertTrue(testDistribution.isSupportLowerBoundInclusive());
}
/**
* Tests if the distribution returns properly that the support includes the
* upper bound.
*/
@Test
public void testIsSupportUpperBoundInclusive() {
Assert.assertTrue(testDistribution.isSupportUpperBoundInclusive());
}
/**
* Tests if the distribution returns properly that the support is connected.
*/
@Test
public void testIsSupportConnected() {
Assert.assertTrue(testDistribution.isSupportConnected());
}
/**
* Tests sampling.
*/
@Test
public void testSample() {
final int n = 1000000;
testDistribution.reseedRandomGenerator(-334759360); // fixed seed
final double[] samples = testDistribution.sample(n);
Assert.assertEquals(n, samples.length);
double sum = 0;
double sumOfSquares = 0;
for (int i = 0; i < samples.length; i++) {
sum += samples[i];
sumOfSquares += samples[i] * samples[i];
}
Assert.assertEquals(testDistribution.getNumericalMean(),
sum / n, 1e-2);
Assert.assertEquals(testDistribution.getNumericalVariance(),
sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2);
}
}