diff --git a/src/main/java/org/apache/commons/math3/optimization/general/GaussNewtonOptimizer.java b/src/main/java/org/apache/commons/math3/optimization/general/GaussNewtonOptimizer.java index ef01065e4..33d0d8df0 100644 --- a/src/main/java/org/apache/commons/math3/optimization/general/GaussNewtonOptimizer.java +++ b/src/main/java/org/apache/commons/math3/optimization/general/GaussNewtonOptimizer.java @@ -18,6 +18,8 @@ package org.apache.commons.math3.optimization.general; import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.MathInternalError; import org.apache.commons.math3.exception.util.LocalizedFormats; import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.BlockRealMatrix; @@ -43,7 +45,6 @@ import org.apache.commons.math3.optimization.PointVectorValuePair; * @since 2.0 * */ - public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer { /** Indicator for using LU decomposition. */ private final boolean useLU; @@ -100,10 +101,26 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer { /** {@inheritDoc} */ @Override public PointVectorValuePair doOptimize() { - final ConvergenceChecker checker = getConvergenceChecker(); + // Computation will be useless without a checker (see "for-loop"). + if (checker == null) { + throw new NullArgumentException(); + } + + final double[] targetValues = getTarget(); + final int nR = targetValues.length; // Number of observed data. + + final RealMatrix weightMatrix = getWeight(); + // Diagonal of the weight matrix. + final double[] residualsWeights = new double[nR]; + for (int i = 0; i < nR; i++) { + residualsWeights[i] = weightMatrix.getEntry(i, i); + } + + double[] currentPoint = getStartPoint(); + // iterate until convergence is reached PointVectorValuePair current = null; int iter = 0; @@ -112,21 +129,22 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer { // evaluate the objective function and its jacobian PointVectorValuePair previous = current; - updateResidualsAndCost(); - updateJacobian(); - current = new PointVectorValuePair(point, objective); - - final double[] targetValues = getTargetRef(); - final double[] residualsWeights = getWeightRef(); + // Value of the objective function at "currentPoint". + final double[] currentObjective = computeObjectiveValue(currentPoint); + final double[] currentResiduals = computeResiduals(currentObjective); + final RealMatrix weightedJacobian = computeJacobian(currentPoint); + current = new PointVectorValuePair(currentPoint, currentObjective); // build the linear problem final double[] b = new double[cols]; final double[][] a = new double[cols][cols]; for (int i = 0; i < rows; ++i) { - final double[] grad = weightedResidualJacobian[i]; + final double[] grad = weightedJacobian.getRow(i); final double weight = residualsWeights[i]; - final double residual = objective[i] - targetValues[i]; + // XXX Minus sign could be left out if "weightedJacobian" + // would be defined differently. + final double residual = -currentResiduals[i]; // compute the normal equation final double wr = weight * residual; @@ -153,20 +171,22 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer { final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray(); // update the estimated parameters for (int i = 0; i < cols; ++i) { - point[i] += dX[i]; + currentPoint[i] += dX[i]; } } catch (SingularMatrixException e) { throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM); } - // check convergence - if (checker != null) { - if (previous != null) { - converged = checker.converged(iter, previous, current); + // Check convergence. + if (previous != null) { + converged = checker.converged(iter, previous, current); + if (converged) { + cost = computeCost(currentResiduals); + return current; } } } - // we have converged - return current; + // Must never happen. + throw new MathInternalError(); } }