MATH-1351

New sampling API for multivariate distributions (similar to changes performed for MATH-1158).

Unit test file renamed in accordance to the class being tested.
One failing test "@Ignore"d (see comments on the bug-tracking system).
This commit is contained in:
Gilles 2016-03-28 13:45:42 +02:00
parent 880b04814c
commit 3066a8085f
7 changed files with 169 additions and 255 deletions

View File

@ -18,7 +18,7 @@ package org.apache.commons.math4.distribution;
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.math4.random.RandomGenerator;
import org.apache.commons.math4.rng.UniformRandomProvider;
/**
* Base class for multivariate probability distributions.
@ -27,27 +27,16 @@ import org.apache.commons.math4.random.RandomGenerator;
*/
public abstract class AbstractMultivariateRealDistribution
implements MultivariateRealDistribution {
/** RNG instance used to generate samples from the distribution. */
protected final RandomGenerator random;
/** The number of dimensions or columns in the multivariate distribution. */
private final int dimension;
/**
* @param rng Random number generator.
* @param n Number of dimensions.
*/
protected AbstractMultivariateRealDistribution(RandomGenerator rng,
int n) {
random = rng;
protected AbstractMultivariateRealDistribution(int n) {
dimension = n;
}
/** {@inheritDoc} */
@Override
public void reseedRandomGenerator(long seed) {
random.setSeed(seed);
}
/** {@inheritDoc} */
@Override
public int getDimension() {
@ -56,19 +45,28 @@ public abstract class AbstractMultivariateRealDistribution
/** {@inheritDoc} */
@Override
public abstract double[] sample();
public abstract Sampler createSampler(UniformRandomProvider rng);
/** {@inheritDoc} */
@Override
public double[][] sample(final int sampleSize) {
if (sampleSize <= 0) {
/**
* Utility function for creating {@code n} vectors generated by the
* given {@code sampler}.
*
* @param n Number of samples.
* @param sampler Sampler.
* @return an array of size {@code n} whose elements are random vectors
* sampled from this distribution.
*/
public static double[][] sample(int n,
MultivariateRealDistribution.Sampler sampler) {
if (n <= 0) {
throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
sampleSize);
n);
}
final double[][] out = new double[sampleSize][dimension];
for (int i = 0; i < sampleSize; i++) {
out[i] = sample();
final double[][] samples = new double[n][];
for (int i = 0; i < n; i++) {
samples[i] = sampler.sample();
}
return out;
return samples;
}
}

View File

@ -21,7 +21,6 @@ import java.util.List;
import org.apache.commons.math4.exception.DimensionMismatchException;
import org.apache.commons.math4.exception.NotPositiveException;
import org.apache.commons.math4.random.RandomGenerator;
import org.apache.commons.math4.util.Pair;
/**
@ -33,63 +32,42 @@ import org.apache.commons.math4.util.Pair;
*/
public class MixtureMultivariateNormalDistribution
extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> {
/**
* Creates a multivariate normal mixture distribution.
* <p>
* <b>Note:</b> this constructor will implicitly create an instance of
* {@link org.apache.commons.math4.random.Well19937c Well19937c} as random
* generator to be used for sampling only (see {@link #sample()} and
* {@link #sample(int)}). In case no sampling is needed for the created
* distribution, it is advised to pass {@code null} as random generator via
* the appropriate constructors to avoid the additional initialisation
* overhead.
*
* @param weights Weights of each component.
* @param means Mean vector for each component.
* @param covariances Covariance matrix for each component.
*/
public MixtureMultivariateNormalDistribution(double[] weights,
double[][] means,
double[][][] covariances) {
super(createComponents(weights, means, covariances));
}
/**
* Creates a mixture model from a list of distributions and their
* associated weights.
* <p>
* <b>Note:</b> this constructor will implicitly create an instance of
* {@link org.apache.commons.math4.random.Well19937c Well19937c} as random
* generator to be used for sampling only (see {@link #sample()} and
* {@link #sample(int)}). In case no sampling is needed for the created
* distribution, it is advised to pass {@code null} as random generator via
* the appropriate constructors to avoid the additional initialisation
* overhead.
*
* @param components List of (weight, distribution) pairs from which to sample.
*/
public MixtureMultivariateNormalDistribution(List<Pair<Double, MultivariateNormalDistribution>> components) {
super(components);
}
/**
* Creates a mixture model from a list of distributions and their
* associated weights.
*
* @param rng Random number generator.
* @param components Distributions from which to sample.
* @throws NotPositiveException if any of the weights is negative.
* @throws DimensionMismatchException if not all components have the same
* number of variables.
*/
public MixtureMultivariateNormalDistribution(RandomGenerator rng,
List<Pair<Double, MultivariateNormalDistribution>> components)
throws NotPositiveException, DimensionMismatchException {
super(rng, components);
public MixtureMultivariateNormalDistribution(List<Pair<Double, MultivariateNormalDistribution>> components)
throws NotPositiveException,
DimensionMismatchException {
super(components);
}
/**
* Creates a multivariate normal mixture distribution.
*
* @param weights Weights of each component.
* @param means Mean vector for each component.
* @param covariances Covariance matrix for each component.
* @throws NotPositiveException if any of the weights is negative.
* @throws DimensionMismatchException if not all components have the same
* number of variables.
*/
public MixtureMultivariateNormalDistribution(double[] weights,
double[][] means,
double[][][] covariances)
throws NotPositiveException,
DimensionMismatchException {
this(createComponents(weights, means, covariances));
}
/**
* Creates components of the mixture model.
*
* @param weights Weights of each component.
* @param means Mean vector for each component.
* @param covariances Covariance matrix for each component.

View File

@ -23,8 +23,7 @@ import org.apache.commons.math4.exception.DimensionMismatchException;
import org.apache.commons.math4.exception.MathArithmeticException;
import org.apache.commons.math4.exception.NotPositiveException;
import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.math4.random.RandomGenerator;
import org.apache.commons.math4.random.Well19937c;
import org.apache.commons.math4.rng.UniformRandomProvider;
import org.apache.commons.math4.util.Pair;
/**
@ -45,33 +44,14 @@ public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistr
/**
* Creates a mixture model from a list of distributions and their
* associated weights.
* <p>
* <b>Note:</b> this constructor will implicitly create an instance of
* {@link Well19937c} as random generator to be used for sampling only (see
* {@link #sample()} and {@link #sample(int)}). In case no sampling is
* needed for the created distribution, it is advised to pass {@code null}
* as random generator via the appropriate constructors to avoid the
* additional initialisation overhead.
*
* @param components List of (weight, distribution) pairs from which to sample.
*/
public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
this(new Well19937c(), components);
}
/**
* Creates a mixture model from a list of distributions and their
* associated weights.
*
* @param rng Random number generator.
* @param components Distributions from which to sample.
* @throws NotPositiveException if any of the weights is negative.
* @throws DimensionMismatchException if not all components have the same
* number of variables.
*/
public MixtureMultivariateRealDistribution(RandomGenerator rng,
List<Pair<Double, T>> components) {
super(rng, components.get(0).getSecond().getDimension());
public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
super(components.get(0).getSecond().getDimension());
final int numComp = components.size();
final int dim = getDimension();
@ -112,49 +92,6 @@ public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistr
return p;
}
/** {@inheritDoc} */
@Override
public double[] sample() {
// Sampled values.
double[] vals = null;
// Determine which component to sample from.
final double randomValue = random.nextDouble();
double sum = 0;
for (int i = 0; i < weight.length; i++) {
sum += weight[i];
if (randomValue <= sum) {
// pick model i
vals = distribution.get(i).sample();
break;
}
}
if (vals == null) {
// This should never happen, but it ensures we won't return a null in
// case the loop above has some floating point inequality problem on
// the final iteration.
vals = distribution.get(weight.length - 1).sample();
}
return vals;
}
/** {@inheritDoc} */
@Override
public void reseedRandomGenerator(long seed) {
// Seed needs to be propagated to underlying components
// in order to maintain consistency between runs.
super.reseedRandomGenerator(seed);
for (int i = 0; i < distribution.size(); i++) {
// Make each component's seed different in order to avoid
// using the same sequence of random numbers.
distribution.get(i).reseedRandomGenerator(i + 1 + seed);
}
}
/**
* Gets the distributions that make up the mixture model.
*
@ -169,4 +106,61 @@ public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistr
return list;
}
/** {@inheritDoc} */
@Override
public MultivariateRealDistribution.Sampler createSampler(UniformRandomProvider rng) {
return new MixtureSampler(rng);
}
/**
* Sampler.
*/
private class MixtureSampler implements MultivariateRealDistribution.Sampler {
/** RNG */
private final UniformRandomProvider rng;
/** Sampler for each of the distribution in the mixture. */
private final MultivariateRealDistribution.Sampler[] samplers;
/**
* @param generator RNG.
*/
MixtureSampler(UniformRandomProvider generator) {
rng = generator;
samplers = new MultivariateRealDistribution.Sampler[weight.length];
for (int i = 0; i < weight.length; i++) {
samplers[i] = distribution.get(i).createSampler(rng);
}
}
/** {@inheritDoc} */
@Override
public double[] sample() {
// Sampled values.
double[] vals = null;
// Determine which component to sample from.
final double randomValue = rng.nextDouble();
double sum = 0;
for (int i = 0; i < weight.length; i++) {
sum += weight[i];
if (randomValue <= sum) {
// pick model i
vals = samplers[i].sample();
break;
}
}
if (vals == null) {
// This should never happen, but it ensures we won't return a null in
// case the loop above has some floating point inequality problem on
// the final iteration.
vals = samplers[weight.length - 1].sample();
}
return vals;
}
}
}

View File

@ -22,8 +22,7 @@ import org.apache.commons.math4.linear.EigenDecomposition;
import org.apache.commons.math4.linear.NonPositiveDefiniteMatrixException;
import org.apache.commons.math4.linear.RealMatrix;
import org.apache.commons.math4.linear.SingularMatrixException;
import org.apache.commons.math4.random.RandomGenerator;
import org.apache.commons.math4.random.Well19937c;
import org.apache.commons.math4.rng.UniformRandomProvider;
import org.apache.commons.math4.util.FastMath;
import org.apache.commons.math4.util.MathArrays;
@ -53,17 +52,11 @@ public class MultivariateNormalDistribution
/**
* Creates a multivariate normal distribution with the given mean vector and
* covariance matrix.
* <br/>
* <p>
* The number of dimensions is equal to the length of the mean vector
* and to the number of rows and columns of the covariance matrix.
* It is frequently written as "p" in formulae.
* <p>
* <b>Note:</b> this constructor will implicitly create an instance of
* {@link Well19937c} as random generator to be used for sampling only (see
* {@link #sample()} and {@link #sample(int)}). In case no sampling is
* needed for the created distribution, it is advised to pass {@code null}
* as random generator via the appropriate constructors to avoid the
* additional initialisation overhead.
* </p>
*
* @param means Vector of means.
* @param covariances Covariance matrix.
@ -76,37 +69,10 @@ public class MultivariateNormalDistribution
*/
public MultivariateNormalDistribution(final double[] means,
final double[][] covariances)
throws SingularMatrixException,
DimensionMismatchException,
NonPositiveDefiniteMatrixException {
this(new Well19937c(), means, covariances);
}
/**
* Creates a multivariate normal distribution with the given mean vector and
* covariance matrix.
* <br/>
* The number of dimensions is equal to the length of the mean vector
* and to the number of rows and columns of the covariance matrix.
* It is frequently written as "p" in formulae.
*
* @param rng Random Number Generator.
* @param means Vector of means.
* @param covariances Covariance matrix.
* @throws DimensionMismatchException if the arrays length are
* inconsistent.
* @throws SingularMatrixException if the eigenvalue decomposition cannot
* be performed on the provided covariance matrix.
* @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is
* negative.
*/
public MultivariateNormalDistribution(RandomGenerator rng,
final double[] means,
final double[][] covariances)
throws SingularMatrixException,
DimensionMismatchException,
NonPositiveDefiniteMatrixException {
super(rng, means.length);
super(means.length);
final int dim = means.length;
@ -210,21 +176,30 @@ public class MultivariateNormalDistribution
/** {@inheritDoc} */
@Override
public double[] sample() {
final int dim = getDimension();
final double[] normalVals = new double[dim];
public MultivariateRealDistribution.Sampler createSampler(final UniformRandomProvider rng) {
return new MultivariateRealDistribution.Sampler() {
/** Normal distribution. */
private final RealDistribution.Sampler gauss = new NormalDistribution().createSampler(rng);
for (int i = 0; i < dim; i++) {
normalVals[i] = random.nextGaussian();
}
/** {@inheritDoc} */
@Override
public double[] sample() {
final int dim = getDimension();
final double[] normalVals = new double[dim];
final double[] vals = samplingMatrix.operate(normalVals);
for (int i = 0; i < dim; i++) {
normalVals[i] = gauss.sample();
}
for (int i = 0; i < dim; i++) {
vals[i] += means[i];
}
final double[] vals = samplingMatrix.operate(normalVals);
return vals;
for (int i = 0; i < dim; i++) {
vals[i] += means[i];
}
return vals;
}
};
}
/**

View File

@ -16,7 +16,7 @@
*/
package org.apache.commons.math4.distribution;
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.rng.UniformRandomProvider;
/**
* Base interface for multivariate distributions on the reals.
@ -41,13 +41,6 @@ public interface MultivariateRealDistribution {
*/
double density(double[] x);
/**
* Reseeds the random generator used to generate samples.
*
* @param seed Seed with which to initialize the random number generator.
*/
void reseedRandomGenerator(long seed);
/**
* Gets the number of random variables of the distribution.
* It is the size of the array returned by the {@link #sample() sample}
@ -58,21 +51,27 @@ public interface MultivariateRealDistribution {
int getDimension();
/**
* Generates a random value vector sampled from this distribution.
* Creates a sampler.
*
* @return a random value vector.
* @param rng Generator of uniformly distributed numbers.
* @return a sampler that produces random numbers according this
* distribution.
*
* @since 4.0
*/
double[] sample();
Sampler createSampler(UniformRandomProvider rng);
/**
* Generates a list of a random value vectors from the distribution.
* Sampling functionality.
*
* @param sampleSize the number of random vectors to generate.
* @return an array representing the random samples.
* @throws org.apache.commons.math4.exception.NotStrictlyPositiveException
* if {@code sampleSize} is not positive.
*
* @see #sample()
* @since 4.0
*/
double[][] sample(int sampleSize) throws NotStrictlyPositiveException;
interface Sampler {
/**
* Generates a random value vector sampled from this distribution.
*
* @return a random value vector.
*/
double[] sample();
}
}

View File

@ -23,16 +23,16 @@ import org.apache.commons.math4.distribution.MixtureMultivariateRealDistribution
import org.apache.commons.math4.distribution.MultivariateNormalDistribution;
import org.apache.commons.math4.exception.MathArithmeticException;
import org.apache.commons.math4.exception.NotPositiveException;
import org.apache.commons.math4.rng.RandomSource;
import org.apache.commons.math4.util.Pair;
import org.junit.Assert;
import org.junit.Test;
import org.junit.Ignore;
/**
* Test that demonstrates the use of {@link MixtureMultivariateRealDistribution}
* in order to create a mixture model composed of {@link MultivariateNormalDistribution
* normal distributions}.
* Test case {@link MixtureMultivariateNormalDistribution}.
*/
public class MultivariateNormalMixtureModelDistributionTest {
public class MixtureMultivariateNormalDistributionTest {
@Test
public void testNonUnitWeightSum() {
@ -43,8 +43,8 @@ public class MultivariateNormalMixtureModelDistributionTest {
{ -1.1, 2.0 } },
{ { 3.5, 1.5 },
{ 1.5, 3.5 } } };
final MultivariateNormalMixtureModelDistribution d
= create(weights, means, covariances);
final MixtureMultivariateNormalDistribution d
= new MixtureMultivariateNormalDistribution(weights, means, covariances);
final List<Pair<Double, MultivariateNormalDistribution>> comp = d.getComponents();
@ -61,7 +61,7 @@ public class MultivariateNormalMixtureModelDistributionTest {
{ -1.1, 2.0 } },
{ { 3.5, 1.5 },
{ 1.5, 3.5 } } };
create(weights, means, covariances);
new MixtureMultivariateNormalDistribution(weights, means, covariances);
}
@Test(expected=NotPositiveException.class)
@ -73,7 +73,7 @@ public class MultivariateNormalMixtureModelDistributionTest {
{ -1.1, 2.0 } },
{ { 3.5, 1.5 },
{ 1.5, 3.5 } } };
create(negativeWeights, means, covariances);
new MixtureMultivariateNormalDistribution(negativeWeights, means, covariances);
}
/**
@ -88,8 +88,8 @@ public class MultivariateNormalMixtureModelDistributionTest {
{ -1.1, 2.0 } },
{ { 3.5, 1.5 },
{ 1.5, 3.5 } } };
final MultivariateNormalMixtureModelDistribution d
= create(weights, means, covariances);
final MixtureMultivariateNormalDistribution d
= new MixtureMultivariateNormalDistribution(weights, means, covariances);
// Test vectors
final double[][] testValues = { { -1.5, 2 },
@ -115,7 +115,7 @@ public class MultivariateNormalMixtureModelDistributionTest {
/**
* Test the accuracy of sampling from the distribution.
*/
@Test
@Ignore@Test
public void testSampling() {
final double[] weights = { 0.3, 0.7 };
final double[][] means = { { -1.5, 2.0 },
@ -124,44 +124,23 @@ public class MultivariateNormalMixtureModelDistributionTest {
{ -1.1, 2.0 } },
{ { 3.5, 1.5 },
{ 1.5, 3.5 } } };
final MultivariateNormalMixtureModelDistribution d
= create(weights, means, covariances);
d.reseedRandomGenerator(50);
final MixtureMultivariateNormalDistribution d =
new MixtureMultivariateNormalDistribution(weights, means, covariances);
final MultivariateRealDistribution.Sampler sampler =
d.createSampler(RandomSource.create(RandomSource.WELL_19937_C, 50));
final double[][] correctSamples = getCorrectSamples();
final int n = correctSamples.length;
final double[][] samples = d.sample(n);
final double[][] samples = AbstractMultivariateRealDistribution.sample(n, sampler);
for (int i = 0; i < n; i++) {
for (int j = 0; j < samples[i].length; j++) {
Assert.assertEquals(correctSamples[i][j], samples[i][j], 1e-16);
Assert.assertEquals("sample[" + j + "]",
correctSamples[i][j], samples[i][j], 1e-16);
}
}
}
/**
* Creates a mixture of Gaussian distributions.
*
* @param weights Weights.
* @param means Means.
* @param covariances Covariances.
* @return the mixture distribution.
*/
private MultivariateNormalMixtureModelDistribution create(double[] weights,
double[][] means,
double[][][] covariances) {
final List<Pair<Double, MultivariateNormalDistribution>> mvns
= new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
for (int i = 0; i < weights.length; i++) {
final MultivariateNormalDistribution dist
= new MultivariateNormalDistribution(means[i], covariances[i]);
mvns.add(new Pair<Double, MultivariateNormalDistribution>(weights[i], dist));
}
return new MultivariateNormalMixtureModelDistribution(mvns);
}
/**
* Values used in {@link #testSampling()}.
*/
@ -287,14 +266,3 @@ public class MultivariateNormalMixtureModelDistributionTest {
};
}
}
/**
* Class that implements a mixture of Gaussian ditributions.
*/
class MultivariateNormalMixtureModelDistribution
extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> {
public MultivariateNormalMixtureModelDistribution(List<Pair<Double, MultivariateNormalDistribution>> components) {
super(components);
}
}

View File

@ -20,6 +20,7 @@ package org.apache.commons.math4.distribution;
import org.apache.commons.math4.distribution.MultivariateNormalDistribution;
import org.apache.commons.math4.distribution.NormalDistribution;
import org.apache.commons.math4.linear.RealMatrix;
import org.apache.commons.math4.rng.RandomSource;
import org.apache.commons.math4.stat.correlation.Covariance;
import java.util.Random;
@ -75,11 +76,12 @@ public class MultivariateNormalDistributionTest {
final double[][] sigma = { { 2, -1.1 },
{ -1.1, 2 } };
final MultivariateNormalDistribution d = new MultivariateNormalDistribution(mu, sigma);
d.reseedRandomGenerator(50);
final MultivariateRealDistribution.Sampler sampler =
d.createSampler(RandomSource.create(RandomSource.WELL_19937_C, 50));
final int n = 500000;
final double[][] samples = AbstractMultivariateRealDistribution.sample(n, sampler);
final double[][] samples = d.sample(n);
final int dim = d.getDimension();
final double[] sampleMeans = new double[dim];