From 1a6879a0023a2892323221622e0d1e759613c68f Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Wed, 19 Feb 2014 20:31:47 +0000 Subject: [PATCH] Make QR in GaussNewton faster and more accurate Re-factored the code in GaussNewtonOptimizer so that the decomposition algorithm sees the Jacobian and residuals instead of the normal equation. This lets the QR algorithm operate directly on the Jacobian matrix, which is faster and less sensitive to numerical errors. As a result, one test case that threw a singular matrix exception now passes with the QR decomposition. The refactoring also include a speed improvement when computing the normal matrix for the LU decomposition. Since the normal matrix is symmetric only half of is computed, which results in a factor of 2 speed up in computing the normal matrix for problems with many more measurements than states. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569905 13f79535-47bb-0310-9956-ffa450edef68 --- src/changes/changes.xml | 3 + .../leastsquares/GaussNewtonOptimizer.java | 143 ++++++++++++------ .../GaussNewtonOptimizerWithQRTest.java | 14 -- 3 files changed, 98 insertions(+), 62 deletions(-) diff --git a/src/changes/changes.xml b/src/changes/changes.xml index c558bf1d3..e8254943b 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -51,6 +51,9 @@ If the output is not quite correct, check for invisible trailing spaces! + + Make QR in GaussNewton faster and more accurate. + The sparse vector and matrix classes have been un-deprecated. This is a reversal of a former decision, as we now think we should adopt a generally accepted diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java index 6f440554a..a22fef0fd 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java @@ -21,15 +21,15 @@ import org.apache.commons.math3.exception.NullArgumentException; import org.apache.commons.math3.exception.util.LocalizedFormats; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.linear.ArrayRealVector; -import org.apache.commons.math3.linear.BlockRealMatrix; -import org.apache.commons.math3.linear.DecompositionSolver; import org.apache.commons.math3.linear.LUDecomposition; +import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.QRDecomposition; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.linear.SingularMatrixException; import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.util.Incrementor; +import org.apache.commons.math3.util.Pair; /** * Gauss-Newton least-squares solver.

This class solve a least-square problem by @@ -46,28 +46,65 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer { /** The decomposition algorithm to use to solve the normal equations. */ //TODO move to linear package and expand options? public static enum Decomposition { - /** Use {@link LUDecomposition}. */ + /** + * Solve by forming the normal equations (JTJx=JTr) and + * using the {@link LUDecomposition}. + * + *

Theoretically this method takes mn2>/2 operations to compute the + * normal matrix and n3/3 operations (m > n) to solve the system using + * the LU decomposition.

+ */ LU { @Override - protected DecompositionSolver getSolver(final RealMatrix matrix) { - return new LUDecomposition(matrix, SINGULARITY_THRESHOLD).getSolver(); + protected RealVector solve(final RealMatrix jacobian, + final RealVector residuals) { + try { + final Pair normalEquation = + computeNormalMatrix(jacobian, residuals); + final RealMatrix normal = normalEquation.getFirst(); + final RealVector jTr = normalEquation.getSecond(); + return new LUDecomposition(normal, SINGULARITY_THRESHOLD) + .getSolver() + .solve(jTr); + } catch (SingularMatrixException e) { + throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM); + } } }, - /** Use {@link QRDecomposition}. */ + /** + * Solve the linear least squares problem (Jx=r) using the {@link + * QRDecomposition}. + * + *

Theoretically this method takes mn2 - n3/3 operations + * (m > n) and has better numerical accuracy than any method that forms the normal + * equations.

+ */ QR { @Override - protected DecompositionSolver getSolver(final RealMatrix matrix) { - return new QRDecomposition(matrix, SINGULARITY_THRESHOLD).getSolver(); + protected RealVector solve(final RealMatrix jacobian, + final RealVector residuals) { + try { + return new QRDecomposition(jacobian, SINGULARITY_THRESHOLD) + .getSolver() + .solve(residuals); + } catch (SingularMatrixException e) { + throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM); + } } }; /** - * Decompose the normal equations. + * Solve the linear least squares problem Jx=r. * - * @param matrix the normal matrix. - * @return a solver. + * @param jacobian the Jacobian matrix, J. the number of rows >= the number or + * columns. + * @param residuals the computed residuals, r. + * @return the solution x, to the linear least squares problem Jx=r. + * @throws ConvergenceException if the matrix properties (e.g. singular) do not + * permit a solution. */ - protected abstract DecompositionSolver getSolver(RealMatrix matrix); + protected abstract RealVector solve(RealMatrix jacobian, + RealVector residuals); } /** @@ -132,7 +169,6 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer { throw new NullArgumentException(); } - final int nR = lsp.getObservationSize(); // Number of observed data. final int nC = lsp.getParameterSize(); final RealVector currentPoint = lsp.getStart(); @@ -160,41 +196,11 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer { } } - // build the linear problem - final double[] b = new double[nC]; - final double[][] a = new double[nC][nC]; - for (int i = 0; i < nR; ++i) { - - final double[] grad = weightedJacobian.getRow(i); - final double residual = currentResiduals.getEntry(i); - - // compute the normal equation - //residual is already weighted - for (int j = 0; j < nC; ++j) { - b[j] += residual * grad[j]; - } - - // build the contribution matrix for measurement i - for (int k = 0; k < nC; ++k) { - double[] ak = a[k]; - //Jacobian/gradient is already weighted - for (int l = 0; l < nC; ++l) { - ak[l] += grad[k] * grad[l]; - } - } - } - - try { - // solve the linearized least squares problem - RealMatrix mA = new BlockRealMatrix(a); - DecompositionSolver solver = this.decomposition.getSolver(mA); - final RealVector dX = solver.solve(new ArrayRealVector(b, false)); - // update the estimated parameters - for (int i = 0; i < nC; ++i) { - currentPoint.setEntry(i, currentPoint.getEntry(i) + dX.getEntry(i)); - } - } catch (SingularMatrixException e) { - throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM); + // solve the linearized least squares problem + final RealVector dX = this.decomposition.solve(weightedJacobian, currentResiduals); + // update the estimated parameters + for (int i = 0; i < nC; ++i) { + currentPoint.setEntry(i, currentPoint.getEntry(i) + dX.getEntry(i)); } } } @@ -206,4 +212,45 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer { '}'; } + /** + * Compute the normal matrix, JTJ. + * + * @param jacobian the m by n jacobian matrix, J. Input. + * @param residuals the m by 1 residual vector, r. Input. + * @return the n by n normal matrix and the n by 1 JTr vector. + */ + private static Pair computeNormalMatrix(final RealMatrix jacobian, + final RealVector residuals) { + //since the normal matrix is symmetric, we only need to compute half of it. + final int nR = jacobian.getRowDimension(); + final int nC = jacobian.getColumnDimension(); + //allocate space for return values + final RealMatrix normal = MatrixUtils.createRealMatrix(nC, nC); + final RealVector jTr = new ArrayRealVector(nC); + //for each measurement + for (int i = 0; i < nR; ++i) { + //compute JTr for measurement i + for (int j = 0; j < nC; j++) { + jTr.setEntry(j, jTr.getEntry(j) + + residuals.getEntry(i) * jacobian.getEntry(i, j)); + } + + // add the the contribution to the normal matrix for measurement i + for (int k = 0; k < nC; ++k) { + //only compute the upper triangular part + for (int l = k; l < nC; ++l) { + normal.setEntry(k, l, normal.getEntry(k, l) + + jacobian.getEntry(i, k) * jacobian.getEntry(i, l)); + } + } + } + //copy the upper triangular part to the lower triangular part. + for (int i = 0; i < nC; i++) { + for (int j = 0; j < i; j++) { + normal.setEntry(i, j, normal.getEntry(j, i)); + } + } + return new Pair(normal, jTr); + } + } diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerWithQRTest.java b/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerWithQRTest.java index 8631101d1..79a74881d 100644 --- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerWithQRTest.java +++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerWithQRTest.java @@ -47,20 +47,6 @@ public class GaussNewtonOptimizerWithQRTest return new GaussNewtonOptimizer(Decomposition.QR); } - @Override - @Test - public void testMoreEstimatedParametersSimple() { - /* - * Exception is expected with this optimizer - */ - try { - super.testMoreEstimatedParametersSimple(); - fail(optimizer); - } catch (ConvergenceException e) { - //expected - } - } - @Override @Test public void testMoreEstimatedParametersUnsorted() {