Allow fitting single component data

This commit is contained in:
Alex Herbert 2024-03-11 21:55:54 +00:00
parent 4957f7570b
commit bab00341a9
2 changed files with 87 additions and 64 deletions

View File

@ -294,7 +294,7 @@ public class MultivariateNormalMixtureExpectationMaximization {
* @return Multivariate normal mixture model estimated from the data
* @throws NumberIsTooLargeException if {@code numComponents} is greater
* than the number of data rows.
* @throws NumberIsTooSmallException if {@code numComponents < 2}.
* @throws NumberIsTooSmallException if {@code numComponents < 1}.
* @throws NotStrictlyPositiveException if data has less than 2 rows
* @throws DimensionMismatchException if rows of data have different numbers
* of columns
@ -306,8 +306,8 @@ public class MultivariateNormalMixtureExpectationMaximization {
if (data.length < 2) {
throw new NotStrictlyPositiveException(data.length);
}
if (numComponents < 2) {
throw new NumberIsTooSmallException(numComponents, 2, true);
if (numComponents < 1) {
throw new NumberIsTooSmallException(numComponents, 1, true);
}
if (numComponents > data.length) {
throw new NumberIsTooLargeException(numComponents, data.length, true);

View File

@ -191,60 +191,37 @@ public class MultivariateNormalMixtureExpectationMaximizationTest {
}
@Test
public void testFit() {
// Test that the loglikelihood, weights, and models are determined and
// fitted correctly
public void testFit2Dimensions2Components() {
final double[][] data = getTestSamples();
final double correctLogLikelihood = -4.292431006791994;
final double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 };
final double[][] correctMeans = new double[][]{
{-1.4213112715121132, 1.6924690505757753},
{4.213612224374709, 7.975621325853645}
// Fit using the test samples using Matlab R2023b (Update 6):
// GMModel = fitgmdist(X,2);
// Expected results use the component order generated by the CM code for convenience
// i.e. ComponentProportion from matlab is reversed: [0.703722, 0.296278]
// NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations)
final double logLikelihood = -4.292430883324220e+02 / data.length;
// ComponentProportion
final double[] weights = new double[] {0.2962324189652912, 0.7037675810347089};
// mu
final double[][] means = new double[][]{
{-1.421239458366293, 1.692604555824222},
{4.213949861591596, 7.975974466776790}
};
// Sigma
final double[][][] covar = new double[][][] {
{{1.739441346307267, -0.586740858187563},
{-0.586740858187563, 1.023420964341543}},
{{4.243780645051973, 2.578176622652551},
{2.578176622652551, 3.918302056479298}}
};
final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
{ 1.739356907285747, -0.5867644251487614 },
{ -0.5867644251487614, 1.0232932029324642 } }
);
correctCovMats[1] = new Array2DRowRealMatrix(new double[][] {
{ 4.245384898007161, 2.5797798966382155 },
{ 2.5797798966382155, 3.9200272522448367 } });
final MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], correctCovMats[0].getData());
correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], correctCovMats[1].getData());
MultivariateNormalMixtureExpectationMaximization fitter
= new MultivariateNormalMixtureExpectationMaximization(data);
MixtureMultivariateNormalDistribution initialMix
= MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
fitter.fit(initialMix);
MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents();
Assert.assertEquals(correctLogLikelihood,
fitter.getLogLikelihood(),
Math.ulp(1d));
int i = 0;
for (Pair<Double, MultivariateNormalDistribution> component : components) {
final double weight = component.getFirst();
final MultivariateNormalDistribution mvn = component.getSecond();
final double[] mean = mvn.getMeans();
final RealMatrix covMat = mvn.getCovariances();
Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d));
Assert.assertArrayEquals(correctMeans[i], mean, 0.0);
Assert.assertEquals(correctCovMats[i], covMat);
i++;
}
assertFit(data, 2, logLikelihood, weights, means, covar, 1e-3);
}
@Test
public void testFit1() {
// Test that the fit can be performed on data with a single dimension
public void testFit1Dimension2Components() {
// Use only the first column of the test data
final double[][] data = Arrays.stream(getTestSamples())
.map(x -> new double[] {x[0]}).toArray(double[][]::new);
@ -253,42 +230,88 @@ public class MultivariateNormalMixtureExpectationMaximizationTest {
// GMModel = fitgmdist(X,2);
// NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations)
final double correctLogLikelihood = -2.512197016873482e+02 / data.length;
final double logLikelihood = -2.512197016873482e+02 / data.length;
// ComponentProportion
final double[] correctWeights = new double[] {0.240510201974078, 0.759489798025922};
final double[] weights = new double[] {0.240510201974078, 0.759489798025922};
// Since data has 1 dimension the means and covariances are single values
// mu
final double[] correctMeans = new double[] {-1.736139126623031, 3.899886984922886};
final double[][] means = new double[][]{
{-1.736139126623031},
{3.899886984922886}
};
// Sigma
final double[] correctCov = new double[] {1.371327786710623, 5.254286022455004};
final double[][][] covar = new double[][][] {
{{1.371327786710623}},
{{5.254286022455004}}
};
assertFit(data, 2, logLikelihood, weights, means, covar, 0.05);
}
@Test
public void testFit1Dimension1Component() {
// Use only the first column of the test data
final double[][] data = Arrays.stream(getTestSamples())
.map(x -> new double[] {x[0]}).toArray(double[][]::new);
// Fit the first column of test samples using Matlab R2023b (Update 6):
// GMModel = fitgmdist(X,1);
// NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations)
final double logLikelihood = -2.576329329354790e+02 / data.length;
// ComponentProportion
final double[] weights = new double[] {1.0};
// Since data has 1 dimension the means and covariances are single values
// mu
final double[][] means = new double[][]{
{2.544365206503801},
};
// Sigma
final double[][][] covar = new double[][][] {
{{10.122711799089901}},
};
assertFit(data, 1, logLikelihood, weights, means, covar, 1e-3);
}
private static void assertFit(double[][] data, int numComponents,
double logLikelihood, double[] weights,
double[][] means, double[][][] covar, double relError) {
MultivariateNormalMixtureExpectationMaximization fitter
= new MultivariateNormalMixtureExpectationMaximization(data);
MixtureMultivariateNormalDistribution initialMix
= MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
= MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents);
fitter.fit(initialMix);
MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents();
final double relError = 0.05;
Assert.assertEquals(correctLogLikelihood,
fitter.getLogLikelihood(),
Math.abs(correctLogLikelihood) * relError);
Assert.assertEquals(logLikelihood,
fitter.getLogLikelihood(),
Math.abs(logLikelihood) * relError);
int i = 0;
for (Pair<Double, MultivariateNormalDistribution> component : components) {
final double weight = component.getFirst();
final MultivariateNormalDistribution mvn = component.getSecond();
final double[] mean = mvn.getMeans();
final RealMatrix covMat = mvn.getCovariances();
Assert.assertEquals(correctWeights[i], weight, correctWeights[i] * relError);
Assert.assertEquals(correctMeans[i], mean[0], Math.abs(correctMeans[i]) * relError);
Assert.assertEquals(correctCov[i], covMat.getEntry(0, 0), correctCov[i] * relError);
Assert.assertEquals(weights[i], weight, weights[i] * relError);
assertArrayEquals(means[i], mvn.getMeans(), relError);
final double[][] c = mvn.getCovariances().getData();
Assert.assertEquals(covar[i].length, c.length);
for (int j = 0; j < covar[i].length; j++) {
assertArrayEquals(covar[i][j], c[j], relError);
}
i++;
}
}
private static void assertArrayEquals(double[] e, double[] a, double relError) {
Assert.assertEquals("length", e.length, a.length);
for (int i = 0; i < e.length; i++) {
Assert.assertEquals(e[i], a[i], Math.abs(e[i]) * relError);
}
}
private double[][] getTestSamples() {
// generated using R Mixtools rmvnorm with mean vectors [-1.5, 2] and
// [4, 8.2]