diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java index 5cb0b847f..073ff4d26 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java @@ -2,6 +2,8 @@ package org.apache.commons.math3.fitting.leastsquares; import org.apache.commons.math3.analysis.MultivariateMatrixFunction; import org.apache.commons.math3.analysis.MultivariateVectorFunction; +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.DiagonalMatrix; import org.apache.commons.math3.linear.EigenDecomposition; import org.apache.commons.math3.linear.RealMatrix; @@ -10,6 +12,7 @@ import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.optim.PointVectorValuePair; import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.Incrementor; +import org.apache.commons.math3.util.Pair; /** * A Factory for creating {@link LeastSquaresProblem}s. @@ -22,6 +25,34 @@ public class LeastSquaresFactory { private LeastSquaresFactory() { } + /** + * Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem} + * from the given elements. There will be no weights applied (Identity weights). + * + * @param model the model function. Produces the computed values. + * @param observed the observed (target) values + * @param start the initial guess. + * @param checker convergence checker + * @param maxEvaluations the maximum number of times to evaluate the model + * @param maxIterations the maximum number to times to iterate in the algorithm + * @return the specified General Least Squares problem. + */ + public static LeastSquaresProblem create(final MultivariateJacobianFunction model, + final double[] observed, + final double[] start, + final ConvergenceChecker checker, + final int maxEvaluations, + final int maxIterations) { + return new LeastSquaresProblemImpl( + model, + observed, + start, + checker, + maxEvaluations, + maxIterations + ); + } + /** * Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem} * from the given elements. There will be no weights applied (Identity weights). @@ -42,9 +73,8 @@ public class LeastSquaresFactory { final ConvergenceChecker checker, final int maxEvaluations, final int maxIterations) { - return new LeastSquaresProblemImpl( - model, - jacobian, + return create( + combine(model, jacobian), observed, start, checker, @@ -163,5 +193,27 @@ public class LeastSquaresFactory { return dec.getSquareRoot(); } } + + /** + * Combine a {@link MultivariateVectorFunction} with a {@link + * MultivariateMatrixFunction} to produce a {@link MultivariateJacobianFunction}. + * + * @param value the vector value function + * @param jacobian the Jacobian function + * @return a function that computes both at the same time + */ + private static MultivariateJacobianFunction combine( + final MultivariateVectorFunction value, + final MultivariateMatrixFunction jacobian + ) { + return new MultivariateJacobianFunction() { + public Pair value(final double[] point) { + //evaluate and use Real* interfaces without copying + return new Pair( + new ArrayRealVector(value.value(point), false), + new Array2DRowRealMatrix(jacobian.value(point), false)); + } + }; + } } diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java index 1cf323c16..15e306705 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java @@ -16,14 +16,13 @@ */ package org.apache.commons.math3.fitting.leastsquares; -import org.apache.commons.math3.analysis.MultivariateMatrixFunction; -import org.apache.commons.math3.analysis.MultivariateVectorFunction; import org.apache.commons.math3.exception.DimensionMismatchException; -import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.optim.AbstractOptimizationProblem; import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.optim.PointVectorValuePair; +import org.apache.commons.math3.util.Pair; /** * A private, "field" immutable (not "real" immutable) implementation of {@link @@ -39,14 +38,11 @@ class LeastSquaresProblemImpl /** Target values for the model function at optimum. */ private double[] target; /** Model function. */ - private MultivariateVectorFunction model; - /** Jacobian of the model function. */ - private MultivariateMatrixFunction jacobian; + private MultivariateJacobianFunction model; /** Initial guess. */ private double[] start; - LeastSquaresProblemImpl(final MultivariateVectorFunction model, - final MultivariateMatrixFunction jacobian, + LeastSquaresProblemImpl(final MultivariateJacobianFunction model, final double[] target, final double[] start, final ConvergenceChecker checker, @@ -55,7 +51,6 @@ class LeastSquaresProblemImpl super(maxEvaluations, maxIterations, checker); this.target = target; this.model = model; - this.jacobian = jacobian; this.start = start; } @@ -72,10 +67,11 @@ class LeastSquaresProblemImpl } public Evaluation evaluate(final double[] point) { - //TODO evaluate value and jacobian in one function call + //evaluate value and jacobian in one function call + final Pair value = this.model.value(point); return new UnweightedEvaluation( - this.model.value(point), - this.jacobian.value(point), + value.getFirst(), + value.getSecond(), this.target, point); } @@ -90,14 +86,14 @@ class LeastSquaresProblemImpl /** the point of evaluation */ private final double[] point; /** value at point */ - private final double[] values; + private final RealVector values; /** deriviative at point */ - private final double[][] jacobian; + private final RealMatrix jacobian; /** reference to the observed values */ private final double[] target; - private UnweightedEvaluation(final double[] values, - final double[][] jacobian, + private UnweightedEvaluation(final RealVector values, + final RealMatrix jacobian, final double[] target, final double[] point) { super(target.length); @@ -109,14 +105,13 @@ class LeastSquaresProblemImpl public double[] computeValue() { - return this.values; + return this.values.toArray(); } public RealMatrix computeJacobian() { - return MatrixUtils.createRealMatrix(this.jacobian); + return this.jacobian; } - public double[] getPoint() { return this.point; } diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/MultivariateJacobianFunction.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/MultivariateJacobianFunction.java new file mode 100644 index 000000000..c79942c14 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/MultivariateJacobianFunction.java @@ -0,0 +1,23 @@ +package org.apache.commons.math3.fitting.leastsquares; + +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; +import org.apache.commons.math3.util.Pair; + +/** + * A interface for functions that compute a vector of values and can compute their + * derivatives (Jacobian). + * + * @version $Id$ + */ +public interface MultivariateJacobianFunction { + + /** + * Compute the function value and its Jacobian. + * + * @param point the abscissae + * @return the values and their Jacobian of this vector valued function. + */ + Pair value(double[] point); + +}