Fixed tests so they do not use equals on top level classes.

Patch submitted by Jared Becksfort.

JIRA: MATH-817

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1460726 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2013-03-25 15:47:31 +00:00
parent 78f8c198c0
commit 16e0a6d47b
1 changed files with 61 additions and 36 deletions

View File

@ -17,6 +17,7 @@
package org.apache.commons.math3.distribution.fitting;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution;
@ -25,10 +26,11 @@ import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.Pair;
import org.junit.Assert;
import org.junit.Test;
import org.junit.Ignore;
/**
* Test that demonstrates the use of
@ -36,9 +38,6 @@ import org.junit.Ignore;
*/
public class MultivariateNormalMixtureExpectationMaximizationTest {
// TODO reject initial mixes where means/covMats not computable with data
// numCols
@Test(expected = NotStrictlyPositiveException.class)
public void testNonEmptyData() {
// Should not accept empty data
@ -144,22 +143,34 @@ public class MultivariateNormalMixtureExpectationMaximizationTest {
fitter.fit(badInitialMix);
}
@Ignore@Test
@Test
public void testInitialMixture() {
// Testing initial mixture estimated from data
double[] correctWeights = new double[] { 0.5, 0.5 };
final double[] correctWeights = new double[] { 0.5, 0.5 };
MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
final double[][] correctMeans = new double[][] {
{-0.0021722935000328823, 3.5432892936887908},
{5.090902706507635, 8.68540656355283},
};
correctMVNs[0] = new MultivariateNormalDistribution(new double[] {
-0.0021722935000328823, 3.5432892936887908 },
new double[][] {
{ 4.537422569229048, 3.5266152281729304 },
{ 3.5266152281729304, 6.175448814169779 } });
correctMVNs[1] = new MultivariateNormalDistribution(new double[] {
5.090902706507635, 8.68540656355283 }, new double[][] {
{ 2.886778573963039, 1.5257474543463154 },
{ 1.5257474543463154, 3.3794567673616918 } });
final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
{ 4.537422569229048, 3.5266152281729304 },
{ 3.5266152281729304, 6.175448814169779 } });
correctCovMats[1] = new Array2DRowRealMatrix( new double[][] {
{ 2.886778573963039, 1.5257474543463154 },
{ 1.5257474543463154, 3.3794567673616918 } });
final MultivariateNormalDistribution[] correctMVNs = new
MultivariateNormalDistribution[2];
correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0],
correctCovMats[0].getData());
correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1],
correctCovMats[1].getData());
final MixtureMultivariateNormalDistribution initialMix
= MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(), 2);
@ -169,30 +180,41 @@ public class MultivariateNormalMixtureExpectationMaximizationTest {
.getComponents()) {
Assert.assertEquals(correctWeights[i], component.getFirst(),
Math.ulp(1d));
Assert.assertEquals(correctMVNs[i], component.getSecond());
final double[] means = component.getValue().getMeans();
Assert.assertTrue(Arrays.equals(correctMeans[i], means));
final RealMatrix covMat = component.getValue().getCovariances();
Assert.assertEquals(correctCovMats[i], covMat);
i++;
}
}
@Ignore@Test
@Test
public void testFit() {
// Test that the loglikelihood, weights, and models are determined and
// fitted correctly
double[][] data = getTestSamples();
double correctLogLikelihood = -4.292431006791994;
double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 };
MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
correctMVNs[0] = new MultivariateNormalDistribution(new double[] {
-1.4213112715121132, 1.6924690505757753 },
new double[][] {
{ 1.739356907285747, -0.5867644251487614 },
{ -0.5867644251487614, 1.0232932029324642 } });
correctMVNs[1] = new MultivariateNormalDistribution(new double[] {
4.213612224374709, 7.975621325853645 },
new double[][] {
{ 4.245384898007161, 2.5797798966382155 },
{ 2.5797798966382155, 3.9200272522448367 } });
final double[][] data = getTestSamples();
final double correctLogLikelihood = -4.292431006791994;
final double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 };
final double[][] correctMeans = new double[][]{
{-1.4213112715121132, 1.6924690505757753},
{4.213612224374709, 7.975621325853645}
};
final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
{ 1.739356907285747, -0.5867644251487614 },
{ -0.5867644251487614, 1.0232932029324642 } }
);
correctCovMats[1] = new Array2DRowRealMatrix(new double[][] {
{ 4.245384898007161, 2.5797798966382155 },
{ 2.5797798966382155, 3.9200272522448367 } });
final MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], correctCovMats[0].getData());
correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], correctCovMats[1].getData());
MultivariateNormalMixtureExpectationMaximization fitter
= new MultivariateNormalMixtureExpectationMaximization(data);
@ -209,10 +231,13 @@ public class MultivariateNormalMixtureExpectationMaximizationTest {
int i = 0;
for (Pair<Double, MultivariateNormalDistribution> component : components) {
double weight = component.getFirst();
MultivariateNormalDistribution mvn = component.getSecond();
final double weight = component.getFirst();
final MultivariateNormalDistribution mvn = component.getSecond();
final double[] mean = mvn.getMeans();
final RealMatrix covMat = mvn.getCovariances();
Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d));
Assert.assertEquals(correctMVNs[i], mvn);
Assert.assertTrue(Arrays.equals(correctMeans[i], mean));
Assert.assertEquals(correctCovMats[i], covMat);
i++;
}
}