Made "GaussNewtonOptimizer" use the new methods in base class
"AbstractLeastSquaresOptimizer" instead of modifying the (now
deprecated) protected fields.


git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1407034 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Gilles Sadowski 2012-11-08 12:48:54 +00:00
parent 54cfc6ce0a
commit 2ce650bf54
1 changed files with 37 additions and 17 deletions

View File

@ -18,6 +18,8 @@
package org.apache.commons.math3.optimization.general;
import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.NullArgumentException;
import org.apache.commons.math3.exception.MathInternalError;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.BlockRealMatrix;
@ -43,7 +45,6 @@ import org.apache.commons.math3.optimization.PointVectorValuePair;
* @since 2.0
*
*/
public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
/** Indicator for using LU decomposition. */
private final boolean useLU;
@ -100,10 +101,26 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
/** {@inheritDoc} */
@Override
public PointVectorValuePair doOptimize() {
final ConvergenceChecker<PointVectorValuePair> checker
= getConvergenceChecker();
// Computation will be useless without a checker (see "for-loop").
if (checker == null) {
throw new NullArgumentException();
}
final double[] targetValues = getTarget();
final int nR = targetValues.length; // Number of observed data.
final RealMatrix weightMatrix = getWeight();
// Diagonal of the weight matrix.
final double[] residualsWeights = new double[nR];
for (int i = 0; i < nR; i++) {
residualsWeights[i] = weightMatrix.getEntry(i, i);
}
double[] currentPoint = getStartPoint();
// iterate until convergence is reached
PointVectorValuePair current = null;
int iter = 0;
@ -112,21 +129,22 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
// evaluate the objective function and its jacobian
PointVectorValuePair previous = current;
updateResidualsAndCost();
updateJacobian();
current = new PointVectorValuePair(point, objective);
final double[] targetValues = getTargetRef();
final double[] residualsWeights = getWeightRef();
// Value of the objective function at "currentPoint".
final double[] currentObjective = computeObjectiveValue(currentPoint);
final double[] currentResiduals = computeResiduals(currentObjective);
final RealMatrix weightedJacobian = computeJacobian(currentPoint);
current = new PointVectorValuePair(currentPoint, currentObjective);
// build the linear problem
final double[] b = new double[cols];
final double[][] a = new double[cols][cols];
for (int i = 0; i < rows; ++i) {
final double[] grad = weightedResidualJacobian[i];
final double[] grad = weightedJacobian.getRow(i);
final double weight = residualsWeights[i];
final double residual = objective[i] - targetValues[i];
// XXX Minus sign could be left out if "weightedJacobian"
// would be defined differently.
final double residual = -currentResiduals[i];
// compute the normal equation
final double wr = weight * residual;
@ -153,20 +171,22 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
// update the estimated parameters
for (int i = 0; i < cols; ++i) {
point[i] += dX[i];
currentPoint[i] += dX[i];
}
} catch (SingularMatrixException e) {
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
}
// check convergence
if (checker != null) {
if (previous != null) {
converged = checker.converged(iter, previous, current);
// Check convergence.
if (previous != null) {
converged = checker.converged(iter, previous, current);
if (converged) {
cost = computeCost(currentResiduals);
return current;
}
}
}
// we have converged
return current;
// Must never happen.
throw new MathInternalError();
}
}