diff --git a/src/changes/changes.xml b/src/changes/changes.xml index c558bf1d3..e8254943b 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -51,6 +51,9 @@ If the output is not quite correct, check for invisible trailing spaces! + + Make QR in GaussNewton faster and more accurate. + The sparse vector and matrix classes have been un-deprecated. This is a reversal of a former decision, as we now think we should adopt a generally accepted diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java index 6f440554a..a22fef0fd 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java @@ -21,15 +21,15 @@ import org.apache.commons.math3.exception.NullArgumentException; import org.apache.commons.math3.exception.util.LocalizedFormats; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.linear.ArrayRealVector; -import org.apache.commons.math3.linear.BlockRealMatrix; -import org.apache.commons.math3.linear.DecompositionSolver; import org.apache.commons.math3.linear.LUDecomposition; +import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.QRDecomposition; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.linear.SingularMatrixException; import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.util.Incrementor; +import org.apache.commons.math3.util.Pair; /** * Gauss-Newton least-squares solver.

This class solve a least-square problem by @@ -46,28 +46,65 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer { /** The decomposition algorithm to use to solve the normal equations. */ //TODO move to linear package and expand options? public static enum Decomposition { - /** Use {@link LUDecomposition}. */ + /** + * Solve by forming the normal equations (JTJx=JTr) and + * using the {@link LUDecomposition}. + * + *

Theoretically this method takes mn2>/2 operations to compute the + * normal matrix and n3/3 operations (m > n) to solve the system using + * the LU decomposition.

+ */ LU { @Override - protected DecompositionSolver getSolver(final RealMatrix matrix) { - return new LUDecomposition(matrix, SINGULARITY_THRESHOLD).getSolver(); + protected RealVector solve(final RealMatrix jacobian, + final RealVector residuals) { + try { + final Pair normalEquation = + computeNormalMatrix(jacobian, residuals); + final RealMatrix normal = normalEquation.getFirst(); + final RealVector jTr = normalEquation.getSecond(); + return new LUDecomposition(normal, SINGULARITY_THRESHOLD) + .getSolver() + .solve(jTr); + } catch (SingularMatrixException e) { + throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM); + } } }, - /** Use {@link QRDecomposition}. */ + /** + * Solve the linear least squares problem (Jx=r) using the {@link + * QRDecomposition}. + * + *

Theoretically this method takes mn2 - n3/3 operations + * (m > n) and has better numerical accuracy than any method that forms the normal + * equations.

+ */ QR { @Override - protected DecompositionSolver getSolver(final RealMatrix matrix) { - return new QRDecomposition(matrix, SINGULARITY_THRESHOLD).getSolver(); + protected RealVector solve(final RealMatrix jacobian, + final RealVector residuals) { + try { + return new QRDecomposition(jacobian, SINGULARITY_THRESHOLD) + .getSolver() + .solve(residuals); + } catch (SingularMatrixException e) { + throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM); + } } }; /** - * Decompose the normal equations. + * Solve the linear least squares problem Jx=r. * - * @param matrix the normal matrix. - * @return a solver. + * @param jacobian the Jacobian matrix, J. the number of rows >= the number or + * columns. + * @param residuals the computed residuals, r. + * @return the solution x, to the linear least squares problem Jx=r. + * @throws ConvergenceException if the matrix properties (e.g. singular) do not + * permit a solution. */ - protected abstract DecompositionSolver getSolver(RealMatrix matrix); + protected abstract RealVector solve(RealMatrix jacobian, + RealVector residuals); } /** @@ -132,7 +169,6 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer { throw new NullArgumentException(); } - final int nR = lsp.getObservationSize(); // Number of observed data. final int nC = lsp.getParameterSize(); final RealVector currentPoint = lsp.getStart(); @@ -160,41 +196,11 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer { } } - // build the linear problem - final double[] b = new double[nC]; - final double[][] a = new double[nC][nC]; - for (int i = 0; i < nR; ++i) { - - final double[] grad = weightedJacobian.getRow(i); - final double residual = currentResiduals.getEntry(i); - - // compute the normal equation - //residual is already weighted - for (int j = 0; j < nC; ++j) { - b[j] += residual * grad[j]; - } - - // build the contribution matrix for measurement i - for (int k = 0; k < nC; ++k) { - double[] ak = a[k]; - //Jacobian/gradient is already weighted - for (int l = 0; l < nC; ++l) { - ak[l] += grad[k] * grad[l]; - } - } - } - - try { - // solve the linearized least squares problem - RealMatrix mA = new BlockRealMatrix(a); - DecompositionSolver solver = this.decomposition.getSolver(mA); - final RealVector dX = solver.solve(new ArrayRealVector(b, false)); - // update the estimated parameters - for (int i = 0; i < nC; ++i) { - currentPoint.setEntry(i, currentPoint.getEntry(i) + dX.getEntry(i)); - } - } catch (SingularMatrixException e) { - throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM); + // solve the linearized least squares problem + final RealVector dX = this.decomposition.solve(weightedJacobian, currentResiduals); + // update the estimated parameters + for (int i = 0; i < nC; ++i) { + currentPoint.setEntry(i, currentPoint.getEntry(i) + dX.getEntry(i)); } } } @@ -206,4 +212,45 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer { '}'; } + /** + * Compute the normal matrix, JTJ. + * + * @param jacobian the m by n jacobian matrix, J. Input. + * @param residuals the m by 1 residual vector, r. Input. + * @return the n by n normal matrix and the n by 1 JTr vector. + */ + private static Pair computeNormalMatrix(final RealMatrix jacobian, + final RealVector residuals) { + //since the normal matrix is symmetric, we only need to compute half of it. + final int nR = jacobian.getRowDimension(); + final int nC = jacobian.getColumnDimension(); + //allocate space for return values + final RealMatrix normal = MatrixUtils.createRealMatrix(nC, nC); + final RealVector jTr = new ArrayRealVector(nC); + //for each measurement + for (int i = 0; i < nR; ++i) { + //compute JTr for measurement i + for (int j = 0; j < nC; j++) { + jTr.setEntry(j, jTr.getEntry(j) + + residuals.getEntry(i) * jacobian.getEntry(i, j)); + } + + // add the the contribution to the normal matrix for measurement i + for (int k = 0; k < nC; ++k) { + //only compute the upper triangular part + for (int l = k; l < nC; ++l) { + normal.setEntry(k, l, normal.getEntry(k, l) + + jacobian.getEntry(i, k) * jacobian.getEntry(i, l)); + } + } + } + //copy the upper triangular part to the lower triangular part. + for (int i = 0; i < nC; i++) { + for (int j = 0; j < i; j++) { + normal.setEntry(i, j, normal.getEntry(j, i)); + } + } + return new Pair(normal, jTr); + } + } diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerWithQRTest.java b/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerWithQRTest.java index 8631101d1..79a74881d 100644 --- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerWithQRTest.java +++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerWithQRTest.java @@ -47,20 +47,6 @@ public class GaussNewtonOptimizerWithQRTest return new GaussNewtonOptimizer(Decomposition.QR); } - @Override - @Test - public void testMoreEstimatedParametersSimple() { - /* - * Exception is expected with this optimizer - */ - try { - super.testMoreEstimatedParametersSimple(); - fail(optimizer); - } catch (ConvergenceException e) { - //expected - } - } - @Override @Test public void testMoreEstimatedParametersUnsorted() {