Value and Jacobian evaluated in a single method.
A new interface MultivariateJacobianFunction lets the function value and Jacobian be evaluated at the same time. This saves the user from having to cache the result between calls to get the value and the jacobian. A factory method was added to create LeastSquaresProblems from the new interface. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569346 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
4158f97463
commit
fc9cb0ce16
|
@ -2,6 +2,8 @@ package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
||||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||||
|
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||||
|
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||||
import org.apache.commons.math3.linear.DiagonalMatrix;
|
import org.apache.commons.math3.linear.DiagonalMatrix;
|
||||||
import org.apache.commons.math3.linear.EigenDecomposition;
|
import org.apache.commons.math3.linear.EigenDecomposition;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
|
@ -10,6 +12,7 @@ import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
import org.apache.commons.math3.util.Incrementor;
|
import org.apache.commons.math3.util.Incrementor;
|
||||||
|
import org.apache.commons.math3.util.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A Factory for creating {@link LeastSquaresProblem}s.
|
* A Factory for creating {@link LeastSquaresProblem}s.
|
||||||
|
@ -22,6 +25,34 @@ public class LeastSquaresFactory {
|
||||||
private LeastSquaresFactory() {
|
private LeastSquaresFactory() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem}
|
||||||
|
* from the given elements. There will be no weights applied (Identity weights).
|
||||||
|
*
|
||||||
|
* @param model the model function. Produces the computed values.
|
||||||
|
* @param observed the observed (target) values
|
||||||
|
* @param start the initial guess.
|
||||||
|
* @param checker convergence checker
|
||||||
|
* @param maxEvaluations the maximum number of times to evaluate the model
|
||||||
|
* @param maxIterations the maximum number to times to iterate in the algorithm
|
||||||
|
* @return the specified General Least Squares problem.
|
||||||
|
*/
|
||||||
|
public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
|
||||||
|
final double[] observed,
|
||||||
|
final double[] start,
|
||||||
|
final ConvergenceChecker<PointVectorValuePair> checker,
|
||||||
|
final int maxEvaluations,
|
||||||
|
final int maxIterations) {
|
||||||
|
return new LeastSquaresProblemImpl(
|
||||||
|
model,
|
||||||
|
observed,
|
||||||
|
start,
|
||||||
|
checker,
|
||||||
|
maxEvaluations,
|
||||||
|
maxIterations
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem}
|
* Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem}
|
||||||
* from the given elements. There will be no weights applied (Identity weights).
|
* from the given elements. There will be no weights applied (Identity weights).
|
||||||
|
@ -42,9 +73,8 @@ public class LeastSquaresFactory {
|
||||||
final ConvergenceChecker<PointVectorValuePair> checker,
|
final ConvergenceChecker<PointVectorValuePair> checker,
|
||||||
final int maxEvaluations,
|
final int maxEvaluations,
|
||||||
final int maxIterations) {
|
final int maxIterations) {
|
||||||
return new LeastSquaresProblemImpl(
|
return create(
|
||||||
model,
|
combine(model, jacobian),
|
||||||
jacobian,
|
|
||||||
observed,
|
observed,
|
||||||
start,
|
start,
|
||||||
checker,
|
checker,
|
||||||
|
@ -163,5 +193,27 @@ public class LeastSquaresFactory {
|
||||||
return dec.getSquareRoot();
|
return dec.getSquareRoot();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Combine a {@link MultivariateVectorFunction} with a {@link
|
||||||
|
* MultivariateMatrixFunction} to produce a {@link MultivariateJacobianFunction}.
|
||||||
|
*
|
||||||
|
* @param value the vector value function
|
||||||
|
* @param jacobian the Jacobian function
|
||||||
|
* @return a function that computes both at the same time
|
||||||
|
*/
|
||||||
|
private static MultivariateJacobianFunction combine(
|
||||||
|
final MultivariateVectorFunction value,
|
||||||
|
final MultivariateMatrixFunction jacobian
|
||||||
|
) {
|
||||||
|
return new MultivariateJacobianFunction() {
|
||||||
|
public Pair<RealVector, RealMatrix> value(final double[] point) {
|
||||||
|
//evaluate and use Real* interfaces without copying
|
||||||
|
return new Pair<RealVector, RealMatrix>(
|
||||||
|
new ArrayRealVector(value.value(point), false),
|
||||||
|
new Array2DRowRealMatrix(jacobian.value(point), false));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,14 +16,13 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.commons.math3.fitting.leastsquares;
|
package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
|
||||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
|
||||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||||
import org.apache.commons.math3.linear.MatrixUtils;
|
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
import org.apache.commons.math3.optim.AbstractOptimizationProblem;
|
import org.apache.commons.math3.optim.AbstractOptimizationProblem;
|
||||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||||
|
import org.apache.commons.math3.util.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A private, "field" immutable (not "real" immutable) implementation of {@link
|
* A private, "field" immutable (not "real" immutable) implementation of {@link
|
||||||
|
@ -39,14 +38,11 @@ class LeastSquaresProblemImpl
|
||||||
/** Target values for the model function at optimum. */
|
/** Target values for the model function at optimum. */
|
||||||
private double[] target;
|
private double[] target;
|
||||||
/** Model function. */
|
/** Model function. */
|
||||||
private MultivariateVectorFunction model;
|
private MultivariateJacobianFunction model;
|
||||||
/** Jacobian of the model function. */
|
|
||||||
private MultivariateMatrixFunction jacobian;
|
|
||||||
/** Initial guess. */
|
/** Initial guess. */
|
||||||
private double[] start;
|
private double[] start;
|
||||||
|
|
||||||
LeastSquaresProblemImpl(final MultivariateVectorFunction model,
|
LeastSquaresProblemImpl(final MultivariateJacobianFunction model,
|
||||||
final MultivariateMatrixFunction jacobian,
|
|
||||||
final double[] target,
|
final double[] target,
|
||||||
final double[] start,
|
final double[] start,
|
||||||
final ConvergenceChecker<PointVectorValuePair> checker,
|
final ConvergenceChecker<PointVectorValuePair> checker,
|
||||||
|
@ -55,7 +51,6 @@ class LeastSquaresProblemImpl
|
||||||
super(maxEvaluations, maxIterations, checker);
|
super(maxEvaluations, maxIterations, checker);
|
||||||
this.target = target;
|
this.target = target;
|
||||||
this.model = model;
|
this.model = model;
|
||||||
this.jacobian = jacobian;
|
|
||||||
this.start = start;
|
this.start = start;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,10 +67,11 @@ class LeastSquaresProblemImpl
|
||||||
}
|
}
|
||||||
|
|
||||||
public Evaluation evaluate(final double[] point) {
|
public Evaluation evaluate(final double[] point) {
|
||||||
//TODO evaluate value and jacobian in one function call
|
//evaluate value and jacobian in one function call
|
||||||
|
final Pair<RealVector, RealMatrix> value = this.model.value(point);
|
||||||
return new UnweightedEvaluation(
|
return new UnweightedEvaluation(
|
||||||
this.model.value(point),
|
value.getFirst(),
|
||||||
this.jacobian.value(point),
|
value.getSecond(),
|
||||||
this.target,
|
this.target,
|
||||||
point);
|
point);
|
||||||
}
|
}
|
||||||
|
@ -90,14 +86,14 @@ class LeastSquaresProblemImpl
|
||||||
/** the point of evaluation */
|
/** the point of evaluation */
|
||||||
private final double[] point;
|
private final double[] point;
|
||||||
/** value at point */
|
/** value at point */
|
||||||
private final double[] values;
|
private final RealVector values;
|
||||||
/** deriviative at point */
|
/** deriviative at point */
|
||||||
private final double[][] jacobian;
|
private final RealMatrix jacobian;
|
||||||
/** reference to the observed values */
|
/** reference to the observed values */
|
||||||
private final double[] target;
|
private final double[] target;
|
||||||
|
|
||||||
private UnweightedEvaluation(final double[] values,
|
private UnweightedEvaluation(final RealVector values,
|
||||||
final double[][] jacobian,
|
final RealMatrix jacobian,
|
||||||
final double[] target,
|
final double[] target,
|
||||||
final double[] point) {
|
final double[] point) {
|
||||||
super(target.length);
|
super(target.length);
|
||||||
|
@ -109,14 +105,13 @@ class LeastSquaresProblemImpl
|
||||||
|
|
||||||
|
|
||||||
public double[] computeValue() {
|
public double[] computeValue() {
|
||||||
return this.values;
|
return this.values.toArray();
|
||||||
}
|
}
|
||||||
|
|
||||||
public RealMatrix computeJacobian() {
|
public RealMatrix computeJacobian() {
|
||||||
return MatrixUtils.createRealMatrix(this.jacobian);
|
return this.jacobian;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public double[] getPoint() {
|
public double[] getPoint() {
|
||||||
return this.point;
|
return this.point;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
|
import org.apache.commons.math3.util.Pair;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A interface for functions that compute a vector of values and can compute their
|
||||||
|
* derivatives (Jacobian).
|
||||||
|
*
|
||||||
|
* @version $Id$
|
||||||
|
*/
|
||||||
|
public interface MultivariateJacobianFunction {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the function value and its Jacobian.
|
||||||
|
*
|
||||||
|
* @param point the abscissae
|
||||||
|
* @return the values and their Jacobian of this vector valued function.
|
||||||
|
*/
|
||||||
|
Pair<RealVector, RealMatrix> value(double[] point);
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue