From fc9cb0ce163e54e4bba232d2bbf2f79dc495585d Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Tue, 18 Feb 2014 14:32:06 +0000 Subject: [PATCH] Value and Jacobian evaluated in a single method. A new interface MultivariateJacobianFunction lets the function value and Jacobian be evaluated at the same time. This saves the user from having to cache the result between calls to get the value and the jacobian. A factory method was added to create LeastSquaresProblems from the new interface. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569346 13f79535-47bb-0310-9956-ffa450edef68 --- .../leastsquares/LeastSquaresFactory.java | 58 ++++++++++++++++++- .../leastsquares/LeastSquaresProblemImpl.java | 33 +++++------ .../MultivariateJacobianFunction.java | 23 ++++++++ 3 files changed, 92 insertions(+), 22 deletions(-) create mode 100644 src/main/java/org/apache/commons/math3/fitting/leastsquares/MultivariateJacobianFunction.java diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java index 5cb0b847f..073ff4d26 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java @@ -2,6 +2,8 @@ package org.apache.commons.math3.fitting.leastsquares; import org.apache.commons.math3.analysis.MultivariateMatrixFunction; import org.apache.commons.math3.analysis.MultivariateVectorFunction; +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.DiagonalMatrix; import org.apache.commons.math3.linear.EigenDecomposition; import org.apache.commons.math3.linear.RealMatrix; @@ -10,6 +12,7 @@ import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.optim.PointVectorValuePair; import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.Incrementor; +import org.apache.commons.math3.util.Pair; /** * A Factory for creating {@link LeastSquaresProblem}s. @@ -22,6 +25,34 @@ public class LeastSquaresFactory { private LeastSquaresFactory() { } + /** + * Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem} + * from the given elements. There will be no weights applied (Identity weights). + * + * @param model the model function. Produces the computed values. + * @param observed the observed (target) values + * @param start the initial guess. + * @param checker convergence checker + * @param maxEvaluations the maximum number of times to evaluate the model + * @param maxIterations the maximum number to times to iterate in the algorithm + * @return the specified General Least Squares problem. + */ + public static LeastSquaresProblem create(final MultivariateJacobianFunction model, + final double[] observed, + final double[] start, + final ConvergenceChecker checker, + final int maxEvaluations, + final int maxIterations) { + return new LeastSquaresProblemImpl( + model, + observed, + start, + checker, + maxEvaluations, + maxIterations + ); + } + /** * Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem} * from the given elements. There will be no weights applied (Identity weights). @@ -42,9 +73,8 @@ public class LeastSquaresFactory { final ConvergenceChecker checker, final int maxEvaluations, final int maxIterations) { - return new LeastSquaresProblemImpl( - model, - jacobian, + return create( + combine(model, jacobian), observed, start, checker, @@ -163,5 +193,27 @@ public class LeastSquaresFactory { return dec.getSquareRoot(); } } + + /** + * Combine a {@link MultivariateVectorFunction} with a {@link + * MultivariateMatrixFunction} to produce a {@link MultivariateJacobianFunction}. + * + * @param value the vector value function + * @param jacobian the Jacobian function + * @return a function that computes both at the same time + */ + private static MultivariateJacobianFunction combine( + final MultivariateVectorFunction value, + final MultivariateMatrixFunction jacobian + ) { + return new MultivariateJacobianFunction() { + public Pair value(final double[] point) { + //evaluate and use Real* interfaces without copying + return new Pair( + new ArrayRealVector(value.value(point), false), + new Array2DRowRealMatrix(jacobian.value(point), false)); + } + }; + } } diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java index 1cf323c16..15e306705 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java @@ -16,14 +16,13 @@ */ package org.apache.commons.math3.fitting.leastsquares; -import org.apache.commons.math3.analysis.MultivariateMatrixFunction; -import org.apache.commons.math3.analysis.MultivariateVectorFunction; import org.apache.commons.math3.exception.DimensionMismatchException; -import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.optim.AbstractOptimizationProblem; import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.optim.PointVectorValuePair; +import org.apache.commons.math3.util.Pair; /** * A private, "field" immutable (not "real" immutable) implementation of {@link @@ -39,14 +38,11 @@ class LeastSquaresProblemImpl /** Target values for the model function at optimum. */ private double[] target; /** Model function. */ - private MultivariateVectorFunction model; - /** Jacobian of the model function. */ - private MultivariateMatrixFunction jacobian; + private MultivariateJacobianFunction model; /** Initial guess. */ private double[] start; - LeastSquaresProblemImpl(final MultivariateVectorFunction model, - final MultivariateMatrixFunction jacobian, + LeastSquaresProblemImpl(final MultivariateJacobianFunction model, final double[] target, final double[] start, final ConvergenceChecker checker, @@ -55,7 +51,6 @@ class LeastSquaresProblemImpl super(maxEvaluations, maxIterations, checker); this.target = target; this.model = model; - this.jacobian = jacobian; this.start = start; } @@ -72,10 +67,11 @@ class LeastSquaresProblemImpl } public Evaluation evaluate(final double[] point) { - //TODO evaluate value and jacobian in one function call + //evaluate value and jacobian in one function call + final Pair value = this.model.value(point); return new UnweightedEvaluation( - this.model.value(point), - this.jacobian.value(point), + value.getFirst(), + value.getSecond(), this.target, point); } @@ -90,14 +86,14 @@ class LeastSquaresProblemImpl /** the point of evaluation */ private final double[] point; /** value at point */ - private final double[] values; + private final RealVector values; /** deriviative at point */ - private final double[][] jacobian; + private final RealMatrix jacobian; /** reference to the observed values */ private final double[] target; - private UnweightedEvaluation(final double[] values, - final double[][] jacobian, + private UnweightedEvaluation(final RealVector values, + final RealMatrix jacobian, final double[] target, final double[] point) { super(target.length); @@ -109,14 +105,13 @@ class LeastSquaresProblemImpl public double[] computeValue() { - return this.values; + return this.values.toArray(); } public RealMatrix computeJacobian() { - return MatrixUtils.createRealMatrix(this.jacobian); + return this.jacobian; } - public double[] getPoint() { return this.point; } diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/MultivariateJacobianFunction.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/MultivariateJacobianFunction.java new file mode 100644 index 000000000..c79942c14 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/MultivariateJacobianFunction.java @@ -0,0 +1,23 @@ +package org.apache.commons.math3.fitting.leastsquares; + +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; +import org.apache.commons.math3.util.Pair; + +/** + * A interface for functions that compute a vector of values and can compute their + * derivatives (Jacobian). + * + * @version $Id$ + */ +public interface MultivariateJacobianFunction { + + /** + * Compute the function value and its Jacobian. + * + * @param point the abscissae + * @return the values and their Jacobian of this vector valued function. + */ + Pair value(double[] point); + +}