Made "GaussNewtonOptimizer" use the new methods in base class
"AbstractLeastSquaresOptimizer" instead of modifying the (now
deprecated) protected fields.


git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1407034 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Gilles Sadowski 2012-11-08 12:48:54 +00:00
parent 54cfc6ce0a
commit 2ce650bf54
1 changed files with 37 additions and 17 deletions

View File

@ -18,6 +18,8 @@
package org.apache.commons.math3.optimization.general; package org.apache.commons.math3.optimization.general;
import org.apache.commons.math3.exception.ConvergenceException; import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.NullArgumentException;
import org.apache.commons.math3.exception.MathInternalError;
import org.apache.commons.math3.exception.util.LocalizedFormats; import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.BlockRealMatrix; import org.apache.commons.math3.linear.BlockRealMatrix;
@ -43,7 +45,6 @@ import org.apache.commons.math3.optimization.PointVectorValuePair;
* @since 2.0 * @since 2.0
* *
*/ */
public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer { public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
/** Indicator for using LU decomposition. */ /** Indicator for using LU decomposition. */
private final boolean useLU; private final boolean useLU;
@ -100,10 +101,26 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
/** {@inheritDoc} */ /** {@inheritDoc} */
@Override @Override
public PointVectorValuePair doOptimize() { public PointVectorValuePair doOptimize() {
final ConvergenceChecker<PointVectorValuePair> checker final ConvergenceChecker<PointVectorValuePair> checker
= getConvergenceChecker(); = getConvergenceChecker();
// Computation will be useless without a checker (see "for-loop").
if (checker == null) {
throw new NullArgumentException();
}
final double[] targetValues = getTarget();
final int nR = targetValues.length; // Number of observed data.
final RealMatrix weightMatrix = getWeight();
// Diagonal of the weight matrix.
final double[] residualsWeights = new double[nR];
for (int i = 0; i < nR; i++) {
residualsWeights[i] = weightMatrix.getEntry(i, i);
}
double[] currentPoint = getStartPoint();
// iterate until convergence is reached // iterate until convergence is reached
PointVectorValuePair current = null; PointVectorValuePair current = null;
int iter = 0; int iter = 0;
@ -112,21 +129,22 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
// evaluate the objective function and its jacobian // evaluate the objective function and its jacobian
PointVectorValuePair previous = current; PointVectorValuePair previous = current;
updateResidualsAndCost(); // Value of the objective function at "currentPoint".
updateJacobian(); final double[] currentObjective = computeObjectiveValue(currentPoint);
current = new PointVectorValuePair(point, objective); final double[] currentResiduals = computeResiduals(currentObjective);
final RealMatrix weightedJacobian = computeJacobian(currentPoint);
final double[] targetValues = getTargetRef(); current = new PointVectorValuePair(currentPoint, currentObjective);
final double[] residualsWeights = getWeightRef();
// build the linear problem // build the linear problem
final double[] b = new double[cols]; final double[] b = new double[cols];
final double[][] a = new double[cols][cols]; final double[][] a = new double[cols][cols];
for (int i = 0; i < rows; ++i) { for (int i = 0; i < rows; ++i) {
final double[] grad = weightedResidualJacobian[i]; final double[] grad = weightedJacobian.getRow(i);
final double weight = residualsWeights[i]; final double weight = residualsWeights[i];
final double residual = objective[i] - targetValues[i]; // XXX Minus sign could be left out if "weightedJacobian"
// would be defined differently.
final double residual = -currentResiduals[i];
// compute the normal equation // compute the normal equation
final double wr = weight * residual; final double wr = weight * residual;
@ -153,20 +171,22 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray(); final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
// update the estimated parameters // update the estimated parameters
for (int i = 0; i < cols; ++i) { for (int i = 0; i < cols; ++i) {
point[i] += dX[i]; currentPoint[i] += dX[i];
} }
} catch (SingularMatrixException e) { } catch (SingularMatrixException e) {
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM); throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
} }
// check convergence // Check convergence.
if (checker != null) { if (previous != null) {
if (previous != null) { converged = checker.converged(iter, previous, current);
converged = checker.converged(iter, previous, current); if (converged) {
cost = computeCost(currentResiduals);
return current;
} }
} }
} }
// we have converged // Must never happen.
return current; throw new MathInternalError();
} }
} }