diff --git a/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java b/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java index 2a1e7d1dd..2f63b63d3 100644 --- a/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java +++ b/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java @@ -31,7 +31,6 @@ import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.SingularMatrixException; import org.apache.commons.math3.stat.correlation.Covariance; -import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.MathArrays; import org.apache.commons.math3.util.Pair; @@ -323,28 +322,23 @@ public class MultivariateNormalMixtureExpectationMaximization { } Arrays.sort(sortedData); - final int totalBins = numComponents; - // uniform weight for each bin - final double weight = 1d / totalBins; + final double weight = 1d / numComponents; // components of mixture model to be created final List> components = new ArrayList>(); // create a component based on data in each bin - for (int binNumber = 1; binNumber <= totalBins; binNumber++) { - // minimum index from sorted data for this bin - final int minIndex - = (int) FastMath.max(0, - FastMath.floor((binNumber - 1) * numRows / totalBins)); + for (int binIndex = 0; binIndex < numComponents; binIndex++) { + // minimum index (inclusive) from sorted data for this bin + final int minIndex = (binIndex * numRows) / numComponents; - // maximum index from sorted data for this bin - final int maxIndex - = (int) FastMath.ceil(binNumber * numRows / numComponents) - 1; + // maximum index (exclusive) from sorted data for this bin + final int maxIndex = ((binIndex + 1) * numRows) / numComponents; // number of data records that will be in this bin - final int numBinRows = maxIndex - minIndex + 1; + final int numBinRows = maxIndex - minIndex; // data for this bin final double[][] binData = new double[numBinRows][numCols]; @@ -353,7 +347,7 @@ public class MultivariateNormalMixtureExpectationMaximization { final double[] columnMeans = new double[numCols]; // populate bin and create component - for (int i = minIndex, iBin = 0; i <= maxIndex; i++, iBin++) { + for (int i = minIndex, iBin = 0; i < maxIndex; i++, iBin++) { for (int j = 0; j < numCols; j++) { final double val = sortedData[i].getRow()[j]; columnMeans[j] += val; @@ -426,6 +420,27 @@ public class MultivariateNormalMixtureExpectationMaximization { return mean.compareTo(other.mean); } + /** {@inheritDoc} */ + @Override + public boolean equals(Object other) { + + if (this == other) { + return true; + } + + if (other instanceof DataRow) { + return MathArrays.equals(row, ((DataRow) other).row); + } + + return false; + + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Arrays.hashCode(row); + } /** * Get a data row. * @return data row array