Value and Jacobian evaluated in a single method.
A new interface MultivariateJacobianFunction lets the function value and Jacobian be evaluated at the same time. This saves the user from having to cache the result between calls to get the value and the jacobian. A factory method was added to create LeastSquaresProblems from the new interface. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569346 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
4158f97463
commit
fc9cb0ce16
|
@ -2,6 +2,8 @@ package org.apache.commons.math3.fitting.leastsquares;
|
|||
|
||||
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.commons.math3.linear.DiagonalMatrix;
|
||||
import org.apache.commons.math3.linear.EigenDecomposition;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
|
@ -10,6 +12,7 @@ import org.apache.commons.math3.optim.ConvergenceChecker;
|
|||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.util.Incrementor;
|
||||
import org.apache.commons.math3.util.Pair;
|
||||
|
||||
/**
|
||||
* A Factory for creating {@link LeastSquaresProblem}s.
|
||||
|
@ -22,6 +25,34 @@ public class LeastSquaresFactory {
|
|||
private LeastSquaresFactory() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem}
|
||||
* from the given elements. There will be no weights applied (Identity weights).
|
||||
*
|
||||
* @param model the model function. Produces the computed values.
|
||||
* @param observed the observed (target) values
|
||||
* @param start the initial guess.
|
||||
* @param checker convergence checker
|
||||
* @param maxEvaluations the maximum number of times to evaluate the model
|
||||
* @param maxIterations the maximum number to times to iterate in the algorithm
|
||||
* @return the specified General Least Squares problem.
|
||||
*/
|
||||
public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
|
||||
final double[] observed,
|
||||
final double[] start,
|
||||
final ConvergenceChecker<PointVectorValuePair> checker,
|
||||
final int maxEvaluations,
|
||||
final int maxIterations) {
|
||||
return new LeastSquaresProblemImpl(
|
||||
model,
|
||||
observed,
|
||||
start,
|
||||
checker,
|
||||
maxEvaluations,
|
||||
maxIterations
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem}
|
||||
* from the given elements. There will be no weights applied (Identity weights).
|
||||
|
@ -42,9 +73,8 @@ public class LeastSquaresFactory {
|
|||
final ConvergenceChecker<PointVectorValuePair> checker,
|
||||
final int maxEvaluations,
|
||||
final int maxIterations) {
|
||||
return new LeastSquaresProblemImpl(
|
||||
model,
|
||||
jacobian,
|
||||
return create(
|
||||
combine(model, jacobian),
|
||||
observed,
|
||||
start,
|
||||
checker,
|
||||
|
@ -163,5 +193,27 @@ public class LeastSquaresFactory {
|
|||
return dec.getSquareRoot();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Combine a {@link MultivariateVectorFunction} with a {@link
|
||||
* MultivariateMatrixFunction} to produce a {@link MultivariateJacobianFunction}.
|
||||
*
|
||||
* @param value the vector value function
|
||||
* @param jacobian the Jacobian function
|
||||
* @return a function that computes both at the same time
|
||||
*/
|
||||
private static MultivariateJacobianFunction combine(
|
||||
final MultivariateVectorFunction value,
|
||||
final MultivariateMatrixFunction jacobian
|
||||
) {
|
||||
return new MultivariateJacobianFunction() {
|
||||
public Pair<RealVector, RealMatrix> value(final double[] point) {
|
||||
//evaluate and use Real* interfaces without copying
|
||||
return new Pair<RealVector, RealMatrix>(
|
||||
new ArrayRealVector(value.value(point), false),
|
||||
new Array2DRowRealMatrix(jacobian.value(point), false));
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,14 +16,13 @@
|
|||
*/
|
||||
package org.apache.commons.math3.fitting.leastsquares;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.linear.MatrixUtils;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.RealVector;
|
||||
import org.apache.commons.math3.optim.AbstractOptimizationProblem;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||
import org.apache.commons.math3.util.Pair;
|
||||
|
||||
/**
|
||||
* A private, "field" immutable (not "real" immutable) implementation of {@link
|
||||
|
@ -39,14 +38,11 @@ class LeastSquaresProblemImpl
|
|||
/** Target values for the model function at optimum. */
|
||||
private double[] target;
|
||||
/** Model function. */
|
||||
private MultivariateVectorFunction model;
|
||||
/** Jacobian of the model function. */
|
||||
private MultivariateMatrixFunction jacobian;
|
||||
private MultivariateJacobianFunction model;
|
||||
/** Initial guess. */
|
||||
private double[] start;
|
||||
|
||||
LeastSquaresProblemImpl(final MultivariateVectorFunction model,
|
||||
final MultivariateMatrixFunction jacobian,
|
||||
LeastSquaresProblemImpl(final MultivariateJacobianFunction model,
|
||||
final double[] target,
|
||||
final double[] start,
|
||||
final ConvergenceChecker<PointVectorValuePair> checker,
|
||||
|
@ -55,7 +51,6 @@ class LeastSquaresProblemImpl
|
|||
super(maxEvaluations, maxIterations, checker);
|
||||
this.target = target;
|
||||
this.model = model;
|
||||
this.jacobian = jacobian;
|
||||
this.start = start;
|
||||
}
|
||||
|
||||
|
@ -72,10 +67,11 @@ class LeastSquaresProblemImpl
|
|||
}
|
||||
|
||||
public Evaluation evaluate(final double[] point) {
|
||||
//TODO evaluate value and jacobian in one function call
|
||||
//evaluate value and jacobian in one function call
|
||||
final Pair<RealVector, RealMatrix> value = this.model.value(point);
|
||||
return new UnweightedEvaluation(
|
||||
this.model.value(point),
|
||||
this.jacobian.value(point),
|
||||
value.getFirst(),
|
||||
value.getSecond(),
|
||||
this.target,
|
||||
point);
|
||||
}
|
||||
|
@ -90,14 +86,14 @@ class LeastSquaresProblemImpl
|
|||
/** the point of evaluation */
|
||||
private final double[] point;
|
||||
/** value at point */
|
||||
private final double[] values;
|
||||
private final RealVector values;
|
||||
/** deriviative at point */
|
||||
private final double[][] jacobian;
|
||||
private final RealMatrix jacobian;
|
||||
/** reference to the observed values */
|
||||
private final double[] target;
|
||||
|
||||
private UnweightedEvaluation(final double[] values,
|
||||
final double[][] jacobian,
|
||||
private UnweightedEvaluation(final RealVector values,
|
||||
final RealMatrix jacobian,
|
||||
final double[] target,
|
||||
final double[] point) {
|
||||
super(target.length);
|
||||
|
@ -109,14 +105,13 @@ class LeastSquaresProblemImpl
|
|||
|
||||
|
||||
public double[] computeValue() {
|
||||
return this.values;
|
||||
return this.values.toArray();
|
||||
}
|
||||
|
||||
public RealMatrix computeJacobian() {
|
||||
return MatrixUtils.createRealMatrix(this.jacobian);
|
||||
return this.jacobian;
|
||||
}
|
||||
|
||||
|
||||
public double[] getPoint() {
|
||||
return this.point;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
package org.apache.commons.math3.fitting.leastsquares;
|
||||
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.RealVector;
|
||||
import org.apache.commons.math3.util.Pair;
|
||||
|
||||
/**
|
||||
* A interface for functions that compute a vector of values and can compute their
|
||||
* derivatives (Jacobian).
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public interface MultivariateJacobianFunction {
|
||||
|
||||
/**
|
||||
* Compute the function value and its Jacobian.
|
||||
*
|
||||
* @param point the abscissae
|
||||
* @return the values and their Jacobian of this vector valued function.
|
||||
*/
|
||||
Pair<RealVector, RealMatrix> value(double[] point);
|
||||
|
||||
}
|
Loading…
Reference in New Issue