MATH-1144

Point must be retrieved after the call to "evaluate" (to ensure that
the validated parameters are actually used by the optimizer).
This commit is contained in:
Gilles 2014-10-15 22:14:46 +02:00
parent 4c5cda210e
commit 566c4d59a1
3 changed files with 9 additions and 7 deletions

View File

@ -230,6 +230,7 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
current = lsp.evaluate(currentPoint); current = lsp.evaluate(currentPoint);
final RealVector currentResiduals = current.getResiduals(); final RealVector currentResiduals = current.getResiduals();
final RealMatrix weightedJacobian = current.getJacobian(); final RealMatrix weightedJacobian = current.getJacobian();
currentPoint = current.getPoint();
// Check convergence. // Check convergence.
if (previous != null) { if (previous != null) {

View File

@ -294,16 +294,14 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
/** {@inheritDoc} */ /** {@inheritDoc} */
public Optimum optimize(final LeastSquaresProblem problem) { public Optimum optimize(final LeastSquaresProblem problem) {
//pull in relevant data from the problem as locals // Pull in relevant data from the problem as locals.
final int nR = problem.getObservationSize(); // Number of observed data. final int nR = problem.getObservationSize(); // Number of observed data.
final int nC = problem.getParameterSize(); // Number of parameters. final int nC = problem.getParameterSize(); // Number of parameters.
final double[] currentPoint = problem.getStart().toArray(); // Counters.
//counters
final Incrementor iterationCounter = problem.getIterationCounter(); final Incrementor iterationCounter = problem.getIterationCounter();
final Incrementor evaluationCounter = problem.getEvaluationCounter(); final Incrementor evaluationCounter = problem.getEvaluationCounter();
//convergence criterion // Convergence criterion.
final ConvergenceChecker<Evaluation> checker final ConvergenceChecker<Evaluation> checker = problem.getConvergenceChecker();
= problem.getConvergenceChecker();
// arrays shared with the other private methods // arrays shared with the other private methods
final int solvedCols = FastMath.min(nR, nC); final int solvedCols = FastMath.min(nR, nC);
@ -327,9 +325,10 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
// Evaluate the function at the starting point and calculate its norm. // Evaluate the function at the starting point and calculate its norm.
evaluationCounter.incrementCount(); evaluationCounter.incrementCount();
//value will be reassigned in the loop //value will be reassigned in the loop
Evaluation current = problem.evaluate(new ArrayRealVector(currentPoint)); Evaluation current = problem.evaluate(problem.getStart());
double[] currentResiduals = current.getResiduals().toArray(); double[] currentResiduals = current.getResiduals().toArray();
double currentCost = current.getCost(); double currentCost = current.getCost();
double[] currentPoint = current.getPoint().toArray();
// Outer loop. // Outer loop.
boolean firstIteration = true; boolean firstIteration = true;
@ -447,6 +446,7 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
current = problem.evaluate(new ArrayRealVector(currentPoint)); current = problem.evaluate(new ArrayRealVector(currentPoint));
currentResiduals = current.getResiduals().toArray(); currentResiduals = current.getResiduals().toArray();
currentCost = current.getCost(); currentCost = current.getCost();
currentPoint = current.getPoint().toArray();
// compute the scaled actual reduction // compute the scaled actual reduction
double actRed = -1.0; double actRed = -1.0;

View File

@ -307,6 +307,7 @@ public class LevenbergMarquardtOptimizerTest
= optimizer.optimize(builder(circle).maxIterations(50).start(init).parameterValidator(cheatValidator).build()); = optimizer.optimize(builder(circle).maxIterations(50).start(init).parameterValidator(cheatValidator).build());
final int cheatNumEval = cheatOptimum.getEvaluations(); final int cheatNumEval = cheatOptimum.getEvaluations();
Assert.assertTrue(cheatNumEval < numEval); Assert.assertTrue(cheatNumEval < numEval);
System.out.println("n=" + numEval + " nc=" + cheatNumEval);
} }
@Test @Test