Fixed tests so they do not use equals on top level classes.
Patch submitted by Jared Becksfort. JIRA: MATH-817 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1460726 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
78f8c198c0
commit
16e0a6d47b
|
@ -17,6 +17,7 @@
|
|||
package org.apache.commons.math3.distribution.fitting;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution;
|
||||
|
@ -25,10 +26,11 @@ import org.apache.commons.math3.exception.ConvergenceException;
|
|||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.util.Pair;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.junit.Ignore;
|
||||
|
||||
/**
|
||||
* Test that demonstrates the use of
|
||||
|
@ -36,9 +38,6 @@ import org.junit.Ignore;
|
|||
*/
|
||||
public class MultivariateNormalMixtureExpectationMaximizationTest {
|
||||
|
||||
// TODO reject initial mixes where means/covMats not computable with data
|
||||
// numCols
|
||||
|
||||
@Test(expected = NotStrictlyPositiveException.class)
|
||||
public void testNonEmptyData() {
|
||||
// Should not accept empty data
|
||||
|
@ -144,22 +143,34 @@ public class MultivariateNormalMixtureExpectationMaximizationTest {
|
|||
fitter.fit(badInitialMix);
|
||||
}
|
||||
|
||||
@Ignore@Test
|
||||
@Test
|
||||
public void testInitialMixture() {
|
||||
// Testing initial mixture estimated from data
|
||||
double[] correctWeights = new double[] { 0.5, 0.5 };
|
||||
final double[] correctWeights = new double[] { 0.5, 0.5 };
|
||||
|
||||
MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
|
||||
final double[][] correctMeans = new double[][] {
|
||||
{-0.0021722935000328823, 3.5432892936887908},
|
||||
{5.090902706507635, 8.68540656355283},
|
||||
};
|
||||
|
||||
correctMVNs[0] = new MultivariateNormalDistribution(new double[] {
|
||||
-0.0021722935000328823, 3.5432892936887908 },
|
||||
new double[][] {
|
||||
{ 4.537422569229048, 3.5266152281729304 },
|
||||
{ 3.5266152281729304, 6.175448814169779 } });
|
||||
correctMVNs[1] = new MultivariateNormalDistribution(new double[] {
|
||||
5.090902706507635, 8.68540656355283 }, new double[][] {
|
||||
{ 2.886778573963039, 1.5257474543463154 },
|
||||
{ 1.5257474543463154, 3.3794567673616918 } });
|
||||
final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
|
||||
|
||||
correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
|
||||
{ 4.537422569229048, 3.5266152281729304 },
|
||||
{ 3.5266152281729304, 6.175448814169779 } });
|
||||
|
||||
correctCovMats[1] = new Array2DRowRealMatrix( new double[][] {
|
||||
{ 2.886778573963039, 1.5257474543463154 },
|
||||
{ 1.5257474543463154, 3.3794567673616918 } });
|
||||
|
||||
final MultivariateNormalDistribution[] correctMVNs = new
|
||||
MultivariateNormalDistribution[2];
|
||||
|
||||
correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0],
|
||||
correctCovMats[0].getData());
|
||||
|
||||
correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1],
|
||||
correctCovMats[1].getData());
|
||||
|
||||
final MixtureMultivariateNormalDistribution initialMix
|
||||
= MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(), 2);
|
||||
|
@ -169,30 +180,41 @@ public class MultivariateNormalMixtureExpectationMaximizationTest {
|
|||
.getComponents()) {
|
||||
Assert.assertEquals(correctWeights[i], component.getFirst(),
|
||||
Math.ulp(1d));
|
||||
Assert.assertEquals(correctMVNs[i], component.getSecond());
|
||||
|
||||
final double[] means = component.getValue().getMeans();
|
||||
Assert.assertTrue(Arrays.equals(correctMeans[i], means));
|
||||
|
||||
final RealMatrix covMat = component.getValue().getCovariances();
|
||||
Assert.assertEquals(correctCovMats[i], covMat);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
@Ignore@Test
|
||||
@Test
|
||||
public void testFit() {
|
||||
// Test that the loglikelihood, weights, and models are determined and
|
||||
// fitted correctly
|
||||
double[][] data = getTestSamples();
|
||||
double correctLogLikelihood = -4.292431006791994;
|
||||
double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 };
|
||||
MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
|
||||
correctMVNs[0] = new MultivariateNormalDistribution(new double[] {
|
||||
-1.4213112715121132, 1.6924690505757753 },
|
||||
new double[][] {
|
||||
{ 1.739356907285747, -0.5867644251487614 },
|
||||
{ -0.5867644251487614, 1.0232932029324642 } });
|
||||
|
||||
correctMVNs[1] = new MultivariateNormalDistribution(new double[] {
|
||||
4.213612224374709, 7.975621325853645 },
|
||||
new double[][] {
|
||||
{ 4.245384898007161, 2.5797798966382155 },
|
||||
{ 2.5797798966382155, 3.9200272522448367 } });
|
||||
final double[][] data = getTestSamples();
|
||||
final double correctLogLikelihood = -4.292431006791994;
|
||||
final double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 };
|
||||
|
||||
final double[][] correctMeans = new double[][]{
|
||||
{-1.4213112715121132, 1.6924690505757753},
|
||||
{4.213612224374709, 7.975621325853645}
|
||||
};
|
||||
|
||||
final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
|
||||
correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
|
||||
{ 1.739356907285747, -0.5867644251487614 },
|
||||
{ -0.5867644251487614, 1.0232932029324642 } }
|
||||
);
|
||||
correctCovMats[1] = new Array2DRowRealMatrix(new double[][] {
|
||||
{ 4.245384898007161, 2.5797798966382155 },
|
||||
{ 2.5797798966382155, 3.9200272522448367 } });
|
||||
|
||||
final MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
|
||||
correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], correctCovMats[0].getData());
|
||||
correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], correctCovMats[1].getData());
|
||||
|
||||
MultivariateNormalMixtureExpectationMaximization fitter
|
||||
= new MultivariateNormalMixtureExpectationMaximization(data);
|
||||
|
@ -209,10 +231,13 @@ public class MultivariateNormalMixtureExpectationMaximizationTest {
|
|||
|
||||
int i = 0;
|
||||
for (Pair<Double, MultivariateNormalDistribution> component : components) {
|
||||
double weight = component.getFirst();
|
||||
MultivariateNormalDistribution mvn = component.getSecond();
|
||||
final double weight = component.getFirst();
|
||||
final MultivariateNormalDistribution mvn = component.getSecond();
|
||||
final double[] mean = mvn.getMeans();
|
||||
final RealMatrix covMat = mvn.getCovariances();
|
||||
Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d));
|
||||
Assert.assertEquals(correctMVNs[i], mvn);
|
||||
Assert.assertTrue(Arrays.equals(correctMeans[i], mean));
|
||||
Assert.assertEquals(correctCovMats[i], covMat);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue