MATH-887
Made "GaussNewtonOptimizer" use the new methods in base class "AbstractLeastSquaresOptimizer" instead of modifying the (now deprecated) protected fields. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1407034 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
54cfc6ce0a
commit
2ce650bf54
|
@ -18,6 +18,8 @@
|
||||||
package org.apache.commons.math3.optimization.general;
|
package org.apache.commons.math3.optimization.general;
|
||||||
|
|
||||||
import org.apache.commons.math3.exception.ConvergenceException;
|
import org.apache.commons.math3.exception.ConvergenceException;
|
||||||
|
import org.apache.commons.math3.exception.NullArgumentException;
|
||||||
|
import org.apache.commons.math3.exception.MathInternalError;
|
||||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||||
import org.apache.commons.math3.linear.BlockRealMatrix;
|
import org.apache.commons.math3.linear.BlockRealMatrix;
|
||||||
|
@ -43,7 +45,6 @@ import org.apache.commons.math3.optimization.PointVectorValuePair;
|
||||||
* @since 2.0
|
* @since 2.0
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
|
public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
|
||||||
/** Indicator for using LU decomposition. */
|
/** Indicator for using LU decomposition. */
|
||||||
private final boolean useLU;
|
private final boolean useLU;
|
||||||
|
@ -100,10 +101,26 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
@Override
|
@Override
|
||||||
public PointVectorValuePair doOptimize() {
|
public PointVectorValuePair doOptimize() {
|
||||||
|
|
||||||
final ConvergenceChecker<PointVectorValuePair> checker
|
final ConvergenceChecker<PointVectorValuePair> checker
|
||||||
= getConvergenceChecker();
|
= getConvergenceChecker();
|
||||||
|
|
||||||
|
// Computation will be useless without a checker (see "for-loop").
|
||||||
|
if (checker == null) {
|
||||||
|
throw new NullArgumentException();
|
||||||
|
}
|
||||||
|
|
||||||
|
final double[] targetValues = getTarget();
|
||||||
|
final int nR = targetValues.length; // Number of observed data.
|
||||||
|
|
||||||
|
final RealMatrix weightMatrix = getWeight();
|
||||||
|
// Diagonal of the weight matrix.
|
||||||
|
final double[] residualsWeights = new double[nR];
|
||||||
|
for (int i = 0; i < nR; i++) {
|
||||||
|
residualsWeights[i] = weightMatrix.getEntry(i, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
double[] currentPoint = getStartPoint();
|
||||||
|
|
||||||
// iterate until convergence is reached
|
// iterate until convergence is reached
|
||||||
PointVectorValuePair current = null;
|
PointVectorValuePair current = null;
|
||||||
int iter = 0;
|
int iter = 0;
|
||||||
|
@ -112,21 +129,22 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
|
||||||
|
|
||||||
// evaluate the objective function and its jacobian
|
// evaluate the objective function and its jacobian
|
||||||
PointVectorValuePair previous = current;
|
PointVectorValuePair previous = current;
|
||||||
updateResidualsAndCost();
|
// Value of the objective function at "currentPoint".
|
||||||
updateJacobian();
|
final double[] currentObjective = computeObjectiveValue(currentPoint);
|
||||||
current = new PointVectorValuePair(point, objective);
|
final double[] currentResiduals = computeResiduals(currentObjective);
|
||||||
|
final RealMatrix weightedJacobian = computeJacobian(currentPoint);
|
||||||
final double[] targetValues = getTargetRef();
|
current = new PointVectorValuePair(currentPoint, currentObjective);
|
||||||
final double[] residualsWeights = getWeightRef();
|
|
||||||
|
|
||||||
// build the linear problem
|
// build the linear problem
|
||||||
final double[] b = new double[cols];
|
final double[] b = new double[cols];
|
||||||
final double[][] a = new double[cols][cols];
|
final double[][] a = new double[cols][cols];
|
||||||
for (int i = 0; i < rows; ++i) {
|
for (int i = 0; i < rows; ++i) {
|
||||||
|
|
||||||
final double[] grad = weightedResidualJacobian[i];
|
final double[] grad = weightedJacobian.getRow(i);
|
||||||
final double weight = residualsWeights[i];
|
final double weight = residualsWeights[i];
|
||||||
final double residual = objective[i] - targetValues[i];
|
// XXX Minus sign could be left out if "weightedJacobian"
|
||||||
|
// would be defined differently.
|
||||||
|
final double residual = -currentResiduals[i];
|
||||||
|
|
||||||
// compute the normal equation
|
// compute the normal equation
|
||||||
final double wr = weight * residual;
|
final double wr = weight * residual;
|
||||||
|
@ -153,20 +171,22 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
|
||||||
final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
|
final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
|
||||||
// update the estimated parameters
|
// update the estimated parameters
|
||||||
for (int i = 0; i < cols; ++i) {
|
for (int i = 0; i < cols; ++i) {
|
||||||
point[i] += dX[i];
|
currentPoint[i] += dX[i];
|
||||||
}
|
}
|
||||||
} catch (SingularMatrixException e) {
|
} catch (SingularMatrixException e) {
|
||||||
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
|
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
|
||||||
}
|
}
|
||||||
|
|
||||||
// check convergence
|
// Check convergence.
|
||||||
if (checker != null) {
|
if (previous != null) {
|
||||||
if (previous != null) {
|
converged = checker.converged(iter, previous, current);
|
||||||
converged = checker.converged(iter, previous, current);
|
if (converged) {
|
||||||
|
cost = computeCost(currentResiduals);
|
||||||
|
return current;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// we have converged
|
// Must never happen.
|
||||||
return current;
|
throw new MathInternalError();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue