Added discrete distributions.
Patch contributed by Piotr Wydrych. JIRA: MATH-941 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1454439 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
e6a5e6dc9c
commit
91ebbb6294
3
pom.xml
3
pom.xml
|
@ -279,6 +279,9 @@
|
|||
<contributor>
|
||||
<name>Christian Winter</name>
|
||||
</contributor>
|
||||
<contributor>
|
||||
<name>Piotr Wydrych</name>
|
||||
</contributor>
|
||||
<contributor>
|
||||
<name>Xiaogang Zhang</name>
|
||||
</contributor>
|
||||
|
|
|
@ -55,6 +55,9 @@ This is a minor release: It combines bug fixes and new features.
|
|||
Changes to existing features were made in a backwards-compatible
|
||||
way such as to allow drop-in replacement of the v3.1[.1] JAR file.
|
||||
">
|
||||
<action dev="luc" type="add" issue="MATH-941" due-to="Piotr Wydrych" >
|
||||
Added discrete distributions.
|
||||
</action>
|
||||
<action dev="luc" type="fix" issue="MATH-940" due-to="Piotr Wydrych" >
|
||||
Fixed abstract test class naming that broke ant builds.
|
||||
</action>
|
||||
|
|
|
@ -0,0 +1,197 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.distribution;
|
||||
|
||||
import java.lang.reflect.Array;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.commons.math3.exception.MathArithmeticException;
|
||||
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
||||
import org.apache.commons.math3.exception.NotPositiveException;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.Well19937c;
|
||||
import org.apache.commons.math3.util.MathArrays;
|
||||
import org.apache.commons.math3.util.Pair;
|
||||
|
||||
/**
|
||||
* Generic implementation of the discrete distribution.
|
||||
*
|
||||
* @param <T> type of the random variable.
|
||||
* @see <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">Discrete probability distribution (Wikipedia)</a>
|
||||
* @see <a href="http://mathworld.wolfram.com/DiscreteDistribution.html">Discrete Distribution (MathWorld)</a>
|
||||
* @version $Id: DiscreteDistribution.java 169 2013-03-08 09:02:38Z wydrych $
|
||||
*/
|
||||
public class DiscreteDistribution<T> {
|
||||
|
||||
/**
|
||||
* RNG instance used to generate samples from the distribution.
|
||||
*/
|
||||
protected final RandomGenerator random;
|
||||
/**
|
||||
* List of random variable values.
|
||||
*/
|
||||
private final List<T> singletons;
|
||||
/**
|
||||
* Normalized array of probabilities of respective random variable values.
|
||||
*/
|
||||
private final double[] probabilities;
|
||||
|
||||
/**
|
||||
* Create a discrete distribution using the given probability mass function
|
||||
* definition.
|
||||
*
|
||||
* @param samples definition of probability mass function in the format of
|
||||
* list of pairs.
|
||||
* @throws NotPositiveException if probability of at least one value is
|
||||
* negative.
|
||||
* @throws MathArithmeticException if the probabilities sum to zero.
|
||||
* @throws MathIllegalArgumentException if probability of at least one value
|
||||
* is infinite.
|
||||
*/
|
||||
public DiscreteDistribution(final List<Pair<T, Double>> samples)
|
||||
throws NotPositiveException, MathArithmeticException, MathIllegalArgumentException {
|
||||
this(new Well19937c(), samples);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a discrete distribution using the given random number generator
|
||||
* and probability mass function definition.
|
||||
*
|
||||
* @param rng random number generator.
|
||||
* @param samples definition of probability mass function in the format of
|
||||
* list of pairs.
|
||||
* @throws NotPositiveException if probability of at least one value is
|
||||
* negative.
|
||||
* @throws MathArithmeticException if the probabilities sum to zero.
|
||||
* @throws MathIllegalArgumentException if probability of at least one value
|
||||
* is infinite.
|
||||
*/
|
||||
public DiscreteDistribution(final RandomGenerator rng, final List<Pair<T, Double>> samples)
|
||||
throws NotPositiveException, MathArithmeticException, MathIllegalArgumentException {
|
||||
random = rng;
|
||||
|
||||
singletons = new ArrayList<T>(samples.size());
|
||||
final double[] probs = new double[samples.size()];
|
||||
|
||||
for (int i = 0; i < samples.size(); i++) {
|
||||
final Pair<T, Double> sample = samples.get(i);
|
||||
singletons.add(sample.getKey());
|
||||
if (sample.getValue() < 0) {
|
||||
throw new NotPositiveException(sample.getValue());
|
||||
}
|
||||
probs[i] = sample.getValue();
|
||||
}
|
||||
|
||||
probabilities = MathArrays.normalizeArray(probs, 1.0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reseed the random generator used to generate samples.
|
||||
*
|
||||
* @param seed the new seed
|
||||
*/
|
||||
public void reseedRandomGenerator(long seed) {
|
||||
random.setSeed(seed);
|
||||
}
|
||||
|
||||
/**
|
||||
* For a random variable {@code X} whose values are distributed according to
|
||||
* this distribution, this method returns {@code P(X = x)}. In other words,
|
||||
* this method represents the probability mass function (PMF) for the
|
||||
* distribution.
|
||||
*
|
||||
* @param x the point at which the PMF is evaluated
|
||||
* @return the value of the probability mass function at {@code x}
|
||||
*/
|
||||
double probability(final T x) {
|
||||
double probability = 0;
|
||||
|
||||
for (int i = 0; i < probabilities.length; i++) {
|
||||
if ((x == null && singletons.get(i) == null) ||
|
||||
(x != null && x.equals(singletons.get(i)))) {
|
||||
probability += probabilities[i];
|
||||
}
|
||||
}
|
||||
|
||||
return probability;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the definition of probability mass function in the format of list
|
||||
* of pairs.
|
||||
*
|
||||
* @return definition of probability mass function.
|
||||
*/
|
||||
public List<Pair<T, Double>> getSamples() {
|
||||
final List<Pair<T, Double>> samples = new ArrayList<Pair<T, Double>>(probabilities.length);
|
||||
|
||||
for (int i = 0; i < probabilities.length; i++) {
|
||||
samples.add(new Pair<T, Double>(singletons.get(i), probabilities[i]));
|
||||
}
|
||||
|
||||
return samples;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a random value sampled from this distribution.
|
||||
*
|
||||
* @return a random value.
|
||||
*/
|
||||
public T sample() {
|
||||
final double randomValue = random.nextDouble();
|
||||
double sum = 0;
|
||||
|
||||
for (int i = 0; i < probabilities.length; i++) {
|
||||
sum += probabilities[i];
|
||||
if (randomValue < sum) {
|
||||
return singletons.get(i);
|
||||
}
|
||||
}
|
||||
|
||||
/* This should never happen, but it ensures we will return a correct
|
||||
* object in case the loop above has some floating point inequality
|
||||
* problem on the final iteration. */
|
||||
return singletons.get(singletons.size() - 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a random sample from the distribution.
|
||||
*
|
||||
* @param sampleSize the number of random values to generate.
|
||||
* @return an array representing the random sample.
|
||||
* @throws NotStrictlyPositiveException if {@code sampleSize} is not
|
||||
* positive.
|
||||
*/
|
||||
public T[] sample(int sampleSize) throws NotStrictlyPositiveException {
|
||||
if (sampleSize <= 0) {
|
||||
throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
|
||||
sampleSize);
|
||||
}
|
||||
@SuppressWarnings("unchecked")
|
||||
final T[]out = (T[]) Array.newInstance(singletons.get(0).getClass(), sampleSize);
|
||||
|
||||
for (int i = 0; i < sampleSize; i++) {
|
||||
out[i] = sample();
|
||||
}
|
||||
|
||||
return out;
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,209 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.distribution;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.MathArithmeticException;
|
||||
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
||||
import org.apache.commons.math3.exception.NotPositiveException;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.Well19937c;
|
||||
import org.apache.commons.math3.util.Pair;
|
||||
|
||||
/**
|
||||
* Implementation of the integer-valued discrete distribution.
|
||||
*
|
||||
* Note: values with zero-probability are allowed but they do not extend the
|
||||
* support.
|
||||
*
|
||||
* @see <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">Discrete probability distribution (Wikipedia)</a>
|
||||
* @see <a href="http://mathworld.wolfram.com/DiscreteDistribution.html">Discrete Distribution (MathWorld)</a>
|
||||
* @version $Id: DiscreteIntegerDistribution.java 169 2013-03-08 09:02:38Z wydrych $
|
||||
*/
|
||||
public class DiscreteIntegerDistribution extends AbstractIntegerDistribution {
|
||||
|
||||
/** Serializable UID. */
|
||||
private static final long serialVersionUID = 20130308L;
|
||||
|
||||
/**
|
||||
* {@link DiscreteDistribution} instance (using the {@link Integer} wrapper)
|
||||
* used to generate samples.
|
||||
*/
|
||||
protected final DiscreteDistribution<Integer> innerDistribution;
|
||||
|
||||
/**
|
||||
* Create a discrete distribution using the given probability mass function
|
||||
* definition.
|
||||
*
|
||||
* @param singletons array of random variable values.
|
||||
* @param probabilities array of probabilities.
|
||||
* @throws DimensionMismatchException if
|
||||
* {@code singletons.length != probabilities.length}
|
||||
* @throws NotPositiveException if probability of at least one value is
|
||||
* negative.
|
||||
* @throws MathArithmeticException if the probabilities sum to zero.
|
||||
* @throws MathIllegalArgumentException if probability of at least one value
|
||||
* is infinite.
|
||||
*/
|
||||
public DiscreteIntegerDistribution(final int[] singletons, final double[] probabilities)
|
||||
throws DimensionMismatchException, NotPositiveException, MathArithmeticException, MathIllegalArgumentException {
|
||||
this(new Well19937c(), singletons, probabilities);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a discrete distribution using the given random number generator
|
||||
* and probability mass function definition.
|
||||
*
|
||||
* @param rng random number generator.
|
||||
* @param singletons array of random variable values.
|
||||
* @param probabilities array of probabilities.
|
||||
* @throws DimensionMismatchException if
|
||||
* {@code singletons.length != probabilities.length}
|
||||
* @throws NotPositiveException if probability of at least one value is
|
||||
* negative.
|
||||
* @throws MathArithmeticException if the probabilities sum to zero.
|
||||
* @throws MathIllegalArgumentException if probability of at least one value
|
||||
* is infinite.
|
||||
*/
|
||||
public DiscreteIntegerDistribution(final RandomGenerator rng,
|
||||
final int[] singletons, final double[] probabilities)
|
||||
throws DimensionMismatchException, NotPositiveException, MathArithmeticException, MathIllegalArgumentException {
|
||||
super(rng);
|
||||
if (singletons.length != probabilities.length) {
|
||||
throw new DimensionMismatchException(probabilities.length, singletons.length);
|
||||
}
|
||||
|
||||
final List<Pair<Integer, Double>> samples = new ArrayList<Pair<Integer, Double>>(singletons.length);
|
||||
|
||||
for (int i = 0; i < singletons.length; i++) {
|
||||
samples.add(new Pair<Integer, Double>(singletons[i], probabilities[i]));
|
||||
}
|
||||
|
||||
innerDistribution = new DiscreteDistribution<Integer>(rng, samples);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
public double probability(final int x) {
|
||||
return innerDistribution.probability(x);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
public double cumulativeProbability(final int x) {
|
||||
double probability = 0;
|
||||
|
||||
for (final Pair<Integer, Double> sample : innerDistribution.getSamples()) {
|
||||
if (sample.getKey() <= x) {
|
||||
probability += sample.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
return probability;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @return {@code sum(singletons[i] * probabilities[i])}
|
||||
*/
|
||||
public double getNumericalMean() {
|
||||
double mean = 0;
|
||||
|
||||
for (final Pair<Integer, Double> sample : innerDistribution.getSamples()) {
|
||||
mean += sample.getValue() * sample.getKey();
|
||||
}
|
||||
|
||||
return mean;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])}
|
||||
*/
|
||||
public double getNumericalVariance() {
|
||||
double mean = 0;
|
||||
double meanOfSquares = 0;
|
||||
|
||||
for (final Pair<Integer, Double> sample : innerDistribution.getSamples()) {
|
||||
mean += sample.getValue() * sample.getKey();
|
||||
meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
|
||||
}
|
||||
|
||||
return meanOfSquares - mean * mean;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* Returns the lowest value with non-zero probability.
|
||||
*
|
||||
* @return the lowest value with non-zero probability.
|
||||
*/
|
||||
public int getSupportLowerBound() {
|
||||
int min = Integer.MAX_VALUE;
|
||||
for (final Pair<Integer, Double> sample : innerDistribution.getSamples()) {
|
||||
if (sample.getKey() < min && sample.getValue() > 0) {
|
||||
min = sample.getKey();
|
||||
}
|
||||
}
|
||||
|
||||
return min;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* Returns the highest value with non-zero probability.
|
||||
*
|
||||
* @return the highest value with non-zero probability.
|
||||
*/
|
||||
public int getSupportUpperBound() {
|
||||
int max = Integer.MIN_VALUE;
|
||||
for (final Pair<Integer, Double> sample : innerDistribution.getSamples()) {
|
||||
if (sample.getKey() > max && sample.getValue() > 0) {
|
||||
max = sample.getKey();
|
||||
}
|
||||
}
|
||||
|
||||
return max;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* The support of this distribution is connected.
|
||||
*
|
||||
* @return {@code true}
|
||||
*/
|
||||
public boolean isSupportConnected() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public int sample() {
|
||||
return innerDistribution.sample();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,245 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.distribution;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.MathArithmeticException;
|
||||
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
||||
import org.apache.commons.math3.exception.NotPositiveException;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.Well19937c;
|
||||
import org.apache.commons.math3.util.Pair;
|
||||
|
||||
/**
|
||||
* Implementation of the discrete distribution on the reals.
|
||||
*
|
||||
* Note: values with zero-probability are allowed but they do not extend the
|
||||
* support.
|
||||
*
|
||||
* @see <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">Discrete probability distribution (Wikipedia)</a>
|
||||
* @see <a href="http://mathworld.wolfram.com/DiscreteDistribution.html">Discrete Distribution (MathWorld)</a>
|
||||
* @version $Id: DiscreteRealDistribution.java 169 2013-03-08 09:02:38Z wydrych $
|
||||
*/
|
||||
public class DiscreteRealDistribution extends AbstractRealDistribution {
|
||||
|
||||
/** Serializable UID. */
|
||||
private static final long serialVersionUID = 20130308L;
|
||||
|
||||
/**
|
||||
* {@link DiscreteDistribution} instance (using the {@link Double} wrapper)
|
||||
* used to generate samples.
|
||||
*/
|
||||
protected final DiscreteDistribution<Double> innerDistribution;
|
||||
|
||||
/**
|
||||
* Create a discrete distribution using the given probability mass function
|
||||
* definition.
|
||||
*
|
||||
* @param singletons array of random variable values.
|
||||
* @param probabilities array of probabilities.
|
||||
* @throws DimensionMismatchException if
|
||||
* {@code singletons.length != probabilities.length}
|
||||
* @throws NotPositiveException if probability of at least one value is
|
||||
* negative.
|
||||
* @throws MathArithmeticException if the probabilities sum to zero.
|
||||
* @throws MathIllegalArgumentException if probability of at least one value
|
||||
* is infinite.
|
||||
*/
|
||||
public DiscreteRealDistribution(final double[] singletons, final double[] probabilities)
|
||||
throws DimensionMismatchException, NotPositiveException, MathArithmeticException, MathIllegalArgumentException {
|
||||
this(new Well19937c(), singletons, probabilities);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a discrete distribution using the given random number generator
|
||||
* and probability mass function definition.
|
||||
*
|
||||
* @param rng random number generator.
|
||||
* @param singletons array of random variable values.
|
||||
* @param probabilities array of probabilities.
|
||||
* @throws DimensionMismatchException if
|
||||
* {@code singletons.length != probabilities.length}
|
||||
* @throws NotPositiveException if probability of at least one value is
|
||||
* negative.
|
||||
* @throws MathArithmeticException if the probabilities sum to zero.
|
||||
* @throws MathIllegalArgumentException if probability of at least one value
|
||||
* is infinite.
|
||||
*/
|
||||
public DiscreteRealDistribution(final RandomGenerator rng,
|
||||
final double[] singletons, final double[] probabilities)
|
||||
throws DimensionMismatchException, NotPositiveException, MathArithmeticException, MathIllegalArgumentException {
|
||||
super(rng);
|
||||
if (singletons.length != probabilities.length) {
|
||||
throw new DimensionMismatchException(probabilities.length, singletons.length);
|
||||
}
|
||||
|
||||
List<Pair<Double, Double>> samples = new ArrayList<Pair<Double, Double>>(singletons.length);
|
||||
|
||||
for (int i = 0; i < singletons.length; i++) {
|
||||
samples.add(new Pair<Double, Double>(singletons[i], probabilities[i]));
|
||||
}
|
||||
|
||||
innerDistribution = new DiscreteDistribution<Double>(rng, samples);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public double probability(final double x) {
|
||||
return innerDistribution.probability(x);
|
||||
}
|
||||
|
||||
/**
|
||||
* For a random variable {@code X} whose values are distributed according to
|
||||
* this distribution, this method returns {@code P(X = x)}. In other words,
|
||||
* this method represents the probability mass function (PMF) for the
|
||||
* distribution.
|
||||
*
|
||||
* @param x the point at which the PMF is evaluated
|
||||
* @return the value of the probability mass function at point {@code x}
|
||||
*/
|
||||
public double density(final double x) {
|
||||
return probability(x);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
public double cumulativeProbability(final double x) {
|
||||
double probability = 0;
|
||||
|
||||
for (final Pair<Double, Double> sample : innerDistribution.getSamples()) {
|
||||
if (sample.getKey() <= x) {
|
||||
probability += sample.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
return probability;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @return {@code sum(singletons[i] * probabilities[i])}
|
||||
*/
|
||||
public double getNumericalMean() {
|
||||
double mean = 0;
|
||||
|
||||
for (final Pair<Double, Double> sample : innerDistribution.getSamples()) {
|
||||
mean += sample.getValue() * sample.getKey();
|
||||
}
|
||||
|
||||
return mean;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])}
|
||||
*/
|
||||
public double getNumericalVariance() {
|
||||
double mean = 0;
|
||||
double meanOfSquares = 0;
|
||||
|
||||
for (final Pair<Double, Double> sample : innerDistribution.getSamples()) {
|
||||
mean += sample.getValue() * sample.getKey();
|
||||
meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
|
||||
}
|
||||
|
||||
return meanOfSquares - mean * mean;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* Returns the lowest value with non-zero probability.
|
||||
*
|
||||
* @return the lowest value with non-zero probability.
|
||||
*/
|
||||
public double getSupportLowerBound() {
|
||||
double min = Double.POSITIVE_INFINITY;
|
||||
for (final Pair<Double, Double> sample : innerDistribution.getSamples()) {
|
||||
if (sample.getKey() < min && sample.getValue() > 0) {
|
||||
min = sample.getKey();
|
||||
}
|
||||
}
|
||||
|
||||
return min;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* Returns the highest value with non-zero probability.
|
||||
*
|
||||
* @return the highest value with non-zero probability.
|
||||
*/
|
||||
public double getSupportUpperBound() {
|
||||
double max = Double.NEGATIVE_INFINITY;
|
||||
for (final Pair<Double, Double> sample : innerDistribution.getSamples()) {
|
||||
if (sample.getKey() > max && sample.getValue() > 0) {
|
||||
max = sample.getKey();
|
||||
}
|
||||
}
|
||||
|
||||
return max;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* The support of this distribution includes the lower bound.
|
||||
*
|
||||
* @return {@code true}
|
||||
*/
|
||||
public boolean isSupportLowerBoundInclusive() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* The support of this distribution includes the upper bound.
|
||||
*
|
||||
* @return {@code true}
|
||||
*/
|
||||
public boolean isSupportUpperBoundInclusive() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* The support of this distribution is connected.
|
||||
*
|
||||
* @return {@code true}
|
||||
*/
|
||||
public boolean isSupportConnected() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public double sample() {
|
||||
return innerDistribution.sample();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.distribution;
|
||||
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.MathArithmeticException;
|
||||
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
||||
import org.apache.commons.math3.exception.NotPositiveException;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Test class for {@link DiscreteIntegerDistribution}.
|
||||
*
|
||||
* @version $Id: DiscreteIntegerDistributionTest.java 161 2013-03-07 09:47:32Z wydrych $
|
||||
*/
|
||||
public class DiscreteIntegerDistributionTest {
|
||||
|
||||
/**
|
||||
* The distribution object used for testing.
|
||||
*/
|
||||
private final DiscreteIntegerDistribution testDistribution;
|
||||
|
||||
/**
|
||||
* Creates the default distribution object uded for testing.
|
||||
*/
|
||||
public DiscreteIntegerDistributionTest() {
|
||||
// Non-sorted singleton array with duplicates should be allowed.
|
||||
// Values with zero-probability do not extend the support.
|
||||
testDistribution = new DiscreteIntegerDistribution(
|
||||
new int[]{3, -1, 3, 7, -2, 8},
|
||||
new double[]{0.2, 0.2, 0.3, 0.3, 0.0, 0.0});
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the {@link DiscreteIntegerDistribution} constructor throws
|
||||
* exceptions for ivalid data.
|
||||
*/
|
||||
@Test
|
||||
public void testExceptions() {
|
||||
DiscreteIntegerDistribution invalid = null;
|
||||
try {
|
||||
invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0});
|
||||
Assert.fail("Expected DimensionMismatchException");
|
||||
} catch (DimensionMismatchException e) {
|
||||
}
|
||||
try {
|
||||
invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0, -1.0});
|
||||
Assert.fail("Expected NotPositiveException");
|
||||
} catch (NotPositiveException e) {
|
||||
}
|
||||
try {
|
||||
invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0, 0.0});
|
||||
Assert.fail("Expected MathArithmeticException");
|
||||
} catch (MathArithmeticException e) {
|
||||
}
|
||||
try {
|
||||
invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0, Double.NaN});
|
||||
Assert.fail("Expected MathArithmeticException");
|
||||
} catch (MathArithmeticException e) {
|
||||
}
|
||||
try {
|
||||
invalid = new DiscreteIntegerDistribution(new int[]{1, 2}, new double[]{0.0, Double.POSITIVE_INFINITY});
|
||||
Assert.fail("Expected MathIllegalArgumentException");
|
||||
} catch (MathIllegalArgumentException e) {
|
||||
}
|
||||
Assert.assertNull("Expected non-initialized DiscreteRealDistribution", invalid);
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper probability values.
|
||||
*/
|
||||
@Test
|
||||
public void testProbability() {
|
||||
int[] points = new int[]{-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8};
|
||||
double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0};
|
||||
for (int p = 0; p < points.length; p++) {
|
||||
double probability = testDistribution.probability(points[p]);
|
||||
Assert.assertEquals(results[p], probability, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper cumulative probability values.
|
||||
*/
|
||||
@Test
|
||||
public void testCumulativeProbability() {
|
||||
int[] points = new int[]{-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8};
|
||||
double[] results = new double[]{0, 0.2, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7, 0.7, 1.0, 1.0};
|
||||
for (int p = 0; p < points.length; p++) {
|
||||
double probability = testDistribution.cumulativeProbability(points[p]);
|
||||
Assert.assertEquals(results[p], probability, 1e-10);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper mean value.
|
||||
*/
|
||||
@Test
|
||||
public void testGetNumericalMean() {
|
||||
Assert.assertEquals(3.4, testDistribution.getNumericalMean(), 1e-10);
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper variance.
|
||||
*/
|
||||
@Test
|
||||
public void testGetNumericalVariance() {
|
||||
Assert.assertEquals(7.84, testDistribution.getNumericalVariance(), 1e-10);
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper lower bound.
|
||||
*/
|
||||
@Test
|
||||
public void testGetSupportLowerBound() {
|
||||
Assert.assertEquals(-1, testDistribution.getSupportLowerBound());
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper upper bound.
|
||||
*/
|
||||
@Test
|
||||
public void testGetSupportUpperBound() {
|
||||
Assert.assertEquals(7, testDistribution.getSupportUpperBound());
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns properly that the support is connected.
|
||||
*/
|
||||
@Test
|
||||
public void testIsSupportConnected() {
|
||||
Assert.assertTrue(testDistribution.isSupportConnected());
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests sampling.
|
||||
*/
|
||||
@Test
|
||||
public void testSample() {
|
||||
final int n = 1000000;
|
||||
testDistribution.reseedRandomGenerator(-334759360); // fixed seed
|
||||
final int[] samples = testDistribution.sample(n);
|
||||
Assert.assertEquals(n, samples.length);
|
||||
double sum = 0;
|
||||
double sumOfSquares = 0;
|
||||
for (int i = 0; i < samples.length; i++) {
|
||||
sum += samples[i];
|
||||
sumOfSquares += samples[i] * samples[i];
|
||||
}
|
||||
Assert.assertEquals(testDistribution.getNumericalMean(),
|
||||
sum / n, 1e-2);
|
||||
Assert.assertEquals(testDistribution.getNumericalVariance(),
|
||||
sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.distribution;
|
||||
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.MathArithmeticException;
|
||||
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
||||
import org.apache.commons.math3.exception.NotPositiveException;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Test class for {@link DiscreteRealDistribution}.
|
||||
*
|
||||
* @version $Id: DiscreteRealDistributionTest.java 161 2013-03-07 09:47:32Z wydrych $
|
||||
*/
|
||||
public class DiscreteRealDistributionTest {
|
||||
|
||||
/**
|
||||
* The distribution object used for testing.
|
||||
*/
|
||||
private final DiscreteRealDistribution testDistribution;
|
||||
|
||||
/**
|
||||
* Creates the default distribution object uded for testing.
|
||||
*/
|
||||
public DiscreteRealDistributionTest() {
|
||||
// Non-sorted singleton array with duplicates should be allowed.
|
||||
// Values with zero-probability do not extend the support.
|
||||
testDistribution = new DiscreteRealDistribution(
|
||||
new double[]{3.0, -1.0, 3.0, 7.0, -2.0, 8.0},
|
||||
new double[]{0.2, 0.2, 0.3, 0.3, 0.0, 0.0});
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the {@link DiscreteRealDistribution} constructor throws
|
||||
* exceptions for ivalid data.
|
||||
*/
|
||||
@Test
|
||||
public void testExceptions() {
|
||||
DiscreteRealDistribution invalid = null;
|
||||
try {
|
||||
invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0});
|
||||
Assert.fail("Expected DimensionMismatchException");
|
||||
} catch (DimensionMismatchException e) {
|
||||
}
|
||||
try{
|
||||
invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, -1.0});
|
||||
Assert.fail("Expected NotPositiveException");
|
||||
} catch (NotPositiveException e) {
|
||||
}
|
||||
try {
|
||||
invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, 0.0});
|
||||
Assert.fail("Expected MathArithmeticException");
|
||||
} catch (MathArithmeticException e) {
|
||||
}
|
||||
try {
|
||||
invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, Double.NaN});
|
||||
Assert.fail("Expected MathArithmeticException");
|
||||
} catch (MathArithmeticException e) {
|
||||
}
|
||||
try {
|
||||
invalid = new DiscreteRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, Double.POSITIVE_INFINITY});
|
||||
Assert.fail("Expected MathIllegalArgumentException");
|
||||
} catch (MathIllegalArgumentException e) {
|
||||
}
|
||||
Assert.assertNull("Expected non-initialized DiscreteRealDistribution", invalid);
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper probability values.
|
||||
*/
|
||||
@Test
|
||||
public void testProbability() {
|
||||
double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
|
||||
double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0};
|
||||
for (int p = 0; p < points.length; p++) {
|
||||
double density = testDistribution.probability(points[p]);
|
||||
Assert.assertEquals(results[p], density, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper density values.
|
||||
*/
|
||||
@Test
|
||||
public void testDensity() {
|
||||
double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
|
||||
double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0};
|
||||
for (int p = 0; p < points.length; p++) {
|
||||
double density = testDistribution.density(points[p]);
|
||||
Assert.assertEquals(results[p], density, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper cumulative probability values.
|
||||
*/
|
||||
@Test
|
||||
public void testCumulativeProbability() {
|
||||
double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
|
||||
double[] results = new double[]{0, 0.2, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7, 0.7, 1.0, 1.0};
|
||||
for (int p = 0; p < points.length; p++) {
|
||||
double probability = testDistribution.cumulativeProbability(points[p]);
|
||||
Assert.assertEquals(results[p], probability, 1e-10);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper mean value.
|
||||
*/
|
||||
@Test
|
||||
public void testGetNumericalMean() {
|
||||
Assert.assertEquals(3.4, testDistribution.getNumericalMean(), 1e-10);
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper variance.
|
||||
*/
|
||||
@Test
|
||||
public void testGetNumericalVariance() {
|
||||
Assert.assertEquals(7.84, testDistribution.getNumericalVariance(), 1e-10);
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper lower bound.
|
||||
*/
|
||||
@Test
|
||||
public void testGetSupportLowerBound() {
|
||||
Assert.assertEquals(-1, testDistribution.getSupportLowerBound(), 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns proper upper bound.
|
||||
*/
|
||||
@Test
|
||||
public void testGetSupportUpperBound() {
|
||||
Assert.assertEquals(7, testDistribution.getSupportUpperBound(), 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns properly that the support includes the
|
||||
* lower bound.
|
||||
*/
|
||||
@Test
|
||||
public void testIsSupportLowerBoundInclusive() {
|
||||
Assert.assertTrue(testDistribution.isSupportLowerBoundInclusive());
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns properly that the support includes the
|
||||
* upper bound.
|
||||
*/
|
||||
@Test
|
||||
public void testIsSupportUpperBoundInclusive() {
|
||||
Assert.assertTrue(testDistribution.isSupportUpperBoundInclusive());
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if the distribution returns properly that the support is connected.
|
||||
*/
|
||||
@Test
|
||||
public void testIsSupportConnected() {
|
||||
Assert.assertTrue(testDistribution.isSupportConnected());
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests sampling.
|
||||
*/
|
||||
@Test
|
||||
public void testSample() {
|
||||
final int n = 1000000;
|
||||
testDistribution.reseedRandomGenerator(-334759360); // fixed seed
|
||||
final double[] samples = testDistribution.sample(n);
|
||||
Assert.assertEquals(n, samples.length);
|
||||
double sum = 0;
|
||||
double sumOfSquares = 0;
|
||||
for (int i = 0; i < samples.length; i++) {
|
||||
sum += samples[i];
|
||||
sumOfSquares += samples[i] * samples[i];
|
||||
}
|
||||
Assert.assertEquals(testDistribution.getNumericalMean(),
|
||||
sum / n, 1e-2);
|
||||
Assert.assertEquals(testDistribution.getNumericalVariance(),
|
||||
sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue