diff --git a/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java b/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java index bc1363159..3af449ce2 100644 --- a/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java +++ b/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java @@ -17,16 +17,17 @@ package org.apache.commons.math3.optimization.general; -import org.apache.commons.math3.exception.NumberIsTooSmallException; -import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.analysis.DifferentiableMultivariateVectorFunction; -import org.apache.commons.math3.analysis.MultivariateMatrixFunction; +import org.apache.commons.math3.analysis.FunctionUtils; +import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; import org.apache.commons.math3.analysis.differentiation.JacobianFunction; import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableVectorFunction; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; import org.apache.commons.math3.exception.util.LocalizedFormats; -import org.apache.commons.math3.linear.QRDecomposition; import org.apache.commons.math3.linear.DecompositionSolver; import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.QRDecomposition; import org.apache.commons.math3.optimization.ConvergenceChecker; import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer; import org.apache.commons.math3.optimization.PointVectorValuePair; @@ -75,7 +76,7 @@ public abstract class AbstractLeastSquaresOptimizer /** Cost value (square root of the sum of the residuals). */ protected double cost; /** Objective function derivatives. */ - private MultivariateMatrixFunction jF; + private MultivariateDifferentiableVectorFunction jF; /** Number of evaluations of the Jacobian. */ private int jacobianEvaluations; @@ -110,9 +111,22 @@ public abstract class AbstractLeastSquaresOptimizer */ protected void updateJacobian() { ++jacobianEvaluations; - weightedResidualJacobian = jF.value(point); - if (weightedResidualJacobian.length != rows) { - throw new DimensionMismatchException(weightedResidualJacobian.length, rows); + + DerivativeStructure[] dsPoint = new DerivativeStructure[point.length]; + for (int i = 0; i < point.length; ++i) { + dsPoint[i] = new DerivativeStructure(point.length, 1, i, point[i]); + } + DerivativeStructure[] dsValue = jF.value(dsPoint); + if (dsValue.length != rows) { + throw new DimensionMismatchException(dsValue.length, rows); + } + for (int i = 0; i < rows; ++i) { + int[] orders = new int[point.length]; + for (int j = 0; j < point.length; ++j) { + orders[j] = 1; + weightedResidualJacobian[i][j] = dsValue[i].getPartialDerivative(orders); + orders[j] = 0; + } } final double[] residualsWeights = getWeightRef(); @@ -303,23 +317,8 @@ public abstract class AbstractLeastSquaresOptimizer final DifferentiableMultivariateVectorFunction f, final double[] target, final double[] weights, final double[] startPoint) { - // Reset counter. - jacobianEvaluations = 0; - - // Store least squares problem characteristics. - jF = f.jacobian(); - - // Arrays shared with the other private methods. - point = startPoint.clone(); - rows = target.length; - cols = point.length; - - weightedResidualJacobian = new double[rows][cols]; - this.weightedResiduals = new double[rows]; - - cost = Double.POSITIVE_INFINITY; - - return optimizeInternal(maxEval, f, target, weights, startPoint); + return optimize(maxEval, FunctionUtils.toMultivariateDifferentiableVectorFunction(f), + target, weights, startPoint); } /** @@ -351,7 +350,7 @@ public abstract class AbstractLeastSquaresOptimizer jacobianEvaluations = 0; // Store least squares problem characteristics. - jF = new JacobianFunction(f); + jF = f; // Arrays shared with the other private methods. point = startPoint.clone();