From a7a380f93478356b287791b917d6e68e89d20a8f Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Tue, 18 Feb 2014 14:32:44 +0000 Subject: [PATCH] Use Evaluation instead of PointVectorValuePair Use Evaluation instead of PointVectorValuePair in the ConvergenceChecker. This gives the checkers access to more information, such as the rms and covariances. The change also simplified the optimizer implementations since they no longer have to keep track of the current function value. A method was added to LeastSquaresFactory to convert between the two types of checkers and a method added to LeastSquaresBuilder so that it can accept either type. I would have prefered to do this through method overloading, but overloading doesn't play well with generics. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569353 13f79535-47bb-0310-9956-ffa450edef68 --- .../leastsquares/GaussNewtonOptimizer.java | 17 ++++---- .../leastsquares/LeastSquaresAdapter.java | 3 +- .../leastsquares/LeastSquaresBuilder.java | 17 +++++++- .../leastsquares/LeastSquaresFactory.java | 36 +++++++++++++++-- .../leastsquares/LeastSquaresProblem.java | 4 +- .../leastsquares/LeastSquaresProblemImpl.java | 6 +-- .../LevenbergMarquardtOptimizer.java | 40 ++++++------------- ...ractLeastSquaresOptimizerAbstractTest.java | 2 +- .../GaussNewtonOptimizerTest.java | 2 +- 9 files changed, 76 insertions(+), 51 deletions(-) diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java index d9246b835..e9d535823 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java @@ -28,7 +28,6 @@ import org.apache.commons.math3.linear.QRDecomposition; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.SingularMatrixException; import org.apache.commons.math3.optim.ConvergenceChecker; -import org.apache.commons.math3.optim.PointVectorValuePair; import org.apache.commons.math3.util.Incrementor; /** @@ -123,7 +122,7 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer { //create local evaluation and iteration counts final Incrementor evaluationCounter = lsp.getEvaluationCounter(); final Incrementor iterationCounter = lsp.getIterationCounter(); - final ConvergenceChecker checker + final ConvergenceChecker checker = lsp.getConvergenceChecker(); // Computation will be useless without a checker (see "for-loop"). @@ -137,25 +136,23 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer { final double[] currentPoint = lsp.getStart(); // iterate until convergence is reached - PointVectorValuePair current = null; + Evaluation current = null; while (true) { iterationCounter.incrementCount(); // evaluate the objective function and its jacobian - PointVectorValuePair previous = current; + Evaluation previous = current; // Value of the objective function at "currentPoint". evaluationCounter.incrementCount(); - final Evaluation value = lsp.evaluate(currentPoint); - final double[] currentObjective = value.computeValue(); - final double[] currentResiduals = value.computeResiduals(); - final RealMatrix weightedJacobian = value.computeJacobian(); - current = new PointVectorValuePair(currentPoint, currentObjective); + current = lsp.evaluate(currentPoint); + final double[] currentResiduals = current.computeResiduals(); + final RealMatrix weightedJacobian = current.computeJacobian(); // Check convergence. if (previous != null) { if (checker.converged(iterationCounter.getCount(), previous, current)) { return new OptimumImpl( - value, + current, evaluationCounter.getCount(), iterationCounter.getCount()); } diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresAdapter.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresAdapter.java index 3e88d035c..f77280c4c 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresAdapter.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresAdapter.java @@ -1,7 +1,6 @@ package org.apache.commons.math3.fitting.leastsquares; import org.apache.commons.math3.optim.ConvergenceChecker; -import org.apache.commons.math3.optim.PointVectorValuePair; import org.apache.commons.math3.util.Incrementor; /** @@ -54,7 +53,7 @@ public class LeastSquaresAdapter implements LeastSquaresProblem { } /** {@inheritDoc} */ - public ConvergenceChecker getConvergenceChecker() { + public ConvergenceChecker getConvergenceChecker() { return problem.getConvergenceChecker(); } } diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java index 485eeea37..647bc4d2f 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java @@ -2,6 +2,7 @@ package org.apache.commons.math3.fitting.leastsquares; import org.apache.commons.math3.analysis.MultivariateMatrixFunction; import org.apache.commons.math3.analysis.MultivariateVectorFunction; +import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.optim.PointVectorValuePair; @@ -19,7 +20,7 @@ public class LeastSquaresBuilder { /** max iterations */ private int maxIterations; /** convergence checker */ - private ConvergenceChecker checker; + private ConvergenceChecker checker; /** model function */ private MultivariateVectorFunction model; /** Jacobian function */ @@ -69,11 +70,23 @@ public class LeastSquaresBuilder { * @param checker the convergence checker. * @return this */ - public LeastSquaresBuilder checker(final ConvergenceChecker checker) { + public LeastSquaresBuilder checker(final ConvergenceChecker checker) { this.checker = checker; return this; } + /** + * Configure the convergence checker. + *

+ * This function is an overloaded version of {@link #checker(ConvergenceChecker)}. + * + * @param checker the convergence checker. + * @return this + */ + public LeastSquaresBuilder checkerPair(final ConvergenceChecker checker) { + return this.checker(LeastSquaresFactory.evaluationChecker(checker)); + } + /** * Configure the model function. * diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java index 073ff4d26..8cce2c488 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java @@ -2,6 +2,7 @@ package org.apache.commons.math3.fitting.leastsquares; import org.apache.commons.math3.analysis.MultivariateMatrixFunction; import org.apache.commons.math3.analysis.MultivariateVectorFunction; +import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.DiagonalMatrix; @@ -40,7 +41,7 @@ public class LeastSquaresFactory { public static LeastSquaresProblem create(final MultivariateJacobianFunction model, final double[] observed, final double[] start, - final ConvergenceChecker checker, + final ConvergenceChecker checker, final int maxEvaluations, final int maxIterations) { return new LeastSquaresProblemImpl( @@ -70,7 +71,7 @@ public class LeastSquaresFactory { final MultivariateMatrixFunction jacobian, final double[] observed, final double[] start, - final ConvergenceChecker checker, + final ConvergenceChecker checker, final int maxEvaluations, final int maxIterations) { return create( @@ -102,7 +103,7 @@ public class LeastSquaresFactory { final double[] observed, final double[] start, final RealMatrix weight, - final ConvergenceChecker checker, + final ConvergenceChecker checker, final int maxEvaluations, final int maxIterations) { return weightMatrix( @@ -174,6 +175,35 @@ public class LeastSquaresFactory { }; } + /** + * View a convergence checker specified for a {@link PointVectorValuePair} as one + * specified for an {@link Evaluation}. + * + * @param checker the convergence checker to adapt. + * @return a convergence checker that delegates to {@code checker}. + */ + public static ConvergenceChecker evaluationChecker( + final ConvergenceChecker checker + ) { + return new ConvergenceChecker() { + public boolean converged(final int iteration, + final Evaluation previous, + final Evaluation current) { + return checker.converged( + iteration, + new PointVectorValuePair( + previous.getPoint(), + previous.computeValue(), + false), + new PointVectorValuePair( + current.getPoint(), + current.computeValue(), + false) + ); + } + }; + } + /** * Computes the square-root of the weight matrix. * diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblem.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblem.java index f61bc32cf..427090caa 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblem.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblem.java @@ -1,8 +1,8 @@ package org.apache.commons.math3.fitting.leastsquares; import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.linear.RealMatrix; -import org.apache.commons.math3.optim.PointVectorValuePair; /** * The data necessary to define a non-linear least squares problem. Includes the observed @@ -12,7 +12,7 @@ import org.apache.commons.math3.optim.PointVectorValuePair; * * @version $Id$ */ -public interface LeastSquaresProblem extends OptimizationProblem { +public interface LeastSquaresProblem extends OptimizationProblem { /** * Gets the initial guess. diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java index 15e306705..fb89d9b2c 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java @@ -17,11 +17,11 @@ package org.apache.commons.math3.fitting.leastsquares; import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.optim.AbstractOptimizationProblem; import org.apache.commons.math3.optim.ConvergenceChecker; -import org.apache.commons.math3.optim.PointVectorValuePair; import org.apache.commons.math3.util.Pair; /** @@ -32,7 +32,7 @@ import org.apache.commons.math3.util.Pair; * @since 3.3 */ class LeastSquaresProblemImpl - extends AbstractOptimizationProblem + extends AbstractOptimizationProblem implements LeastSquaresProblem { /** Target values for the model function at optimum. */ @@ -45,7 +45,7 @@ class LeastSquaresProblemImpl LeastSquaresProblemImpl(final MultivariateJacobianFunction model, final double[] target, final double[] start, - final ConvergenceChecker checker, + final ConvergenceChecker checker, final int maxEvaluations, final int maxIterations) { super(maxEvaluations, maxIterations, checker); diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizer.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizer.java index 67155b092..fbc4d7009 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizer.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizer.java @@ -23,7 +23,6 @@ import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.exception.ConvergenceException; import org.apache.commons.math3.exception.util.LocalizedFormats; import org.apache.commons.math3.optim.ConvergenceChecker; -import org.apache.commons.math3.optim.PointVectorValuePair; import org.apache.commons.math3.util.Incrementor; import org.apache.commons.math3.util.Precision; import org.apache.commons.math3.util.FastMath; @@ -303,7 +302,7 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer { final Incrementor iterationCounter = problem.getIterationCounter(); final Incrementor evaluationCounter = problem.getEvaluationCounter(); //convergence criterion - final ConvergenceChecker checker + final ConvergenceChecker checker = problem.getConvergenceChecker(); // arrays shared with the other private methods @@ -319,7 +318,6 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer { double[] diag = new double[nC]; double[] oldX = new double[nC]; double[] oldRes = new double[nR]; - double[] oldObj = new double[nR]; double[] qtf = new double[nR]; double[] work1 = new double[nC]; double[] work2 = new double[nC]; @@ -329,23 +327,20 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer { // Evaluate the function at the starting point and calculate its norm. evaluationCounter.incrementCount(); //value will be reassigned in the loop - Evaluation value = problem.evaluate(currentPoint); - double[] currentObjective = value.computeValue(); - double[] currentResiduals = value.computeResiduals(); - PointVectorValuePair current = new PointVectorValuePair(currentPoint, currentObjective); - double currentCost = value.computeCost(); + Evaluation current = problem.evaluate(currentPoint); + double[] currentResiduals = current.computeResiduals(); + double currentCost = current.computeCost(); // Outer loop. boolean firstIteration = true; while (true) { iterationCounter.incrementCount(); - final PointVectorValuePair previous = current; - final Evaluation previousValue = value; + final Evaluation previous = current; // QR decomposition of the jacobian matrix final InternalData internalData - = qrDecomposition(value.computeJacobian(), solvedCols); + = qrDecomposition(current.computeJacobian(), solvedCols); final double[][] weightedJacobian = internalData.weightedJacobian; final int[] permutation = internalData.permutation; final double[] diagR = internalData.diagR; @@ -404,7 +399,7 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer { if (maxCosine <= orthoTolerance) { // Convergence has been reached. return new OptimumImpl( - value, + current, evaluationCounter.getCount(), iterationCounter.getCount()); } @@ -426,9 +421,6 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer { double[] tmpVec = weightedResidual; weightedResidual = oldRes; oldRes = tmpVec; - tmpVec = currentObjective; - currentObjective = oldObj; - oldObj = tmpVec; // determine the Levenberg-Marquardt parameter lmPar = determineLMParameter(qtf, delta, diag, @@ -452,11 +444,9 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer { // Evaluate the function at x + p and calculate its norm. evaluationCounter.incrementCount(); - value = problem.evaluate(currentPoint); - currentObjective = value.computeValue(); - currentResiduals = value.computeResiduals(); - current = new PointVectorValuePair(currentPoint, currentObjective); - currentCost = value.computeCost(); + current = problem.evaluate(currentPoint); + currentResiduals = current.computeResiduals(); + currentCost = current.computeCost(); // compute the scaled actual reduction double actRed = -1.0; @@ -515,7 +505,7 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer { // tests for convergence. if (checker != null && checker.converged(iterationCounter.getCount(), previous, current)) { - return new OptimumImpl(value, iterationCounter.getCount(), evaluationCounter.getCount()); + return new OptimumImpl(current, iterationCounter.getCount(), evaluationCounter.getCount()); } } else { // failed iteration, reset the previous values @@ -527,12 +517,8 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer { tmpVec = weightedResidual; weightedResidual = oldRes; oldRes = tmpVec; - tmpVec = currentObjective; - currentObjective = oldObj; - oldObj = tmpVec; // Reset "current" to previous values. - current = new PointVectorValuePair(currentPoint, currentObjective); - value = previousValue; + current = previous; } // Default convergence criteria. @@ -540,7 +526,7 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer { preRed <= costRelativeTolerance && ratio <= 2.0) || delta <= parRelativeTolerance * xNorm) { - return new OptimumImpl(value, iterationCounter.getCount(), evaluationCounter.getCount()); + return new OptimumImpl(current, iterationCounter.getCount(), evaluationCounter.getCount()); } // tests for termination and stringent tolerances diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java b/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java index 05b50abf3..ad34f84c0 100644 --- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java +++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java @@ -46,7 +46,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest { public LeastSquaresBuilder base() { return new LeastSquaresBuilder() - .checker(new SimpleVectorValueChecker(1e-6, 1e-6)) + .checkerPair(new SimpleVectorValueChecker(1e-6, 1e-6)) .maxEvaluations(100) .maxIterations(getMaxIterations()); } diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerTest.java b/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerTest.java index 17a0f3cd8..1afa44898 100644 --- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerTest.java +++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerTest.java @@ -96,7 +96,7 @@ public class GaussNewtonOptimizerTest circle.addPoint( 45.0, 97.0); LeastSquaresProblem lsp = builder(circle) - .checker(new SimpleVectorValueChecker(1e-30, 1e-30)) + .checkerPair(new SimpleVectorValueChecker(1e-30, 1e-30)) .maxIterations(Integer.MAX_VALUE) .start(new double[]{98.680, 47.345}) .build();