Use the new differentation API for all optimizers.

The older API is still supported as of version 3.1, but is implemented
by wrapping the user function into the new API and then calling the new
code.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1401838 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2012-10-24 19:40:08 +00:00
parent 7b5a64c0bb
commit f5765cf99d
1 changed files with 25 additions and 26 deletions

View File

@ -17,16 +17,17 @@
package org.apache.commons.math3.optimization.general;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.analysis.DifferentiableMultivariateVectorFunction;
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
import org.apache.commons.math3.analysis.FunctionUtils;
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
import org.apache.commons.math3.analysis.differentiation.JacobianFunction;
import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableVectorFunction;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.optimization.ConvergenceChecker;
import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer;
import org.apache.commons.math3.optimization.PointVectorValuePair;
@ -75,7 +76,7 @@ public abstract class AbstractLeastSquaresOptimizer
/** Cost value (square root of the sum of the residuals). */
protected double cost;
/** Objective function derivatives. */
private MultivariateMatrixFunction jF;
private MultivariateDifferentiableVectorFunction jF;
/** Number of evaluations of the Jacobian. */
private int jacobianEvaluations;
@ -110,9 +111,22 @@ public abstract class AbstractLeastSquaresOptimizer
*/
protected void updateJacobian() {
++jacobianEvaluations;
weightedResidualJacobian = jF.value(point);
if (weightedResidualJacobian.length != rows) {
throw new DimensionMismatchException(weightedResidualJacobian.length, rows);
DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
for (int i = 0; i < point.length; ++i) {
dsPoint[i] = new DerivativeStructure(point.length, 1, i, point[i]);
}
DerivativeStructure[] dsValue = jF.value(dsPoint);
if (dsValue.length != rows) {
throw new DimensionMismatchException(dsValue.length, rows);
}
for (int i = 0; i < rows; ++i) {
int[] orders = new int[point.length];
for (int j = 0; j < point.length; ++j) {
orders[j] = 1;
weightedResidualJacobian[i][j] = dsValue[i].getPartialDerivative(orders);
orders[j] = 0;
}
}
final double[] residualsWeights = getWeightRef();
@ -303,23 +317,8 @@ public abstract class AbstractLeastSquaresOptimizer
final DifferentiableMultivariateVectorFunction f,
final double[] target, final double[] weights,
final double[] startPoint) {
// Reset counter.
jacobianEvaluations = 0;
// Store least squares problem characteristics.
jF = f.jacobian();
// Arrays shared with the other private methods.
point = startPoint.clone();
rows = target.length;
cols = point.length;
weightedResidualJacobian = new double[rows][cols];
this.weightedResiduals = new double[rows];
cost = Double.POSITIVE_INFINITY;
return optimizeInternal(maxEval, f, target, weights, startPoint);
return optimize(maxEval, FunctionUtils.toMultivariateDifferentiableVectorFunction(f),
target, weights, startPoint);
}
/**
@ -351,7 +350,7 @@ public abstract class AbstractLeastSquaresOptimizer
jacobianEvaluations = 0;
// Store least squares problem characteristics.
jF = new JacobianFunction(f);
jF = f;
// Arrays shared with the other private methods.
point = startPoint.clone();