Use Evaluation instead of PointVectorValuePair
Use Evaluation instead of PointVectorValuePair in the ConvergenceChecker. This gives the checkers access to more information, such as the rms and covariances. The change also simplified the optimizer implementations since they no longer have to keep track of the current function value. A method was added to LeastSquaresFactory to convert between the two types of checkers and a method added to LeastSquaresBuilder so that it can accept either type. I would have prefered to do this through method overloading, but overloading doesn't play well with generics. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569353 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
3e18e999c7
commit
a7a380f934
|
@ -28,7 +28,6 @@ import org.apache.commons.math3.linear.QRDecomposition;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
import org.apache.commons.math3.linear.SingularMatrixException;
|
import org.apache.commons.math3.linear.SingularMatrixException;
|
||||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
|
||||||
import org.apache.commons.math3.util.Incrementor;
|
import org.apache.commons.math3.util.Incrementor;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -123,7 +122,7 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
|
||||||
//create local evaluation and iteration counts
|
//create local evaluation and iteration counts
|
||||||
final Incrementor evaluationCounter = lsp.getEvaluationCounter();
|
final Incrementor evaluationCounter = lsp.getEvaluationCounter();
|
||||||
final Incrementor iterationCounter = lsp.getIterationCounter();
|
final Incrementor iterationCounter = lsp.getIterationCounter();
|
||||||
final ConvergenceChecker<PointVectorValuePair> checker
|
final ConvergenceChecker<Evaluation> checker
|
||||||
= lsp.getConvergenceChecker();
|
= lsp.getConvergenceChecker();
|
||||||
|
|
||||||
// Computation will be useless without a checker (see "for-loop").
|
// Computation will be useless without a checker (see "for-loop").
|
||||||
|
@ -137,25 +136,23 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
|
||||||
final double[] currentPoint = lsp.getStart();
|
final double[] currentPoint = lsp.getStart();
|
||||||
|
|
||||||
// iterate until convergence is reached
|
// iterate until convergence is reached
|
||||||
PointVectorValuePair current = null;
|
Evaluation current = null;
|
||||||
while (true) {
|
while (true) {
|
||||||
iterationCounter.incrementCount();
|
iterationCounter.incrementCount();
|
||||||
|
|
||||||
// evaluate the objective function and its jacobian
|
// evaluate the objective function and its jacobian
|
||||||
PointVectorValuePair previous = current;
|
Evaluation previous = current;
|
||||||
// Value of the objective function at "currentPoint".
|
// Value of the objective function at "currentPoint".
|
||||||
evaluationCounter.incrementCount();
|
evaluationCounter.incrementCount();
|
||||||
final Evaluation value = lsp.evaluate(currentPoint);
|
current = lsp.evaluate(currentPoint);
|
||||||
final double[] currentObjective = value.computeValue();
|
final double[] currentResiduals = current.computeResiduals();
|
||||||
final double[] currentResiduals = value.computeResiduals();
|
final RealMatrix weightedJacobian = current.computeJacobian();
|
||||||
final RealMatrix weightedJacobian = value.computeJacobian();
|
|
||||||
current = new PointVectorValuePair(currentPoint, currentObjective);
|
|
||||||
|
|
||||||
// Check convergence.
|
// Check convergence.
|
||||||
if (previous != null) {
|
if (previous != null) {
|
||||||
if (checker.converged(iterationCounter.getCount(), previous, current)) {
|
if (checker.converged(iterationCounter.getCount(), previous, current)) {
|
||||||
return new OptimumImpl(
|
return new OptimumImpl(
|
||||||
value,
|
current,
|
||||||
evaluationCounter.getCount(),
|
evaluationCounter.getCount(),
|
||||||
iterationCounter.getCount());
|
iterationCounter.getCount());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package org.apache.commons.math3.fitting.leastsquares;
|
package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
|
||||||
import org.apache.commons.math3.util.Incrementor;
|
import org.apache.commons.math3.util.Incrementor;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -54,7 +53,7 @@ public class LeastSquaresAdapter implements LeastSquaresProblem {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
public ConvergenceChecker<PointVectorValuePair> getConvergenceChecker() {
|
public ConvergenceChecker<Evaluation> getConvergenceChecker() {
|
||||||
return problem.getConvergenceChecker();
|
return problem.getConvergenceChecker();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
||||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||||
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||||
|
@ -19,7 +20,7 @@ public class LeastSquaresBuilder {
|
||||||
/** max iterations */
|
/** max iterations */
|
||||||
private int maxIterations;
|
private int maxIterations;
|
||||||
/** convergence checker */
|
/** convergence checker */
|
||||||
private ConvergenceChecker<PointVectorValuePair> checker;
|
private ConvergenceChecker<Evaluation> checker;
|
||||||
/** model function */
|
/** model function */
|
||||||
private MultivariateVectorFunction model;
|
private MultivariateVectorFunction model;
|
||||||
/** Jacobian function */
|
/** Jacobian function */
|
||||||
|
@ -69,11 +70,23 @@ public class LeastSquaresBuilder {
|
||||||
* @param checker the convergence checker.
|
* @param checker the convergence checker.
|
||||||
* @return this
|
* @return this
|
||||||
*/
|
*/
|
||||||
public LeastSquaresBuilder checker(final ConvergenceChecker<PointVectorValuePair> checker) {
|
public LeastSquaresBuilder checker(final ConvergenceChecker<Evaluation> checker) {
|
||||||
this.checker = checker;
|
this.checker = checker;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure the convergence checker.
|
||||||
|
* <p/>
|
||||||
|
* This function is an overloaded version of {@link #checker(ConvergenceChecker)}.
|
||||||
|
*
|
||||||
|
* @param checker the convergence checker.
|
||||||
|
* @return this
|
||||||
|
*/
|
||||||
|
public LeastSquaresBuilder checkerPair(final ConvergenceChecker<PointVectorValuePair> checker) {
|
||||||
|
return this.checker(LeastSquaresFactory.evaluationChecker(checker));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Configure the model function.
|
* Configure the model function.
|
||||||
*
|
*
|
||||||
|
|
|
@ -2,6 +2,7 @@ package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
||||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||||
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
||||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||||
import org.apache.commons.math3.linear.DiagonalMatrix;
|
import org.apache.commons.math3.linear.DiagonalMatrix;
|
||||||
|
@ -40,7 +41,7 @@ public class LeastSquaresFactory {
|
||||||
public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
|
public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
|
||||||
final double[] observed,
|
final double[] observed,
|
||||||
final double[] start,
|
final double[] start,
|
||||||
final ConvergenceChecker<PointVectorValuePair> checker,
|
final ConvergenceChecker<Evaluation> checker,
|
||||||
final int maxEvaluations,
|
final int maxEvaluations,
|
||||||
final int maxIterations) {
|
final int maxIterations) {
|
||||||
return new LeastSquaresProblemImpl(
|
return new LeastSquaresProblemImpl(
|
||||||
|
@ -70,7 +71,7 @@ public class LeastSquaresFactory {
|
||||||
final MultivariateMatrixFunction jacobian,
|
final MultivariateMatrixFunction jacobian,
|
||||||
final double[] observed,
|
final double[] observed,
|
||||||
final double[] start,
|
final double[] start,
|
||||||
final ConvergenceChecker<PointVectorValuePair> checker,
|
final ConvergenceChecker<Evaluation> checker,
|
||||||
final int maxEvaluations,
|
final int maxEvaluations,
|
||||||
final int maxIterations) {
|
final int maxIterations) {
|
||||||
return create(
|
return create(
|
||||||
|
@ -102,7 +103,7 @@ public class LeastSquaresFactory {
|
||||||
final double[] observed,
|
final double[] observed,
|
||||||
final double[] start,
|
final double[] start,
|
||||||
final RealMatrix weight,
|
final RealMatrix weight,
|
||||||
final ConvergenceChecker<PointVectorValuePair> checker,
|
final ConvergenceChecker<Evaluation> checker,
|
||||||
final int maxEvaluations,
|
final int maxEvaluations,
|
||||||
final int maxIterations) {
|
final int maxIterations) {
|
||||||
return weightMatrix(
|
return weightMatrix(
|
||||||
|
@ -174,6 +175,35 @@ public class LeastSquaresFactory {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* View a convergence checker specified for a {@link PointVectorValuePair} as one
|
||||||
|
* specified for an {@link Evaluation}.
|
||||||
|
*
|
||||||
|
* @param checker the convergence checker to adapt.
|
||||||
|
* @return a convergence checker that delegates to {@code checker}.
|
||||||
|
*/
|
||||||
|
public static ConvergenceChecker<Evaluation> evaluationChecker(
|
||||||
|
final ConvergenceChecker<PointVectorValuePair> checker
|
||||||
|
) {
|
||||||
|
return new ConvergenceChecker<Evaluation>() {
|
||||||
|
public boolean converged(final int iteration,
|
||||||
|
final Evaluation previous,
|
||||||
|
final Evaluation current) {
|
||||||
|
return checker.converged(
|
||||||
|
iteration,
|
||||||
|
new PointVectorValuePair(
|
||||||
|
previous.getPoint(),
|
||||||
|
previous.computeValue(),
|
||||||
|
false),
|
||||||
|
new PointVectorValuePair(
|
||||||
|
current.getPoint(),
|
||||||
|
current.computeValue(),
|
||||||
|
false)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Computes the square-root of the weight matrix.
|
* Computes the square-root of the weight matrix.
|
||||||
*
|
*
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
package org.apache.commons.math3.fitting.leastsquares;
|
package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||||
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The data necessary to define a non-linear least squares problem. Includes the observed
|
* The data necessary to define a non-linear least squares problem. Includes the observed
|
||||||
|
@ -12,7 +12,7 @@ import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||||
*
|
*
|
||||||
* @version $Id$
|
* @version $Id$
|
||||||
*/
|
*/
|
||||||
public interface LeastSquaresProblem extends OptimizationProblem<PointVectorValuePair> {
|
public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the initial guess.
|
* Gets the initial guess.
|
||||||
|
|
|
@ -17,11 +17,11 @@
|
||||||
package org.apache.commons.math3.fitting.leastsquares;
|
package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||||
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
import org.apache.commons.math3.linear.RealVector;
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
import org.apache.commons.math3.optim.AbstractOptimizationProblem;
|
import org.apache.commons.math3.optim.AbstractOptimizationProblem;
|
||||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
|
||||||
import org.apache.commons.math3.util.Pair;
|
import org.apache.commons.math3.util.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -32,7 +32,7 @@ import org.apache.commons.math3.util.Pair;
|
||||||
* @since 3.3
|
* @since 3.3
|
||||||
*/
|
*/
|
||||||
class LeastSquaresProblemImpl
|
class LeastSquaresProblemImpl
|
||||||
extends AbstractOptimizationProblem<PointVectorValuePair>
|
extends AbstractOptimizationProblem<Evaluation>
|
||||||
implements LeastSquaresProblem {
|
implements LeastSquaresProblem {
|
||||||
|
|
||||||
/** Target values for the model function at optimum. */
|
/** Target values for the model function at optimum. */
|
||||||
|
@ -45,7 +45,7 @@ class LeastSquaresProblemImpl
|
||||||
LeastSquaresProblemImpl(final MultivariateJacobianFunction model,
|
LeastSquaresProblemImpl(final MultivariateJacobianFunction model,
|
||||||
final double[] target,
|
final double[] target,
|
||||||
final double[] start,
|
final double[] start,
|
||||||
final ConvergenceChecker<PointVectorValuePair> checker,
|
final ConvergenceChecker<Evaluation> checker,
|
||||||
final int maxEvaluations,
|
final int maxEvaluations,
|
||||||
final int maxIterations) {
|
final int maxIterations) {
|
||||||
super(maxEvaluations, maxIterations, checker);
|
super(maxEvaluations, maxIterations, checker);
|
||||||
|
|
|
@ -23,7 +23,6 @@ import org.apache.commons.math3.linear.RealMatrix;
|
||||||
import org.apache.commons.math3.exception.ConvergenceException;
|
import org.apache.commons.math3.exception.ConvergenceException;
|
||||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
|
||||||
import org.apache.commons.math3.util.Incrementor;
|
import org.apache.commons.math3.util.Incrementor;
|
||||||
import org.apache.commons.math3.util.Precision;
|
import org.apache.commons.math3.util.Precision;
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
|
@ -303,7 +302,7 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
|
||||||
final Incrementor iterationCounter = problem.getIterationCounter();
|
final Incrementor iterationCounter = problem.getIterationCounter();
|
||||||
final Incrementor evaluationCounter = problem.getEvaluationCounter();
|
final Incrementor evaluationCounter = problem.getEvaluationCounter();
|
||||||
//convergence criterion
|
//convergence criterion
|
||||||
final ConvergenceChecker<PointVectorValuePair> checker
|
final ConvergenceChecker<Evaluation> checker
|
||||||
= problem.getConvergenceChecker();
|
= problem.getConvergenceChecker();
|
||||||
|
|
||||||
// arrays shared with the other private methods
|
// arrays shared with the other private methods
|
||||||
|
@ -319,7 +318,6 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
|
||||||
double[] diag = new double[nC];
|
double[] diag = new double[nC];
|
||||||
double[] oldX = new double[nC];
|
double[] oldX = new double[nC];
|
||||||
double[] oldRes = new double[nR];
|
double[] oldRes = new double[nR];
|
||||||
double[] oldObj = new double[nR];
|
|
||||||
double[] qtf = new double[nR];
|
double[] qtf = new double[nR];
|
||||||
double[] work1 = new double[nC];
|
double[] work1 = new double[nC];
|
||||||
double[] work2 = new double[nC];
|
double[] work2 = new double[nC];
|
||||||
|
@ -329,23 +327,20 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
|
||||||
// Evaluate the function at the starting point and calculate its norm.
|
// Evaluate the function at the starting point and calculate its norm.
|
||||||
evaluationCounter.incrementCount();
|
evaluationCounter.incrementCount();
|
||||||
//value will be reassigned in the loop
|
//value will be reassigned in the loop
|
||||||
Evaluation value = problem.evaluate(currentPoint);
|
Evaluation current = problem.evaluate(currentPoint);
|
||||||
double[] currentObjective = value.computeValue();
|
double[] currentResiduals = current.computeResiduals();
|
||||||
double[] currentResiduals = value.computeResiduals();
|
double currentCost = current.computeCost();
|
||||||
PointVectorValuePair current = new PointVectorValuePair(currentPoint, currentObjective);
|
|
||||||
double currentCost = value.computeCost();
|
|
||||||
|
|
||||||
// Outer loop.
|
// Outer loop.
|
||||||
boolean firstIteration = true;
|
boolean firstIteration = true;
|
||||||
while (true) {
|
while (true) {
|
||||||
iterationCounter.incrementCount();
|
iterationCounter.incrementCount();
|
||||||
|
|
||||||
final PointVectorValuePair previous = current;
|
final Evaluation previous = current;
|
||||||
final Evaluation previousValue = value;
|
|
||||||
|
|
||||||
// QR decomposition of the jacobian matrix
|
// QR decomposition of the jacobian matrix
|
||||||
final InternalData internalData
|
final InternalData internalData
|
||||||
= qrDecomposition(value.computeJacobian(), solvedCols);
|
= qrDecomposition(current.computeJacobian(), solvedCols);
|
||||||
final double[][] weightedJacobian = internalData.weightedJacobian;
|
final double[][] weightedJacobian = internalData.weightedJacobian;
|
||||||
final int[] permutation = internalData.permutation;
|
final int[] permutation = internalData.permutation;
|
||||||
final double[] diagR = internalData.diagR;
|
final double[] diagR = internalData.diagR;
|
||||||
|
@ -404,7 +399,7 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
|
||||||
if (maxCosine <= orthoTolerance) {
|
if (maxCosine <= orthoTolerance) {
|
||||||
// Convergence has been reached.
|
// Convergence has been reached.
|
||||||
return new OptimumImpl(
|
return new OptimumImpl(
|
||||||
value,
|
current,
|
||||||
evaluationCounter.getCount(),
|
evaluationCounter.getCount(),
|
||||||
iterationCounter.getCount());
|
iterationCounter.getCount());
|
||||||
}
|
}
|
||||||
|
@ -426,9 +421,6 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
|
||||||
double[] tmpVec = weightedResidual;
|
double[] tmpVec = weightedResidual;
|
||||||
weightedResidual = oldRes;
|
weightedResidual = oldRes;
|
||||||
oldRes = tmpVec;
|
oldRes = tmpVec;
|
||||||
tmpVec = currentObjective;
|
|
||||||
currentObjective = oldObj;
|
|
||||||
oldObj = tmpVec;
|
|
||||||
|
|
||||||
// determine the Levenberg-Marquardt parameter
|
// determine the Levenberg-Marquardt parameter
|
||||||
lmPar = determineLMParameter(qtf, delta, diag,
|
lmPar = determineLMParameter(qtf, delta, diag,
|
||||||
|
@ -452,11 +444,9 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
|
||||||
|
|
||||||
// Evaluate the function at x + p and calculate its norm.
|
// Evaluate the function at x + p and calculate its norm.
|
||||||
evaluationCounter.incrementCount();
|
evaluationCounter.incrementCount();
|
||||||
value = problem.evaluate(currentPoint);
|
current = problem.evaluate(currentPoint);
|
||||||
currentObjective = value.computeValue();
|
currentResiduals = current.computeResiduals();
|
||||||
currentResiduals = value.computeResiduals();
|
currentCost = current.computeCost();
|
||||||
current = new PointVectorValuePair(currentPoint, currentObjective);
|
|
||||||
currentCost = value.computeCost();
|
|
||||||
|
|
||||||
// compute the scaled actual reduction
|
// compute the scaled actual reduction
|
||||||
double actRed = -1.0;
|
double actRed = -1.0;
|
||||||
|
@ -515,7 +505,7 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
|
||||||
|
|
||||||
// tests for convergence.
|
// tests for convergence.
|
||||||
if (checker != null && checker.converged(iterationCounter.getCount(), previous, current)) {
|
if (checker != null && checker.converged(iterationCounter.getCount(), previous, current)) {
|
||||||
return new OptimumImpl(value, iterationCounter.getCount(), evaluationCounter.getCount());
|
return new OptimumImpl(current, iterationCounter.getCount(), evaluationCounter.getCount());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// failed iteration, reset the previous values
|
// failed iteration, reset the previous values
|
||||||
|
@ -527,12 +517,8 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
|
||||||
tmpVec = weightedResidual;
|
tmpVec = weightedResidual;
|
||||||
weightedResidual = oldRes;
|
weightedResidual = oldRes;
|
||||||
oldRes = tmpVec;
|
oldRes = tmpVec;
|
||||||
tmpVec = currentObjective;
|
|
||||||
currentObjective = oldObj;
|
|
||||||
oldObj = tmpVec;
|
|
||||||
// Reset "current" to previous values.
|
// Reset "current" to previous values.
|
||||||
current = new PointVectorValuePair(currentPoint, currentObjective);
|
current = previous;
|
||||||
value = previousValue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default convergence criteria.
|
// Default convergence criteria.
|
||||||
|
@ -540,7 +526,7 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
|
||||||
preRed <= costRelativeTolerance &&
|
preRed <= costRelativeTolerance &&
|
||||||
ratio <= 2.0) ||
|
ratio <= 2.0) ||
|
||||||
delta <= parRelativeTolerance * xNorm) {
|
delta <= parRelativeTolerance * xNorm) {
|
||||||
return new OptimumImpl(value, iterationCounter.getCount(), evaluationCounter.getCount());
|
return new OptimumImpl(current, iterationCounter.getCount(), evaluationCounter.getCount());
|
||||||
}
|
}
|
||||||
|
|
||||||
// tests for termination and stringent tolerances
|
// tests for termination and stringent tolerances
|
||||||
|
|
|
@ -46,7 +46,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
|
|
||||||
public LeastSquaresBuilder base() {
|
public LeastSquaresBuilder base() {
|
||||||
return new LeastSquaresBuilder()
|
return new LeastSquaresBuilder()
|
||||||
.checker(new SimpleVectorValueChecker(1e-6, 1e-6))
|
.checkerPair(new SimpleVectorValueChecker(1e-6, 1e-6))
|
||||||
.maxEvaluations(100)
|
.maxEvaluations(100)
|
||||||
.maxIterations(getMaxIterations());
|
.maxIterations(getMaxIterations());
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,7 +96,7 @@ public class GaussNewtonOptimizerTest
|
||||||
circle.addPoint( 45.0, 97.0);
|
circle.addPoint( 45.0, 97.0);
|
||||||
|
|
||||||
LeastSquaresProblem lsp = builder(circle)
|
LeastSquaresProblem lsp = builder(circle)
|
||||||
.checker(new SimpleVectorValueChecker(1e-30, 1e-30))
|
.checkerPair(new SimpleVectorValueChecker(1e-30, 1e-30))
|
||||||
.maxIterations(Integer.MAX_VALUE)
|
.maxIterations(Integer.MAX_VALUE)
|
||||||
.start(new double[]{98.680, 47.345})
|
.start(new double[]{98.680, 47.345})
|
||||||
.build();
|
.build();
|
||||||
|
|
Loading…
Reference in New Issue