MATH-887
Made "GaussNewtonOptimizer" use the new methods in base class "AbstractLeastSquaresOptimizer" instead of modifying the (now deprecated) protected fields. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1407034 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
54cfc6ce0a
commit
2ce650bf54
|
@ -18,6 +18,8 @@
|
|||
package org.apache.commons.math3.optimization.general;
|
||||
|
||||
import org.apache.commons.math3.exception.ConvergenceException;
|
||||
import org.apache.commons.math3.exception.NullArgumentException;
|
||||
import org.apache.commons.math3.exception.MathInternalError;
|
||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.commons.math3.linear.BlockRealMatrix;
|
||||
|
@ -43,7 +45,6 @@ import org.apache.commons.math3.optimization.PointVectorValuePair;
|
|||
* @since 2.0
|
||||
*
|
||||
*/
|
||||
|
||||
public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
|
||||
/** Indicator for using LU decomposition. */
|
||||
private final boolean useLU;
|
||||
|
@ -100,10 +101,26 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
|
|||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public PointVectorValuePair doOptimize() {
|
||||
|
||||
final ConvergenceChecker<PointVectorValuePair> checker
|
||||
= getConvergenceChecker();
|
||||
|
||||
// Computation will be useless without a checker (see "for-loop").
|
||||
if (checker == null) {
|
||||
throw new NullArgumentException();
|
||||
}
|
||||
|
||||
final double[] targetValues = getTarget();
|
||||
final int nR = targetValues.length; // Number of observed data.
|
||||
|
||||
final RealMatrix weightMatrix = getWeight();
|
||||
// Diagonal of the weight matrix.
|
||||
final double[] residualsWeights = new double[nR];
|
||||
for (int i = 0; i < nR; i++) {
|
||||
residualsWeights[i] = weightMatrix.getEntry(i, i);
|
||||
}
|
||||
|
||||
double[] currentPoint = getStartPoint();
|
||||
|
||||
// iterate until convergence is reached
|
||||
PointVectorValuePair current = null;
|
||||
int iter = 0;
|
||||
|
@ -112,21 +129,22 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
|
|||
|
||||
// evaluate the objective function and its jacobian
|
||||
PointVectorValuePair previous = current;
|
||||
updateResidualsAndCost();
|
||||
updateJacobian();
|
||||
current = new PointVectorValuePair(point, objective);
|
||||
|
||||
final double[] targetValues = getTargetRef();
|
||||
final double[] residualsWeights = getWeightRef();
|
||||
// Value of the objective function at "currentPoint".
|
||||
final double[] currentObjective = computeObjectiveValue(currentPoint);
|
||||
final double[] currentResiduals = computeResiduals(currentObjective);
|
||||
final RealMatrix weightedJacobian = computeJacobian(currentPoint);
|
||||
current = new PointVectorValuePair(currentPoint, currentObjective);
|
||||
|
||||
// build the linear problem
|
||||
final double[] b = new double[cols];
|
||||
final double[][] a = new double[cols][cols];
|
||||
for (int i = 0; i < rows; ++i) {
|
||||
|
||||
final double[] grad = weightedResidualJacobian[i];
|
||||
final double[] grad = weightedJacobian.getRow(i);
|
||||
final double weight = residualsWeights[i];
|
||||
final double residual = objective[i] - targetValues[i];
|
||||
// XXX Minus sign could be left out if "weightedJacobian"
|
||||
// would be defined differently.
|
||||
final double residual = -currentResiduals[i];
|
||||
|
||||
// compute the normal equation
|
||||
final double wr = weight * residual;
|
||||
|
@ -153,20 +171,22 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
|
|||
final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
|
||||
// update the estimated parameters
|
||||
for (int i = 0; i < cols; ++i) {
|
||||
point[i] += dX[i];
|
||||
currentPoint[i] += dX[i];
|
||||
}
|
||||
} catch (SingularMatrixException e) {
|
||||
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
|
||||
}
|
||||
|
||||
// check convergence
|
||||
if (checker != null) {
|
||||
if (previous != null) {
|
||||
converged = checker.converged(iter, previous, current);
|
||||
// Check convergence.
|
||||
if (previous != null) {
|
||||
converged = checker.converged(iter, previous, current);
|
||||
if (converged) {
|
||||
cost = computeCost(currentResiduals);
|
||||
return current;
|
||||
}
|
||||
}
|
||||
}
|
||||
// we have converged
|
||||
return current;
|
||||
// Must never happen.
|
||||
throw new MathInternalError();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue