Clean up LeastSquaresBuilder
Provide methods for using old and new interfaces. Data is stored internally using the new interfaces now. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569356 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
4dcae270e6
commit
8916830e8a
|
@ -3,7 +3,9 @@ package org.apache.commons.math3.fitting.leastsquares;
|
|||
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.RealVector;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||
|
||||
|
@ -22,13 +24,11 @@ public class LeastSquaresBuilder {
|
|||
/** convergence checker */
|
||||
private ConvergenceChecker<Evaluation> checker;
|
||||
/** model function */
|
||||
private MultivariateVectorFunction model;
|
||||
/** Jacobian function */
|
||||
private MultivariateMatrixFunction jacobian;
|
||||
private MultivariateJacobianFunction model;
|
||||
/** observed values */
|
||||
private double[] target;
|
||||
private RealVector target;
|
||||
/** initial guess */
|
||||
private double[] start;
|
||||
private RealVector start;
|
||||
/** weight matrix */
|
||||
private RealMatrix weight;
|
||||
|
||||
|
@ -39,7 +39,7 @@ public class LeastSquaresBuilder {
|
|||
* @return a new {@link LeastSquaresProblem}.
|
||||
*/
|
||||
public LeastSquaresProblem build() {
|
||||
return LeastSquaresFactory.create(model, jacobian, target, start, weight, checker, maxEvaluations, maxIterations);
|
||||
return LeastSquaresFactory.create(model, target, start, weight, checker, maxEvaluations, maxIterations);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -90,22 +90,34 @@ public class LeastSquaresBuilder {
|
|||
/**
|
||||
* Configure the model function.
|
||||
*
|
||||
* @param model the model function
|
||||
* @param value the model function value
|
||||
* @param jacobian the Jacobian of {@code value}
|
||||
* @return this
|
||||
*/
|
||||
public LeastSquaresBuilder model(final MultivariateVectorFunction model) {
|
||||
public LeastSquaresBuilder model(final MultivariateVectorFunction value,
|
||||
final MultivariateMatrixFunction jacobian) {
|
||||
return model(LeastSquaresFactory.model(value, jacobian));
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure the model function.
|
||||
*
|
||||
* @param model the model function value and Jacobian
|
||||
* @return this
|
||||
*/
|
||||
public LeastSquaresBuilder model(final MultivariateJacobianFunction model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure the Jacobian function.
|
||||
* Configure the observed data.
|
||||
*
|
||||
* @param jacobian the Jacobian function
|
||||
* @param target the observed data.
|
||||
* @return this
|
||||
*/
|
||||
public LeastSquaresBuilder jacobian(final MultivariateMatrixFunction jacobian) {
|
||||
this.jacobian = jacobian;
|
||||
public LeastSquaresBuilder target(final RealVector target) {
|
||||
this.target = target;
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -116,7 +128,17 @@ public class LeastSquaresBuilder {
|
|||
* @return this
|
||||
*/
|
||||
public LeastSquaresBuilder target(final double[] target) {
|
||||
this.target = target;
|
||||
return target(new ArrayRealVector(target, false));
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure the initial guess.
|
||||
*
|
||||
* @param start the initial guess.
|
||||
* @return this
|
||||
*/
|
||||
public LeastSquaresBuilder start(final RealVector start) {
|
||||
this.start = start;
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -127,8 +149,7 @@ public class LeastSquaresBuilder {
|
|||
* @return this
|
||||
*/
|
||||
public LeastSquaresBuilder start(final double[] start) {
|
||||
this.start = start;
|
||||
return this;
|
||||
return start(new ArrayRealVector(start, false));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -22,12 +22,15 @@ import org.apache.commons.math3.exception.ConvergenceException;
|
|||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
|
||||
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.commons.math3.linear.BlockRealMatrix;
|
||||
import org.apache.commons.math3.linear.DiagonalMatrix;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.RealVector;
|
||||
import org.apache.commons.math3.optim.SimpleVectorValueChecker;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.util.Pair;
|
||||
import org.junit.Assert;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -59,8 +62,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
|||
final double[] weights = new double[c.getN()];
|
||||
Arrays.fill(weights, 1.0);
|
||||
return base()
|
||||
.model(c.getModelFunction())
|
||||
.jacobian(c.getModelFunctionJacobian())
|
||||
.model(c.getModelFunction(), c.getModelFunctionJacobian())
|
||||
.target(new double[c.getN()])
|
||||
.weight(new DiagonalMatrix(weights));
|
||||
}
|
||||
|
@ -71,8 +73,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
|||
final double[] weights = new double[dataset.getNumObservations()];
|
||||
Arrays.fill(weights, 1.0);
|
||||
return base()
|
||||
.model(problem.getModelFunction())
|
||||
.jacobian(problem.getModelFunctionJacobian())
|
||||
.model(problem.getModelFunction(), problem.getModelFunctionJacobian())
|
||||
.target(dataset.getData()[1])
|
||||
.weight(new DiagonalMatrix(weights))
|
||||
.start(dataset.getStartingPoint(0));
|
||||
|
@ -133,24 +134,22 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
|||
.target(new double[]{1})
|
||||
.weight(new DiagonalMatrix(new double[]{1}))
|
||||
.start(new double[]{3})
|
||||
.model(
|
||||
new MultivariateVectorFunction() {
|
||||
public double[] value(double[] point) {
|
||||
return new double[]{
|
||||
FastMath.pow(point[0], 4)
|
||||
};
|
||||
.model(new MultivariateJacobianFunction() {
|
||||
public Pair<RealVector, RealMatrix> value(final RealVector point) {
|
||||
return new Pair<RealVector, RealMatrix>(
|
||||
new ArrayRealVector(
|
||||
new double[]{
|
||||
FastMath.pow(point.getEntry(0), 4)
|
||||
},
|
||||
false),
|
||||
new Array2DRowRealMatrix(
|
||||
new double[][]{
|
||||
{0.25 * FastMath.pow(point.getEntry(0), 3)}
|
||||
},
|
||||
false)
|
||||
);
|
||||
}
|
||||
}
|
||||
)
|
||||
.jacobian(
|
||||
new MultivariateMatrixFunction() {
|
||||
public double[][] value(double[] point) {
|
||||
return new double[][]{
|
||||
{0.25 * FastMath.pow(point[0], 3)}
|
||||
};
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
.build();
|
||||
|
||||
Optimum optimum = optimizer.optimize(lsp);
|
||||
|
@ -554,8 +553,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
|||
final double[] weights = new double[target.length];
|
||||
Arrays.fill(weights, 1.0);
|
||||
return base()
|
||||
.model(getModelFunction())
|
||||
.jacobian(getModelFunctionJacobian())
|
||||
.model(getModelFunction(), getModelFunctionJacobian())
|
||||
.target(target)
|
||||
.weight(new DiagonalMatrix(weights))
|
||||
.start(new double[factors.getColumnDimension()]);
|
||||
|
|
|
@ -44,8 +44,7 @@ public class EvaluationTest {
|
|||
Arrays.fill(weights, 1d);
|
||||
|
||||
return new LeastSquaresBuilder()
|
||||
.model(problem.getModelFunction())
|
||||
.jacobian(problem.getModelFunctionJacobian())
|
||||
.model(problem.getModelFunction(), problem.getModelFunctionJacobian())
|
||||
.target(observed)
|
||||
.weight(new DiagonalMatrix(weights))
|
||||
.start(start);
|
||||
|
|
|
@ -288,8 +288,7 @@ public class EvaluationTestValidation {
|
|||
|
||||
LeastSquaresBuilder builder(StraightLineProblem problem){
|
||||
return new LeastSquaresBuilder()
|
||||
.model(problem.getModelFunction())
|
||||
.jacobian(problem.getModelFunctionJacobian())
|
||||
.model(problem.getModelFunction(), problem.getModelFunctionJacobian())
|
||||
.target(problem.target())
|
||||
.weight(new DiagonalMatrix(problem.weight()))
|
||||
//unused start point to avoid NPE
|
||||
|
|
|
@ -49,14 +49,12 @@ public class LevenbergMarquardtOptimizerTest
|
|||
|
||||
public LeastSquaresBuilder builder(BevingtonProblem problem){
|
||||
return base()
|
||||
.model(problem.getModelFunction())
|
||||
.jacobian(problem.getModelFunctionJacobian());
|
||||
.model(problem.getModelFunction(), problem.getModelFunctionJacobian());
|
||||
}
|
||||
|
||||
public LeastSquaresBuilder builder(CircleProblem problem){
|
||||
return base()
|
||||
.model(problem.getModelFunction())
|
||||
.jacobian(problem.getModelFunctionJacobian())
|
||||
.model(problem.getModelFunction(), problem.getModelFunctionJacobian())
|
||||
.target(problem.target())
|
||||
.weight(new DiagonalMatrix(problem.weight()));
|
||||
}
|
||||
|
|
|
@ -507,8 +507,7 @@ public class MinpackTest {
|
|||
LeastSquaresProblem problem = new LeastSquaresBuilder()
|
||||
.maxEvaluations(400 * (function.getN() + 1))
|
||||
.maxIterations(2000)
|
||||
.model(function.getModelFunction())
|
||||
.jacobian(function.getModelFunctionJacobian())
|
||||
.model(function.getModelFunction(), function.getModelFunctionJacobian())
|
||||
.target(function.getTarget())
|
||||
.weight(new DiagonalMatrix(function.getWeight()))
|
||||
.start(function.getStartPoint())
|
||||
|
|
Loading…
Reference in New Issue