From 91ebbb6294771e5a270810dba16ec6dd49f8f9bb Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Fri, 8 Mar 2013 15:59:29 +0000 Subject: [PATCH] Added discrete distributions. Patch contributed by Piotr Wydrych. JIRA: MATH-941 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1454439 13f79535-47bb-0310-9956-ffa450edef68 --- pom.xml | 3 + src/changes/changes.xml | 3 + .../distribution/DiscreteDistribution.java | 197 ++++++++++++++ .../DiscreteIntegerDistribution.java | 209 +++++++++++++++ .../DiscreteRealDistribution.java | 245 ++++++++++++++++++ .../DiscreteIntegerDistributionTest.java | 171 ++++++++++++ .../DiscreteRealDistributionTest.java | 202 +++++++++++++++ 7 files changed, 1030 insertions(+) create mode 100644 src/main/java/org/apache/commons/math3/distribution/DiscreteDistribution.java create mode 100644 src/main/java/org/apache/commons/math3/distribution/DiscreteIntegerDistribution.java create mode 100644 src/main/java/org/apache/commons/math3/distribution/DiscreteRealDistribution.java create mode 100644 src/test/java/org/apache/commons/math3/distribution/DiscreteIntegerDistributionTest.java create mode 100644 src/test/java/org/apache/commons/math3/distribution/DiscreteRealDistributionTest.java diff --git a/pom.xml b/pom.xml index cf660603c..44b242dc6 100644 --- a/pom.xml +++ b/pom.xml @@ -279,6 +279,9 @@ Christian Winter + + Piotr Wydrych + Xiaogang Zhang diff --git a/src/changes/changes.xml b/src/changes/changes.xml index cb8d20ee3..ca334a2e4 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -55,6 +55,9 @@ This is a minor release: It combines bug fixes and new features. Changes to existing features were made in a backwards-compatible way such as to allow drop-in replacement of the v3.1[.1] JAR file. "> + + Added discrete distributions. + Fixed abstract test class naming that broke ant builds. diff --git a/src/main/java/org/apache/commons/math3/distribution/DiscreteDistribution.java b/src/main/java/org/apache/commons/math3/distribution/DiscreteDistribution.java new file mode 100644 index 000000000..8c08dbe36 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/distribution/DiscreteDistribution.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.distribution; + +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.math3.exception.MathArithmeticException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.random.Well19937c; +import org.apache.commons.math3.util.MathArrays; +import org.apache.commons.math3.util.Pair; + +/** + * Generic implementation of the discrete distribution. + * + * @param type of the random variable. + * @see Discrete probability distribution (Wikipedia) + * @see Discrete Distribution (MathWorld) + * @version $Id: DiscreteDistribution.java 169 2013-03-08 09:02:38Z wydrych $ + */ +public class DiscreteDistribution { + + /** + * RNG instance used to generate samples from the distribution. + */ + protected final RandomGenerator random; + /** + * List of random variable values. + */ + private final List singletons; + /** + * Normalized array of probabilities of respective random variable values. + */ + private final double[] probabilities; + + /** + * Create a discrete distribution using the given probability mass function + * definition. + * + * @param samples definition of probability mass function in the format of + * list of pairs. + * @throws NotPositiveException if probability of at least one value is + * negative. + * @throws MathArithmeticException if the probabilities sum to zero. + * @throws MathIllegalArgumentException if probability of at least one value + * is infinite. + */ + public DiscreteDistribution(final List> samples) + throws NotPositiveException, MathArithmeticException, MathIllegalArgumentException { + this(new Well19937c(), samples); + } + + /** + * Create a discrete distribution using the given random number generator + * and probability mass function definition. + * + * @param rng random number generator. + * @param samples definition of probability mass function in the format of + * list of pairs. + * @throws NotPositiveException if probability of at least one value is + * negative. + * @throws MathArithmeticException if the probabilities sum to zero. + * @throws MathIllegalArgumentException if probability of at least one value + * is infinite. + */ + public DiscreteDistribution(final RandomGenerator rng, final List> samples) + throws NotPositiveException, MathArithmeticException, MathIllegalArgumentException { + random = rng; + + singletons = new ArrayList(samples.size()); + final double[] probs = new double[samples.size()]; + + for (int i = 0; i < samples.size(); i++) { + final Pair sample = samples.get(i); + singletons.add(sample.getKey()); + if (sample.getValue() < 0) { + throw new NotPositiveException(sample.getValue()); + } + probs[i] = sample.getValue(); + } + + probabilities = MathArrays.normalizeArray(probs, 1.0); + } + + /** + * Reseed the random generator used to generate samples. + * + * @param seed the new seed + */ + public void reseedRandomGenerator(long seed) { + random.setSeed(seed); + } + + /** + * For a random variable {@code X} whose values are distributed according to + * this distribution, this method returns {@code P(X = x)}. In other words, + * this method represents the probability mass function (PMF) for the + * distribution. + * + * @param x the point at which the PMF is evaluated + * @return the value of the probability mass function at {@code x} + */ + double probability(final T x) { + double probability = 0; + + for (int i = 0; i < probabilities.length; i++) { + if ((x == null && singletons.get(i) == null) || + (x != null && x.equals(singletons.get(i)))) { + probability += probabilities[i]; + } + } + + return probability; + } + + /** + * Return the definition of probability mass function in the format of list + * of pairs. + * + * @return definition of probability mass function. + */ + public List> getSamples() { + final List> samples = new ArrayList>(probabilities.length); + + for (int i = 0; i < probabilities.length; i++) { + samples.add(new Pair(singletons.get(i), probabilities[i])); + } + + return samples; + } + + /** + * Generate a random value sampled from this distribution. + * + * @return a random value. + */ + public T sample() { + final double randomValue = random.nextDouble(); + double sum = 0; + + for (int i = 0; i < probabilities.length; i++) { + sum += probabilities[i]; + if (randomValue < sum) { + return singletons.get(i); + } + } + + /* This should never happen, but it ensures we will return a correct + * object in case the loop above has some floating point inequality + * problem on the final iteration. */ + return singletons.get(singletons.size() - 1); + } + + /** + * Generate a random sample from the distribution. + * + * @param sampleSize the number of random values to generate. + * @return an array representing the random sample. + * @throws NotStrictlyPositiveException if {@code sampleSize} is not + * positive. + */ + public T[] sample(int sampleSize) throws NotStrictlyPositiveException { + if (sampleSize <= 0) { + throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, + sampleSize); + } + @SuppressWarnings("unchecked") + final T[]out = (T[]) Array.newInstance(singletons.get(0).getClass(), sampleSize); + + for (int i = 0; i < sampleSize; i++) { + out[i] = sample(); + } + + return out; + + } + +} diff --git a/src/main/java/org/apache/commons/math3/distribution/DiscreteIntegerDistribution.java b/src/main/java/org/apache/commons/math3/distribution/DiscreteIntegerDistribution.java new file mode 100644 index 000000000..5e31b0809 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/distribution/DiscreteIntegerDistribution.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.distribution; + +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MathArithmeticException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.random.Well19937c; +import org.apache.commons.math3.util.Pair; + +/** + * Implementation of the integer-valued discrete distribution. + * + * Note: values with zero-probability are allowed but they do not extend the + * support. + * + * @see Discrete probability distribution (Wikipedia) + * @see Discrete Distribution (MathWorld) + * @version $Id: DiscreteIntegerDistribution.java 169 2013-03-08 09:02:38Z wydrych $ + */ +public class DiscreteIntegerDistribution extends AbstractIntegerDistribution { + + /** Serializable UID. */ + private static final long serialVersionUID = 20130308L; + + /** + * {@link DiscreteDistribution} instance (using the {@link Integer} wrapper) + * used to generate samples. + */ + protected final DiscreteDistribution innerDistribution; + + /** + * Create a discrete distribution using the given probability mass function + * definition. + * + * @param singletons array of random variable values. + * @param probabilities array of probabilities. + * @throws DimensionMismatchException if + * {@code singletons.length != probabilities.length} + * @throws NotPositiveException if probability of at least one value is + * negative. + * @throws MathArithmeticException if the probabilities sum to zero. + * @throws MathIllegalArgumentException if probability of at least one value + * is infinite. + */ + public DiscreteIntegerDistribution(final int[] singletons, final double[] probabilities) + throws DimensionMismatchException, NotPositiveException, MathArithmeticException, MathIllegalArgumentException { + this(new Well19937c(), singletons, probabilities); + } + + /** + * Create a discrete distribution using the given random number generator + * and probability mass function definition. + * + * @param rng random number generator. + * @param singletons array of random variable values. + * @param probabilities array of probabilities. + * @throws DimensionMismatchException if + * {@code singletons.length != probabilities.length} + * @throws NotPositiveException if probability of at least one value is + * negative. + * @throws MathArithmeticException if the probabilities sum to zero. + * @throws MathIllegalArgumentException if probability of at least one value + * is infinite. + */ + public DiscreteIntegerDistribution(final RandomGenerator rng, + final int[] singletons, final double[] probabilities) + throws DimensionMismatchException, NotPositiveException, MathArithmeticException, MathIllegalArgumentException { + super(rng); + if (singletons.length != probabilities.length) { + throw new DimensionMismatchException(probabilities.length, singletons.length); + } + + final List> samples = new ArrayList>(singletons.length); + + for (int i = 0; i < singletons.length; i++) { + samples.add(new Pair(singletons[i], probabilities[i])); + } + + innerDistribution = new DiscreteDistribution(rng, samples); + } + + /** + * {@inheritDoc} + */ + public double probability(final int x) { + return innerDistribution.probability(x); + } + + /** + * {@inheritDoc} + */ + public double cumulativeProbability(final int x) { + double probability = 0; + + for (final Pair sample : innerDistribution.getSamples()) { + if (sample.getKey() <= x) { + probability += sample.getValue(); + } + } + + return probability; + } + + /** + * {@inheritDoc} + * + * @return {@code sum(singletons[i] * probabilities[i])} + */ + public double getNumericalMean() { + double mean = 0; + + for (final Pair sample : innerDistribution.getSamples()) { + mean += sample.getValue() * sample.getKey(); + } + + return mean; + } + + /** + * {@inheritDoc} + * + * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])} + */ + public double getNumericalVariance() { + double mean = 0; + double meanOfSquares = 0; + + for (final Pair sample : innerDistribution.getSamples()) { + mean += sample.getValue() * sample.getKey(); + meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey(); + } + + return meanOfSquares - mean * mean; + } + + /** + * {@inheritDoc} + * + * Returns the lowest value with non-zero probability. + * + * @return the lowest value with non-zero probability. + */ + public int getSupportLowerBound() { + int min = Integer.MAX_VALUE; + for (final Pair sample : innerDistribution.getSamples()) { + if (sample.getKey() < min && sample.getValue() > 0) { + min = sample.getKey(); + } + } + + return min; + } + + /** + * {@inheritDoc} + * + * Returns the highest value with non-zero probability. + * + * @return the highest value with non-zero probability. + */ + public int getSupportUpperBound() { + int max = Integer.MIN_VALUE; + for (final Pair sample : innerDistribution.getSamples()) { + if (sample.getKey() > max && sample.getValue() > 0) { + max = sample.getKey(); + } + } + + return max; + } + + /** + * {@inheritDoc} + * + * The support of this distribution is connected. + * + * @return {@code true} + */ + public boolean isSupportConnected() { + return true; + } + + /** + * {@inheritDoc} + */ + @Override + public int sample() { + return innerDistribution.sample(); + } +} diff --git a/src/main/java/org/apache/commons/math3/distribution/DiscreteRealDistribution.java b/src/main/java/org/apache/commons/math3/distribution/DiscreteRealDistribution.java new file mode 100644 index 000000000..a9f046f90 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/distribution/DiscreteRealDistribution.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.distribution; + +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MathArithmeticException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.random.Well19937c; +import org.apache.commons.math3.util.Pair; + +/** + * Implementation of the discrete distribution on the reals. + * + * Note: values with zero-probability are allowed but they do not extend the + * support. + * + * @see Discrete probability distribution (Wikipedia) + * @see Discrete Distribution (MathWorld) + * @version $Id: DiscreteRealDistribution.java 169 2013-03-08 09:02:38Z wydrych $ + */ +public class DiscreteRealDistribution extends AbstractRealDistribution { + + /** Serializable UID. */ + private static final long serialVersionUID = 20130308L; + + /** + * {@link DiscreteDistribution} instance (using the {@link Double} wrapper) + * used to generate samples. + */ + protected final DiscreteDistribution innerDistribution; + + /** + * Create a discrete distribution using the given probability mass function + * definition. + * + * @param singletons array of random variable values. + * @param probabilities array of probabilities. + * @throws DimensionMismatchException if + * {@code singletons.length != probabilities.length} + * @throws NotPositiveException if probability of at least one value is + * negative. + * @throws MathArithmeticException if the probabilities sum to zero. + * @throws MathIllegalArgumentException if probability of at least one value + * is infinite. + */ + public DiscreteRealDistribution(final double[] singletons, final double[] probabilities) + throws DimensionMismatchException, NotPositiveException, MathArithmeticException, MathIllegalArgumentException { + this(new Well19937c(), singletons, probabilities); + } + + /** + * Create a discrete distribution using the given random number generator + * and probability mass function definition. + * + * @param rng random number generator. + * @param singletons array of random variable values. + * @param probabilities array of probabilities. + * @throws DimensionMismatchException if + * {@code singletons.length != probabilities.length} + * @throws NotPositiveException if probability of at least one value is + * negative. + * @throws MathArithmeticException if the probabilities sum to zero. + * @throws MathIllegalArgumentException if probability of at least one value + * is infinite. + */ + public DiscreteRealDistribution(final RandomGenerator rng, + final double[] singletons, final double[] probabilities) + throws DimensionMismatchException, NotPositiveException, MathArithmeticException, MathIllegalArgumentException { + super(rng); + if (singletons.length != probabilities.length) { + throw new DimensionMismatchException(probabilities.length, singletons.length); + } + + List> samples = new ArrayList>(singletons.length); + + for (int i = 0; i < singletons.length; i++) { + samples.add(new Pair(singletons[i], probabilities[i])); + } + + innerDistribution = new DiscreteDistribution(rng, samples); + } + + /** + * {@inheritDoc} + */ + @Override + public double probability(final double x) { + return innerDistribution.probability(x); + } + + /** + * For a random variable {@code X} whose values are distributed according to + * this distribution, this method returns {@code P(X = x)}. In other words, + * this method represents the probability mass function (PMF) for the + * distribution. + * + * @param x the point at which the PMF is evaluated + * @return the value of the probability mass function at point {@code x} + */ + public double density(final double x) { + return probability(x); + } + + /** + * {@inheritDoc} + */ + public double cumulativeProbability(final double x) { + double probability = 0; + + for (final Pair sample : innerDistribution.getSamples()) { + if (sample.getKey() <= x) { + probability += sample.getValue(); + } + } + + return probability; + } + + /** + * {@inheritDoc} + * + * @return {@code sum(singletons[i] * probabilities[i])} + */ + public double getNumericalMean() { + double mean = 0; + + for (final Pair sample : innerDistribution.getSamples()) { + mean += sample.getValue() * sample.getKey(); + } + + return mean; + } + + /** + * {@inheritDoc} + * + * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])} + */ + public double getNumericalVariance() { + double mean = 0; + double meanOfSquares = 0; + + for (final Pair sample : innerDistribution.getSamples()) { + mean += sample.getValue() * sample.getKey(); + meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey(); + } + + return meanOfSquares - mean * mean; + } + + /** + * {@inheritDoc} + * + * Returns the lowest value with non-zero probability. + * + * @return the lowest value with non-zero probability. + */ + public double getSupportLowerBound() { + double min = Double.POSITIVE_INFINITY; + for (final Pair sample : innerDistribution.getSamples()) { + if (sample.getKey() < min && sample.getValue() > 0) { + min = sample.getKey(); + } + } + + return min; + } + + /** + * {@inheritDoc} + * + * Returns the highest value with non-zero probability. + * + * @return the highest value with non-zero probability. + */ + public double getSupportUpperBound() { + double max = Double.NEGATIVE_INFINITY; + for (final Pair sample : innerDistribution.getSamples()) { + if (sample.getKey() > max && sample.getValue() > 0) { + max = sample.getKey(); + } + } + + return max; + } + + /** + * {@inheritDoc} + * + * The support of this distribution includes the lower bound. + * + * @return {@code true} + */ + public boolean isSupportLowerBoundInclusive() { + return true; + } + + /** + * {@inheritDoc} + * + * The support of this distribution includes the upper bound. + * + * @return {@code true} + */ + public boolean isSupportUpperBoundInclusive() { + return true; + } + + /** + * {@inheritDoc} + * + * The support of this distribution is connected. + * + * @return {@code true} + */ + public boolean isSupportConnected() { + return true; + } + + /** + * {@inheritDoc} + */ + @Override + public double sample() { + return innerDistribution.sample(); + } +} diff --git a/src/test/java/org/apache/commons/math3/distribution/DiscreteIntegerDistributionTest.java b/src/test/java/org/apache/commons/math3/distribution/DiscreteIntegerDistributionTest.java new file mode 100644 index 000000000..028486ffe --- /dev/null +++ b/src/test/java/org/apache/commons/math3/distribution/DiscreteIntegerDistributionTest.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.distribution; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MathArithmeticException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.util.FastMath; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test class for {@link DiscreteIntegerDistribution}. + * + * @version $Id: DiscreteIntegerDistributionTest.java 161 2013-03-07 09:47:32Z wydrych $ + */ +public class DiscreteIntegerDistributionTest { + + /** + * The distribution object used for testing. + */ + private final DiscreteIntegerDistribution testDistribution; + + /** + * Creates the default distribution object uded for testing. + */ + public DiscreteIntegerDistributionTest() { + // Non-sorted singleton array with duplicates should be allowed. + // Values with zero-probability do not extend the support. + testDistribution = new DiscreteIntegerDistribution( + new int[]{3, -1, 3, 7, -2, 8}, + new double[]{0.2, 0.2, 0.3, 0.3, 0.0, 0.0}); + } + + /** + * Tests if the {@link DiscreteIntegerDistribution} constructor throws + * exceptions for ivalid data. + */ + @Test + public void testExceptions() { + DiscreteIntegerDistribution invalid = null; + try { + invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0}); + Assert.fail("Expected DimensionMismatchException"); + } catch (DimensionMismatchException e) { + } + try { + invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0, -1.0}); + Assert.fail("Expected NotPositiveException"); + } catch (NotPositiveException e) { + } + try { + invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0, 0.0}); + Assert.fail("Expected MathArithmeticException"); + } catch (MathArithmeticException e) { + } + try { + invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0, Double.NaN}); + Assert.fail("Expected MathArithmeticException"); + } catch (MathArithmeticException e) { + } + try { + invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0, Double.POSITIVE_INFINITY}); + Assert.fail("Expected MathIllegalArgumentException"); + } catch (MathIllegalArgumentException e) { + } + Assert.assertNull("Expected non-initialized DiscreteRealDistribution", invalid); + } + + /** + * Tests if the distribution returns proper probability values. + */ + @Test + public void testProbability() { + int[] points = new int[]{-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0}; + for (int p = 0; p < points.length; p++) { + double probability = testDistribution.probability(points[p]); + Assert.assertEquals(results[p], probability, 0.0); + } + } + + /** + * Tests if the distribution returns proper cumulative probability values. + */ + @Test + public void testCumulativeProbability() { + int[] points = new int[]{-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + double[] results = new double[]{0, 0.2, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7, 0.7, 1.0, 1.0}; + for (int p = 0; p < points.length; p++) { + double probability = testDistribution.cumulativeProbability(points[p]); + Assert.assertEquals(results[p], probability, 1e-10); + } + } + + /** + * Tests if the distribution returns proper mean value. + */ + @Test + public void testGetNumericalMean() { + Assert.assertEquals(3.4, testDistribution.getNumericalMean(), 1e-10); + } + + /** + * Tests if the distribution returns proper variance. + */ + @Test + public void testGetNumericalVariance() { + Assert.assertEquals(7.84, testDistribution.getNumericalVariance(), 1e-10); + } + + /** + * Tests if the distribution returns proper lower bound. + */ + @Test + public void testGetSupportLowerBound() { + Assert.assertEquals(-1, testDistribution.getSupportLowerBound()); + } + + /** + * Tests if the distribution returns proper upper bound. + */ + @Test + public void testGetSupportUpperBound() { + Assert.assertEquals(7, testDistribution.getSupportUpperBound()); + } + + /** + * Tests if the distribution returns properly that the support is connected. + */ + @Test + public void testIsSupportConnected() { + Assert.assertTrue(testDistribution.isSupportConnected()); + } + + /** + * Tests sampling. + */ + @Test + public void testSample() { + final int n = 1000000; + testDistribution.reseedRandomGenerator(-334759360); // fixed seed + final int[] samples = testDistribution.sample(n); + Assert.assertEquals(n, samples.length); + double sum = 0; + double sumOfSquares = 0; + for (int i = 0; i < samples.length; i++) { + sum += samples[i]; + sumOfSquares += samples[i] * samples[i]; + } + Assert.assertEquals(testDistribution.getNumericalMean(), + sum / n, 1e-2); + Assert.assertEquals(testDistribution.getNumericalVariance(), + sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2); + } +} diff --git a/src/test/java/org/apache/commons/math3/distribution/DiscreteRealDistributionTest.java b/src/test/java/org/apache/commons/math3/distribution/DiscreteRealDistributionTest.java new file mode 100644 index 000000000..1a7ef5302 --- /dev/null +++ b/src/test/java/org/apache/commons/math3/distribution/DiscreteRealDistributionTest.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.distribution; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MathArithmeticException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.util.FastMath; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test class for {@link DiscreteRealDistribution}. + * + * @version $Id: DiscreteRealDistributionTest.java 161 2013-03-07 09:47:32Z wydrych $ + */ +public class DiscreteRealDistributionTest { + + /** + * The distribution object used for testing. + */ + private final DiscreteRealDistribution testDistribution; + + /** + * Creates the default distribution object uded for testing. + */ + public DiscreteRealDistributionTest() { + // Non-sorted singleton array with duplicates should be allowed. + // Values with zero-probability do not extend the support. + testDistribution = new DiscreteRealDistribution( + new double[]{3.0, -1.0, 3.0, 7.0, -2.0, 8.0}, + new double[]{0.2, 0.2, 0.3, 0.3, 0.0, 0.0}); + } + + /** + * Tests if the {@link DiscreteRealDistribution} constructor throws + * exceptions for ivalid data. + */ + @Test + public void testExceptions() { + DiscreteRealDistribution invalid = null; + try { + invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0}); + Assert.fail("Expected DimensionMismatchException"); + } catch (DimensionMismatchException e) { + } + try{ + invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, -1.0}); + Assert.fail("Expected NotPositiveException"); + } catch (NotPositiveException e) { + } + try { + invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, 0.0}); + Assert.fail("Expected MathArithmeticException"); + } catch (MathArithmeticException e) { + } + try { + invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, Double.NaN}); + Assert.fail("Expected MathArithmeticException"); + } catch (MathArithmeticException e) { + } + try { + invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, Double.POSITIVE_INFINITY}); + Assert.fail("Expected MathIllegalArgumentException"); + } catch (MathIllegalArgumentException e) { + } + Assert.assertNull("Expected non-initialized DiscreteRealDistribution", invalid); + } + + /** + * Tests if the distribution returns proper probability values. + */ + @Test + public void testProbability() { + double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0}; + for (int p = 0; p < points.length; p++) { + double density = testDistribution.probability(points[p]); + Assert.assertEquals(results[p], density, 0.0); + } + } + + /** + * Tests if the distribution returns proper density values. + */ + @Test + public void testDensity() { + double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0}; + for (int p = 0; p < points.length; p++) { + double density = testDistribution.density(points[p]); + Assert.assertEquals(results[p], density, 0.0); + } + } + + /** + * Tests if the distribution returns proper cumulative probability values. + */ + @Test + public void testCumulativeProbability() { + double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + double[] results = new double[]{0, 0.2, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7, 0.7, 1.0, 1.0}; + for (int p = 0; p < points.length; p++) { + double probability = testDistribution.cumulativeProbability(points[p]); + Assert.assertEquals(results[p], probability, 1e-10); + } + } + + /** + * Tests if the distribution returns proper mean value. + */ + @Test + public void testGetNumericalMean() { + Assert.assertEquals(3.4, testDistribution.getNumericalMean(), 1e-10); + } + + /** + * Tests if the distribution returns proper variance. + */ + @Test + public void testGetNumericalVariance() { + Assert.assertEquals(7.84, testDistribution.getNumericalVariance(), 1e-10); + } + + /** + * Tests if the distribution returns proper lower bound. + */ + @Test + public void testGetSupportLowerBound() { + Assert.assertEquals(-1, testDistribution.getSupportLowerBound(), 0); + } + + /** + * Tests if the distribution returns proper upper bound. + */ + @Test + public void testGetSupportUpperBound() { + Assert.assertEquals(7, testDistribution.getSupportUpperBound(), 0); + } + + /** + * Tests if the distribution returns properly that the support includes the + * lower bound. + */ + @Test + public void testIsSupportLowerBoundInclusive() { + Assert.assertTrue(testDistribution.isSupportLowerBoundInclusive()); + } + + /** + * Tests if the distribution returns properly that the support includes the + * upper bound. + */ + @Test + public void testIsSupportUpperBoundInclusive() { + Assert.assertTrue(testDistribution.isSupportUpperBoundInclusive()); + } + + /** + * Tests if the distribution returns properly that the support is connected. + */ + @Test + public void testIsSupportConnected() { + Assert.assertTrue(testDistribution.isSupportConnected()); + } + + /** + * Tests sampling. + */ + @Test + public void testSample() { + final int n = 1000000; + testDistribution.reseedRandomGenerator(-334759360); // fixed seed + final double[] samples = testDistribution.sample(n); + Assert.assertEquals(n, samples.length); + double sum = 0; + double sumOfSquares = 0; + for (int i = 0; i < samples.length; i++) { + sum += samples[i]; + sumOfSquares += samples[i] * samples[i]; + } + Assert.assertEquals(testDistribution.getNumericalMean(), + sum / n, 1e-2); + Assert.assertEquals(testDistribution.getNumericalVariance(), + sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2); + } +}