Clean up LeastSquaresBuilder

Provide methods for using old and new interfaces. Data is stored internally
using the new interfaces now.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569356 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2014-02-18 14:33:08 +00:00
parent 4dcae270e6
commit 8916830e8a
6 changed files with 63 additions and 49 deletions

View File

@ -3,7 +3,9 @@ package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.PointVectorValuePair;
@ -22,13 +24,11 @@ public class LeastSquaresBuilder {
/** convergence checker */
private ConvergenceChecker<Evaluation> checker;
/** model function */
private MultivariateVectorFunction model;
/** Jacobian function */
private MultivariateMatrixFunction jacobian;
private MultivariateJacobianFunction model;
/** observed values */
private double[] target;
private RealVector target;
/** initial guess */
private double[] start;
private RealVector start;
/** weight matrix */
private RealMatrix weight;
@ -39,7 +39,7 @@ public class LeastSquaresBuilder {
* @return a new {@link LeastSquaresProblem}.
*/
public LeastSquaresProblem build() {
return LeastSquaresFactory.create(model, jacobian, target, start, weight, checker, maxEvaluations, maxIterations);
return LeastSquaresFactory.create(model, target, start, weight, checker, maxEvaluations, maxIterations);
}
/**
@ -90,22 +90,34 @@ public class LeastSquaresBuilder {
/**
* Configure the model function.
*
* @param model the model function
* @param value the model function value
* @param jacobian the Jacobian of {@code value}
* @return this
*/
public LeastSquaresBuilder model(final MultivariateVectorFunction model) {
public LeastSquaresBuilder model(final MultivariateVectorFunction value,
final MultivariateMatrixFunction jacobian) {
return model(LeastSquaresFactory.model(value, jacobian));
}
/**
* Configure the model function.
*
* @param model the model function value and Jacobian
* @return this
*/
public LeastSquaresBuilder model(final MultivariateJacobianFunction model) {
this.model = model;
return this;
}
/**
* Configure the Jacobian function.
* Configure the observed data.
*
* @param jacobian the Jacobian function
* @param target the observed data.
* @return this
*/
public LeastSquaresBuilder jacobian(final MultivariateMatrixFunction jacobian) {
this.jacobian = jacobian;
public LeastSquaresBuilder target(final RealVector target) {
this.target = target;
return this;
}
@ -116,7 +128,17 @@ public class LeastSquaresBuilder {
* @return this
*/
public LeastSquaresBuilder target(final double[] target) {
this.target = target;
return target(new ArrayRealVector(target, false));
}
/**
* Configure the initial guess.
*
* @param start the initial guess.
* @return this
*/
public LeastSquaresBuilder start(final RealVector start) {
this.start = start;
return this;
}
@ -127,8 +149,7 @@ public class LeastSquaresBuilder {
* @return this
*/
public LeastSquaresBuilder start(final double[] start) {
this.start = start;
return this;
return start(new ArrayRealVector(start, false));
}
/**

View File

@ -22,12 +22,15 @@ import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.SimpleVectorValueChecker;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;
import org.junit.Assert;
import java.io.IOException;
@ -59,8 +62,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
final double[] weights = new double[c.getN()];
Arrays.fill(weights, 1.0);
return base()
.model(c.getModelFunction())
.jacobian(c.getModelFunctionJacobian())
.model(c.getModelFunction(), c.getModelFunctionJacobian())
.target(new double[c.getN()])
.weight(new DiagonalMatrix(weights));
}
@ -71,8 +73,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
final double[] weights = new double[dataset.getNumObservations()];
Arrays.fill(weights, 1.0);
return base()
.model(problem.getModelFunction())
.jacobian(problem.getModelFunctionJacobian())
.model(problem.getModelFunction(), problem.getModelFunctionJacobian())
.target(dataset.getData()[1])
.weight(new DiagonalMatrix(weights))
.start(dataset.getStartingPoint(0));
@ -133,24 +134,22 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
.target(new double[]{1})
.weight(new DiagonalMatrix(new double[]{1}))
.start(new double[]{3})
.model(
new MultivariateVectorFunction() {
public double[] value(double[] point) {
return new double[]{
FastMath.pow(point[0], 4)
};
.model(new MultivariateJacobianFunction() {
public Pair<RealVector, RealMatrix> value(final RealVector point) {
return new Pair<RealVector, RealMatrix>(
new ArrayRealVector(
new double[]{
FastMath.pow(point.getEntry(0), 4)
},
false),
new Array2DRowRealMatrix(
new double[][]{
{0.25 * FastMath.pow(point.getEntry(0), 3)}
},
false)
);
}
}
)
.jacobian(
new MultivariateMatrixFunction() {
public double[][] value(double[] point) {
return new double[][]{
{0.25 * FastMath.pow(point[0], 3)}
};
}
}
)
})
.build();
Optimum optimum = optimizer.optimize(lsp);
@ -554,8 +553,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
final double[] weights = new double[target.length];
Arrays.fill(weights, 1.0);
return base()
.model(getModelFunction())
.jacobian(getModelFunctionJacobian())
.model(getModelFunction(), getModelFunctionJacobian())
.target(target)
.weight(new DiagonalMatrix(weights))
.start(new double[factors.getColumnDimension()]);

View File

@ -44,8 +44,7 @@ public class EvaluationTest {
Arrays.fill(weights, 1d);
return new LeastSquaresBuilder()
.model(problem.getModelFunction())
.jacobian(problem.getModelFunctionJacobian())
.model(problem.getModelFunction(), problem.getModelFunctionJacobian())
.target(observed)
.weight(new DiagonalMatrix(weights))
.start(start);

View File

@ -288,8 +288,7 @@ public class EvaluationTestValidation {
LeastSquaresBuilder builder(StraightLineProblem problem){
return new LeastSquaresBuilder()
.model(problem.getModelFunction())
.jacobian(problem.getModelFunctionJacobian())
.model(problem.getModelFunction(), problem.getModelFunctionJacobian())
.target(problem.target())
.weight(new DiagonalMatrix(problem.weight()))
//unused start point to avoid NPE

View File

@ -49,14 +49,12 @@ public class LevenbergMarquardtOptimizerTest
public LeastSquaresBuilder builder(BevingtonProblem problem){
return base()
.model(problem.getModelFunction())
.jacobian(problem.getModelFunctionJacobian());
.model(problem.getModelFunction(), problem.getModelFunctionJacobian());
}
public LeastSquaresBuilder builder(CircleProblem problem){
return base()
.model(problem.getModelFunction())
.jacobian(problem.getModelFunctionJacobian())
.model(problem.getModelFunction(), problem.getModelFunctionJacobian())
.target(problem.target())
.weight(new DiagonalMatrix(problem.weight()));
}

View File

@ -507,8 +507,7 @@ public class MinpackTest {
LeastSquaresProblem problem = new LeastSquaresBuilder()
.maxEvaluations(400 * (function.getN() + 1))
.maxIterations(2000)
.model(function.getModelFunction())
.jacobian(function.getModelFunctionJacobian())
.model(function.getModelFunction(), function.getModelFunctionJacobian())
.target(function.getTarget())
.weight(new DiagonalMatrix(function.getWeight()))
.start(function.getStartPoint())