From 16e0a6d47b6498555b9691a190960bd72ea32e56 Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Mon, 25 Mar 2013 15:47:31 +0000 Subject: [PATCH] Fixed tests so they do not use equals on top level classes. Patch submitted by Jared Becksfort. JIRA: MATH-817 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1460726 13f79535-47bb-0310-9956-ffa450edef68 --- ...malMixtureExpectationMaximizationTest.java | 97 ++++++++++++------- 1 file changed, 61 insertions(+), 36 deletions(-) diff --git a/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java b/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java index 2a18fb86f..75328fc2c 100644 --- a/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java +++ b/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java @@ -17,6 +17,7 @@ package org.apache.commons.math3.distribution.fitting; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution; @@ -25,10 +26,11 @@ import org.apache.commons.math3.exception.ConvergenceException; import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.NotStrictlyPositiveException; import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.util.Pair; import org.junit.Assert; import org.junit.Test; -import org.junit.Ignore; /** * Test that demonstrates the use of @@ -36,9 +38,6 @@ import org.junit.Ignore; */ public class MultivariateNormalMixtureExpectationMaximizationTest { - // TODO reject initial mixes where means/covMats not computable with data - // numCols - @Test(expected = NotStrictlyPositiveException.class) public void testNonEmptyData() { // Should not accept empty data @@ -144,22 +143,34 @@ public class MultivariateNormalMixtureExpectationMaximizationTest { fitter.fit(badInitialMix); } - @Ignore@Test + @Test public void testInitialMixture() { // Testing initial mixture estimated from data - double[] correctWeights = new double[] { 0.5, 0.5 }; + final double[] correctWeights = new double[] { 0.5, 0.5 }; - MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2]; + final double[][] correctMeans = new double[][] { + {-0.0021722935000328823, 3.5432892936887908}, + {5.090902706507635, 8.68540656355283}, + }; - correctMVNs[0] = new MultivariateNormalDistribution(new double[] { - -0.0021722935000328823, 3.5432892936887908 }, - new double[][] { - { 4.537422569229048, 3.5266152281729304 }, - { 3.5266152281729304, 6.175448814169779 } }); - correctMVNs[1] = new MultivariateNormalDistribution(new double[] { - 5.090902706507635, 8.68540656355283 }, new double[][] { - { 2.886778573963039, 1.5257474543463154 }, - { 1.5257474543463154, 3.3794567673616918 } }); + final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2]; + + correctCovMats[0] = new Array2DRowRealMatrix(new double[][] { + { 4.537422569229048, 3.5266152281729304 }, + { 3.5266152281729304, 6.175448814169779 } }); + + correctCovMats[1] = new Array2DRowRealMatrix( new double[][] { + { 2.886778573963039, 1.5257474543463154 }, + { 1.5257474543463154, 3.3794567673616918 } }); + + final MultivariateNormalDistribution[] correctMVNs = new + MultivariateNormalDistribution[2]; + + correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], + correctCovMats[0].getData()); + + correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], + correctCovMats[1].getData()); final MixtureMultivariateNormalDistribution initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(), 2); @@ -169,30 +180,41 @@ public class MultivariateNormalMixtureExpectationMaximizationTest { .getComponents()) { Assert.assertEquals(correctWeights[i], component.getFirst(), Math.ulp(1d)); - Assert.assertEquals(correctMVNs[i], component.getSecond()); + + final double[] means = component.getValue().getMeans(); + Assert.assertTrue(Arrays.equals(correctMeans[i], means)); + + final RealMatrix covMat = component.getValue().getCovariances(); + Assert.assertEquals(correctCovMats[i], covMat); i++; } } - @Ignore@Test + @Test public void testFit() { // Test that the loglikelihood, weights, and models are determined and // fitted correctly - double[][] data = getTestSamples(); - double correctLogLikelihood = -4.292431006791994; - double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 }; - MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2]; - correctMVNs[0] = new MultivariateNormalDistribution(new double[] { - -1.4213112715121132, 1.6924690505757753 }, - new double[][] { - { 1.739356907285747, -0.5867644251487614 }, - { -0.5867644251487614, 1.0232932029324642 } }); - - correctMVNs[1] = new MultivariateNormalDistribution(new double[] { - 4.213612224374709, 7.975621325853645 }, - new double[][] { - { 4.245384898007161, 2.5797798966382155 }, - { 2.5797798966382155, 3.9200272522448367 } }); + final double[][] data = getTestSamples(); + final double correctLogLikelihood = -4.292431006791994; + final double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 }; + + final double[][] correctMeans = new double[][]{ + {-1.4213112715121132, 1.6924690505757753}, + {4.213612224374709, 7.975621325853645} + }; + + final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2]; + correctCovMats[0] = new Array2DRowRealMatrix(new double[][] { + { 1.739356907285747, -0.5867644251487614 }, + { -0.5867644251487614, 1.0232932029324642 } } + ); + correctCovMats[1] = new Array2DRowRealMatrix(new double[][] { + { 4.245384898007161, 2.5797798966382155 }, + { 2.5797798966382155, 3.9200272522448367 } }); + + final MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2]; + correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], correctCovMats[0].getData()); + correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], correctCovMats[1].getData()); MultivariateNormalMixtureExpectationMaximization fitter = new MultivariateNormalMixtureExpectationMaximization(data); @@ -209,10 +231,13 @@ public class MultivariateNormalMixtureExpectationMaximizationTest { int i = 0; for (Pair component : components) { - double weight = component.getFirst(); - MultivariateNormalDistribution mvn = component.getSecond(); + final double weight = component.getFirst(); + final MultivariateNormalDistribution mvn = component.getSecond(); + final double[] mean = mvn.getMeans(); + final RealMatrix covMat = mvn.getCovariances(); Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d)); - Assert.assertEquals(correctMVNs[i], mvn); + Assert.assertTrue(Arrays.equals(correctMeans[i], mean)); + Assert.assertEquals(correctCovMats[i], covMat); i++; } }