Clean up LeastSquaresBuilder

Provide methods for using old and new interfaces. Data is stored internally
using the new interfaces now.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569356 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2014-02-18 14:33:08 +00:00
parent 4dcae270e6
commit 8916830e8a
6 changed files with 63 additions and 49 deletions

View File

@ -3,7 +3,9 @@ package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.analysis.MultivariateMatrixFunction; import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
import org.apache.commons.math3.analysis.MultivariateVectorFunction; import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.PointVectorValuePair; import org.apache.commons.math3.optim.PointVectorValuePair;
@ -22,13 +24,11 @@ public class LeastSquaresBuilder {
/** convergence checker */ /** convergence checker */
private ConvergenceChecker<Evaluation> checker; private ConvergenceChecker<Evaluation> checker;
/** model function */ /** model function */
private MultivariateVectorFunction model; private MultivariateJacobianFunction model;
/** Jacobian function */
private MultivariateMatrixFunction jacobian;
/** observed values */ /** observed values */
private double[] target; private RealVector target;
/** initial guess */ /** initial guess */
private double[] start; private RealVector start;
/** weight matrix */ /** weight matrix */
private RealMatrix weight; private RealMatrix weight;
@ -39,7 +39,7 @@ public class LeastSquaresBuilder {
* @return a new {@link LeastSquaresProblem}. * @return a new {@link LeastSquaresProblem}.
*/ */
public LeastSquaresProblem build() { public LeastSquaresProblem build() {
return LeastSquaresFactory.create(model, jacobian, target, start, weight, checker, maxEvaluations, maxIterations); return LeastSquaresFactory.create(model, target, start, weight, checker, maxEvaluations, maxIterations);
} }
/** /**
@ -90,22 +90,34 @@ public class LeastSquaresBuilder {
/** /**
* Configure the model function. * Configure the model function.
* *
* @param model the model function * @param value the model function value
* @param jacobian the Jacobian of {@code value}
* @return this * @return this
*/ */
public LeastSquaresBuilder model(final MultivariateVectorFunction model) { public LeastSquaresBuilder model(final MultivariateVectorFunction value,
final MultivariateMatrixFunction jacobian) {
return model(LeastSquaresFactory.model(value, jacobian));
}
/**
* Configure the model function.
*
* @param model the model function value and Jacobian
* @return this
*/
public LeastSquaresBuilder model(final MultivariateJacobianFunction model) {
this.model = model; this.model = model;
return this; return this;
} }
/** /**
* Configure the Jacobian function. * Configure the observed data.
* *
* @param jacobian the Jacobian function * @param target the observed data.
* @return this * @return this
*/ */
public LeastSquaresBuilder jacobian(final MultivariateMatrixFunction jacobian) { public LeastSquaresBuilder target(final RealVector target) {
this.jacobian = jacobian; this.target = target;
return this; return this;
} }
@ -116,7 +128,17 @@ public class LeastSquaresBuilder {
* @return this * @return this
*/ */
public LeastSquaresBuilder target(final double[] target) { public LeastSquaresBuilder target(final double[] target) {
this.target = target; return target(new ArrayRealVector(target, false));
}
/**
* Configure the initial guess.
*
* @param start the initial guess.
* @return this
*/
public LeastSquaresBuilder start(final RealVector start) {
this.start = start;
return this; return this;
} }
@ -127,8 +149,7 @@ public class LeastSquaresBuilder {
* @return this * @return this
*/ */
public LeastSquaresBuilder start(final double[] start) { public LeastSquaresBuilder start(final double[] start) {
this.start = start; return start(new ArrayRealVector(start, false));
return this;
} }
/** /**

View File

@ -22,12 +22,15 @@ import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D; import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.BlockRealMatrix; import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.DiagonalMatrix; import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.SimpleVectorValueChecker; import org.apache.commons.math3.optim.SimpleVectorValueChecker;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;
import org.junit.Assert; import org.junit.Assert;
import java.io.IOException; import java.io.IOException;
@ -59,8 +62,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
final double[] weights = new double[c.getN()]; final double[] weights = new double[c.getN()];
Arrays.fill(weights, 1.0); Arrays.fill(weights, 1.0);
return base() return base()
.model(c.getModelFunction()) .model(c.getModelFunction(), c.getModelFunctionJacobian())
.jacobian(c.getModelFunctionJacobian())
.target(new double[c.getN()]) .target(new double[c.getN()])
.weight(new DiagonalMatrix(weights)); .weight(new DiagonalMatrix(weights));
} }
@ -71,8 +73,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
final double[] weights = new double[dataset.getNumObservations()]; final double[] weights = new double[dataset.getNumObservations()];
Arrays.fill(weights, 1.0); Arrays.fill(weights, 1.0);
return base() return base()
.model(problem.getModelFunction()) .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
.jacobian(problem.getModelFunctionJacobian())
.target(dataset.getData()[1]) .target(dataset.getData()[1])
.weight(new DiagonalMatrix(weights)) .weight(new DiagonalMatrix(weights))
.start(dataset.getStartingPoint(0)); .start(dataset.getStartingPoint(0));
@ -133,24 +134,22 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
.target(new double[]{1}) .target(new double[]{1})
.weight(new DiagonalMatrix(new double[]{1})) .weight(new DiagonalMatrix(new double[]{1}))
.start(new double[]{3}) .start(new double[]{3})
.model( .model(new MultivariateJacobianFunction() {
new MultivariateVectorFunction() { public Pair<RealVector, RealMatrix> value(final RealVector point) {
public double[] value(double[] point) { return new Pair<RealVector, RealMatrix>(
return new double[]{ new ArrayRealVector(
FastMath.pow(point[0], 4) new double[]{
}; FastMath.pow(point.getEntry(0), 4)
} },
} false),
) new Array2DRowRealMatrix(
.jacobian( new double[][]{
new MultivariateMatrixFunction() { {0.25 * FastMath.pow(point.getEntry(0), 3)}
public double[][] value(double[] point) { },
return new double[][]{ false)
{0.25 * FastMath.pow(point[0], 3)} );
}; }
} })
}
)
.build(); .build();
Optimum optimum = optimizer.optimize(lsp); Optimum optimum = optimizer.optimize(lsp);
@ -554,8 +553,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
final double[] weights = new double[target.length]; final double[] weights = new double[target.length];
Arrays.fill(weights, 1.0); Arrays.fill(weights, 1.0);
return base() return base()
.model(getModelFunction()) .model(getModelFunction(), getModelFunctionJacobian())
.jacobian(getModelFunctionJacobian())
.target(target) .target(target)
.weight(new DiagonalMatrix(weights)) .weight(new DiagonalMatrix(weights))
.start(new double[factors.getColumnDimension()]); .start(new double[factors.getColumnDimension()]);

View File

@ -44,8 +44,7 @@ public class EvaluationTest {
Arrays.fill(weights, 1d); Arrays.fill(weights, 1d);
return new LeastSquaresBuilder() return new LeastSquaresBuilder()
.model(problem.getModelFunction()) .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
.jacobian(problem.getModelFunctionJacobian())
.target(observed) .target(observed)
.weight(new DiagonalMatrix(weights)) .weight(new DiagonalMatrix(weights))
.start(start); .start(start);

View File

@ -288,8 +288,7 @@ public class EvaluationTestValidation {
LeastSquaresBuilder builder(StraightLineProblem problem){ LeastSquaresBuilder builder(StraightLineProblem problem){
return new LeastSquaresBuilder() return new LeastSquaresBuilder()
.model(problem.getModelFunction()) .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
.jacobian(problem.getModelFunctionJacobian())
.target(problem.target()) .target(problem.target())
.weight(new DiagonalMatrix(problem.weight())) .weight(new DiagonalMatrix(problem.weight()))
//unused start point to avoid NPE //unused start point to avoid NPE

View File

@ -49,14 +49,12 @@ public class LevenbergMarquardtOptimizerTest
public LeastSquaresBuilder builder(BevingtonProblem problem){ public LeastSquaresBuilder builder(BevingtonProblem problem){
return base() return base()
.model(problem.getModelFunction()) .model(problem.getModelFunction(), problem.getModelFunctionJacobian());
.jacobian(problem.getModelFunctionJacobian());
} }
public LeastSquaresBuilder builder(CircleProblem problem){ public LeastSquaresBuilder builder(CircleProblem problem){
return base() return base()
.model(problem.getModelFunction()) .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
.jacobian(problem.getModelFunctionJacobian())
.target(problem.target()) .target(problem.target())
.weight(new DiagonalMatrix(problem.weight())); .weight(new DiagonalMatrix(problem.weight()));
} }

View File

@ -507,8 +507,7 @@ public class MinpackTest {
LeastSquaresProblem problem = new LeastSquaresBuilder() LeastSquaresProblem problem = new LeastSquaresBuilder()
.maxEvaluations(400 * (function.getN() + 1)) .maxEvaluations(400 * (function.getN() + 1))
.maxIterations(2000) .maxIterations(2000)
.model(function.getModelFunction()) .model(function.getModelFunction(), function.getModelFunctionJacobian())
.jacobian(function.getModelFunctionJacobian())
.target(function.getTarget()) .target(function.getTarget())
.weight(new DiagonalMatrix(function.getWeight())) .weight(new DiagonalMatrix(function.getWeight()))
.start(function.getStartPoint()) .start(function.getStartPoint())