Allow fitting single component data
This commit is contained in:
parent
4957f7570b
commit
bab00341a9
|
@ -294,7 +294,7 @@ public class MultivariateNormalMixtureExpectationMaximization {
|
|||
* @return Multivariate normal mixture model estimated from the data
|
||||
* @throws NumberIsTooLargeException if {@code numComponents} is greater
|
||||
* than the number of data rows.
|
||||
* @throws NumberIsTooSmallException if {@code numComponents < 2}.
|
||||
* @throws NumberIsTooSmallException if {@code numComponents < 1}.
|
||||
* @throws NotStrictlyPositiveException if data has less than 2 rows
|
||||
* @throws DimensionMismatchException if rows of data have different numbers
|
||||
* of columns
|
||||
|
@ -306,8 +306,8 @@ public class MultivariateNormalMixtureExpectationMaximization {
|
|||
if (data.length < 2) {
|
||||
throw new NotStrictlyPositiveException(data.length);
|
||||
}
|
||||
if (numComponents < 2) {
|
||||
throw new NumberIsTooSmallException(numComponents, 2, true);
|
||||
if (numComponents < 1) {
|
||||
throw new NumberIsTooSmallException(numComponents, 1, true);
|
||||
}
|
||||
if (numComponents > data.length) {
|
||||
throw new NumberIsTooLargeException(numComponents, data.length, true);
|
||||
|
|
|
@ -191,60 +191,37 @@ public class MultivariateNormalMixtureExpectationMaximizationTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testFit() {
|
||||
// Test that the loglikelihood, weights, and models are determined and
|
||||
// fitted correctly
|
||||
public void testFit2Dimensions2Components() {
|
||||
final double[][] data = getTestSamples();
|
||||
final double correctLogLikelihood = -4.292431006791994;
|
||||
final double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 };
|
||||
|
||||
final double[][] correctMeans = new double[][]{
|
||||
{-1.4213112715121132, 1.6924690505757753},
|
||||
{4.213612224374709, 7.975621325853645}
|
||||
// Fit using the test samples using Matlab R2023b (Update 6):
|
||||
// GMModel = fitgmdist(X,2);
|
||||
|
||||
// Expected results use the component order generated by the CM code for convenience
|
||||
// i.e. ComponentProportion from matlab is reversed: [0.703722, 0.296278]
|
||||
|
||||
// NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations)
|
||||
final double logLikelihood = -4.292430883324220e+02 / data.length;
|
||||
// ComponentProportion
|
||||
final double[] weights = new double[] {0.2962324189652912, 0.7037675810347089};
|
||||
// mu
|
||||
final double[][] means = new double[][]{
|
||||
{-1.421239458366293, 1.692604555824222},
|
||||
{4.213949861591596, 7.975974466776790}
|
||||
};
|
||||
// Sigma
|
||||
final double[][][] covar = new double[][][] {
|
||||
{{1.739441346307267, -0.586740858187563},
|
||||
{-0.586740858187563, 1.023420964341543}},
|
||||
{{4.243780645051973, 2.578176622652551},
|
||||
{2.578176622652551, 3.918302056479298}}
|
||||
};
|
||||
|
||||
final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
|
||||
correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
|
||||
{ 1.739356907285747, -0.5867644251487614 },
|
||||
{ -0.5867644251487614, 1.0232932029324642 } }
|
||||
);
|
||||
correctCovMats[1] = new Array2DRowRealMatrix(new double[][] {
|
||||
{ 4.245384898007161, 2.5797798966382155 },
|
||||
{ 2.5797798966382155, 3.9200272522448367 } });
|
||||
|
||||
final MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
|
||||
correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], correctCovMats[0].getData());
|
||||
correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], correctCovMats[1].getData());
|
||||
|
||||
MultivariateNormalMixtureExpectationMaximization fitter
|
||||
= new MultivariateNormalMixtureExpectationMaximization(data);
|
||||
|
||||
MixtureMultivariateNormalDistribution initialMix
|
||||
= MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
|
||||
fitter.fit(initialMix);
|
||||
MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
|
||||
List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents();
|
||||
|
||||
Assert.assertEquals(correctLogLikelihood,
|
||||
fitter.getLogLikelihood(),
|
||||
Math.ulp(1d));
|
||||
|
||||
int i = 0;
|
||||
for (Pair<Double, MultivariateNormalDistribution> component : components) {
|
||||
final double weight = component.getFirst();
|
||||
final MultivariateNormalDistribution mvn = component.getSecond();
|
||||
final double[] mean = mvn.getMeans();
|
||||
final RealMatrix covMat = mvn.getCovariances();
|
||||
Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d));
|
||||
Assert.assertArrayEquals(correctMeans[i], mean, 0.0);
|
||||
Assert.assertEquals(correctCovMats[i], covMat);
|
||||
i++;
|
||||
}
|
||||
assertFit(data, 2, logLikelihood, weights, means, covar, 1e-3);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFit1() {
|
||||
// Test that the fit can be performed on data with a single dimension
|
||||
public void testFit1Dimension2Components() {
|
||||
// Use only the first column of the test data
|
||||
final double[][] data = Arrays.stream(getTestSamples())
|
||||
.map(x -> new double[] {x[0]}).toArray(double[][]::new);
|
||||
|
@ -253,42 +230,88 @@ public class MultivariateNormalMixtureExpectationMaximizationTest {
|
|||
// GMModel = fitgmdist(X,2);
|
||||
|
||||
// NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations)
|
||||
final double correctLogLikelihood = -2.512197016873482e+02 / data.length;
|
||||
final double logLikelihood = -2.512197016873482e+02 / data.length;
|
||||
// ComponentProportion
|
||||
final double[] correctWeights = new double[] {0.240510201974078, 0.759489798025922};
|
||||
final double[] weights = new double[] {0.240510201974078, 0.759489798025922};
|
||||
// Since data has 1 dimension the means and covariances are single values
|
||||
// mu
|
||||
final double[] correctMeans = new double[] {-1.736139126623031, 3.899886984922886};
|
||||
final double[][] means = new double[][]{
|
||||
{-1.736139126623031},
|
||||
{3.899886984922886}
|
||||
};
|
||||
// Sigma
|
||||
final double[] correctCov = new double[] {1.371327786710623, 5.254286022455004};
|
||||
final double[][][] covar = new double[][][] {
|
||||
{{1.371327786710623}},
|
||||
{{5.254286022455004}}
|
||||
};
|
||||
|
||||
assertFit(data, 2, logLikelihood, weights, means, covar, 0.05);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFit1Dimension1Component() {
|
||||
// Use only the first column of the test data
|
||||
final double[][] data = Arrays.stream(getTestSamples())
|
||||
.map(x -> new double[] {x[0]}).toArray(double[][]::new);
|
||||
|
||||
// Fit the first column of test samples using Matlab R2023b (Update 6):
|
||||
// GMModel = fitgmdist(X,1);
|
||||
|
||||
// NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations)
|
||||
final double logLikelihood = -2.576329329354790e+02 / data.length;
|
||||
// ComponentProportion
|
||||
final double[] weights = new double[] {1.0};
|
||||
// Since data has 1 dimension the means and covariances are single values
|
||||
// mu
|
||||
final double[][] means = new double[][]{
|
||||
{2.544365206503801},
|
||||
};
|
||||
// Sigma
|
||||
final double[][][] covar = new double[][][] {
|
||||
{{10.122711799089901}},
|
||||
};
|
||||
|
||||
assertFit(data, 1, logLikelihood, weights, means, covar, 1e-3);
|
||||
}
|
||||
|
||||
private static void assertFit(double[][] data, int numComponents,
|
||||
double logLikelihood, double[] weights,
|
||||
double[][] means, double[][][] covar, double relError) {
|
||||
MultivariateNormalMixtureExpectationMaximization fitter
|
||||
= new MultivariateNormalMixtureExpectationMaximization(data);
|
||||
|
||||
MixtureMultivariateNormalDistribution initialMix
|
||||
= MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
|
||||
= MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents);
|
||||
fitter.fit(initialMix);
|
||||
MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
|
||||
List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents();
|
||||
|
||||
final double relError = 0.05;
|
||||
Assert.assertEquals(correctLogLikelihood,
|
||||
fitter.getLogLikelihood(),
|
||||
Math.abs(correctLogLikelihood) * relError);
|
||||
Assert.assertEquals(logLikelihood,
|
||||
fitter.getLogLikelihood(),
|
||||
Math.abs(logLikelihood) * relError);
|
||||
|
||||
int i = 0;
|
||||
for (Pair<Double, MultivariateNormalDistribution> component : components) {
|
||||
final double weight = component.getFirst();
|
||||
final MultivariateNormalDistribution mvn = component.getSecond();
|
||||
final double[] mean = mvn.getMeans();
|
||||
final RealMatrix covMat = mvn.getCovariances();
|
||||
Assert.assertEquals(correctWeights[i], weight, correctWeights[i] * relError);
|
||||
Assert.assertEquals(correctMeans[i], mean[0], Math.abs(correctMeans[i]) * relError);
|
||||
Assert.assertEquals(correctCov[i], covMat.getEntry(0, 0), correctCov[i] * relError);
|
||||
Assert.assertEquals(weights[i], weight, weights[i] * relError);
|
||||
assertArrayEquals(means[i], mvn.getMeans(), relError);
|
||||
final double[][] c = mvn.getCovariances().getData();
|
||||
Assert.assertEquals(covar[i].length, c.length);
|
||||
for (int j = 0; j < covar[i].length; j++) {
|
||||
assertArrayEquals(covar[i][j], c[j], relError);
|
||||
}
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
private static void assertArrayEquals(double[] e, double[] a, double relError) {
|
||||
Assert.assertEquals("length", e.length, a.length);
|
||||
for (int i = 0; i < e.length; i++) {
|
||||
Assert.assertEquals(e[i], a[i], Math.abs(e[i]) * relError);
|
||||
}
|
||||
}
|
||||
|
||||
private double[][] getTestSamples() {
|
||||
// generated using R Mixtools rmvnorm with mean vectors [-1.5, 2] and
|
||||
// [4, 8.2]
|
||||
|
|
Loading…
Reference in New Issue