Clean up LeastSquaresBuilder
Provide methods for using old and new interfaces. Data is stored internally using the new interfaces now. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569356 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
4dcae270e6
commit
8916830e8a
|
@ -3,7 +3,9 @@ package org.apache.commons.math3.fitting.leastsquares;
|
||||||
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
||||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||||
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
||||||
|
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||||
|
|
||||||
|
@ -22,13 +24,11 @@ public class LeastSquaresBuilder {
|
||||||
/** convergence checker */
|
/** convergence checker */
|
||||||
private ConvergenceChecker<Evaluation> checker;
|
private ConvergenceChecker<Evaluation> checker;
|
||||||
/** model function */
|
/** model function */
|
||||||
private MultivariateVectorFunction model;
|
private MultivariateJacobianFunction model;
|
||||||
/** Jacobian function */
|
|
||||||
private MultivariateMatrixFunction jacobian;
|
|
||||||
/** observed values */
|
/** observed values */
|
||||||
private double[] target;
|
private RealVector target;
|
||||||
/** initial guess */
|
/** initial guess */
|
||||||
private double[] start;
|
private RealVector start;
|
||||||
/** weight matrix */
|
/** weight matrix */
|
||||||
private RealMatrix weight;
|
private RealMatrix weight;
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ public class LeastSquaresBuilder {
|
||||||
* @return a new {@link LeastSquaresProblem}.
|
* @return a new {@link LeastSquaresProblem}.
|
||||||
*/
|
*/
|
||||||
public LeastSquaresProblem build() {
|
public LeastSquaresProblem build() {
|
||||||
return LeastSquaresFactory.create(model, jacobian, target, start, weight, checker, maxEvaluations, maxIterations);
|
return LeastSquaresFactory.create(model, target, start, weight, checker, maxEvaluations, maxIterations);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -90,22 +90,34 @@ public class LeastSquaresBuilder {
|
||||||
/**
|
/**
|
||||||
* Configure the model function.
|
* Configure the model function.
|
||||||
*
|
*
|
||||||
* @param model the model function
|
* @param value the model function value
|
||||||
|
* @param jacobian the Jacobian of {@code value}
|
||||||
* @return this
|
* @return this
|
||||||
*/
|
*/
|
||||||
public LeastSquaresBuilder model(final MultivariateVectorFunction model) {
|
public LeastSquaresBuilder model(final MultivariateVectorFunction value,
|
||||||
|
final MultivariateMatrixFunction jacobian) {
|
||||||
|
return model(LeastSquaresFactory.model(value, jacobian));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure the model function.
|
||||||
|
*
|
||||||
|
* @param model the model function value and Jacobian
|
||||||
|
* @return this
|
||||||
|
*/
|
||||||
|
public LeastSquaresBuilder model(final MultivariateJacobianFunction model) {
|
||||||
this.model = model;
|
this.model = model;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Configure the Jacobian function.
|
* Configure the observed data.
|
||||||
*
|
*
|
||||||
* @param jacobian the Jacobian function
|
* @param target the observed data.
|
||||||
* @return this
|
* @return this
|
||||||
*/
|
*/
|
||||||
public LeastSquaresBuilder jacobian(final MultivariateMatrixFunction jacobian) {
|
public LeastSquaresBuilder target(final RealVector target) {
|
||||||
this.jacobian = jacobian;
|
this.target = target;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,7 +128,17 @@ public class LeastSquaresBuilder {
|
||||||
* @return this
|
* @return this
|
||||||
*/
|
*/
|
||||||
public LeastSquaresBuilder target(final double[] target) {
|
public LeastSquaresBuilder target(final double[] target) {
|
||||||
this.target = target;
|
return target(new ArrayRealVector(target, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure the initial guess.
|
||||||
|
*
|
||||||
|
* @param start the initial guess.
|
||||||
|
* @return this
|
||||||
|
*/
|
||||||
|
public LeastSquaresBuilder start(final RealVector start) {
|
||||||
|
this.start = start;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,8 +149,7 @@ public class LeastSquaresBuilder {
|
||||||
* @return this
|
* @return this
|
||||||
*/
|
*/
|
||||||
public LeastSquaresBuilder start(final double[] start) {
|
public LeastSquaresBuilder start(final double[] start) {
|
||||||
this.start = start;
|
return start(new ArrayRealVector(start, false));
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -22,12 +22,15 @@ import org.apache.commons.math3.exception.ConvergenceException;
|
||||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||||
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
|
||||||
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
|
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
|
||||||
|
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||||
|
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||||
import org.apache.commons.math3.linear.BlockRealMatrix;
|
import org.apache.commons.math3.linear.BlockRealMatrix;
|
||||||
import org.apache.commons.math3.linear.DiagonalMatrix;
|
import org.apache.commons.math3.linear.DiagonalMatrix;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
import org.apache.commons.math3.linear.RealVector;
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
import org.apache.commons.math3.optim.SimpleVectorValueChecker;
|
import org.apache.commons.math3.optim.SimpleVectorValueChecker;
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
|
import org.apache.commons.math3.util.Pair;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -59,8 +62,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
final double[] weights = new double[c.getN()];
|
final double[] weights = new double[c.getN()];
|
||||||
Arrays.fill(weights, 1.0);
|
Arrays.fill(weights, 1.0);
|
||||||
return base()
|
return base()
|
||||||
.model(c.getModelFunction())
|
.model(c.getModelFunction(), c.getModelFunctionJacobian())
|
||||||
.jacobian(c.getModelFunctionJacobian())
|
|
||||||
.target(new double[c.getN()])
|
.target(new double[c.getN()])
|
||||||
.weight(new DiagonalMatrix(weights));
|
.weight(new DiagonalMatrix(weights));
|
||||||
}
|
}
|
||||||
|
@ -71,8 +73,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
final double[] weights = new double[dataset.getNumObservations()];
|
final double[] weights = new double[dataset.getNumObservations()];
|
||||||
Arrays.fill(weights, 1.0);
|
Arrays.fill(weights, 1.0);
|
||||||
return base()
|
return base()
|
||||||
.model(problem.getModelFunction())
|
.model(problem.getModelFunction(), problem.getModelFunctionJacobian())
|
||||||
.jacobian(problem.getModelFunctionJacobian())
|
|
||||||
.target(dataset.getData()[1])
|
.target(dataset.getData()[1])
|
||||||
.weight(new DiagonalMatrix(weights))
|
.weight(new DiagonalMatrix(weights))
|
||||||
.start(dataset.getStartingPoint(0));
|
.start(dataset.getStartingPoint(0));
|
||||||
|
@ -133,24 +134,22 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
.target(new double[]{1})
|
.target(new double[]{1})
|
||||||
.weight(new DiagonalMatrix(new double[]{1}))
|
.weight(new DiagonalMatrix(new double[]{1}))
|
||||||
.start(new double[]{3})
|
.start(new double[]{3})
|
||||||
.model(
|
.model(new MultivariateJacobianFunction() {
|
||||||
new MultivariateVectorFunction() {
|
public Pair<RealVector, RealMatrix> value(final RealVector point) {
|
||||||
public double[] value(double[] point) {
|
return new Pair<RealVector, RealMatrix>(
|
||||||
return new double[]{
|
new ArrayRealVector(
|
||||||
FastMath.pow(point[0], 4)
|
new double[]{
|
||||||
};
|
FastMath.pow(point.getEntry(0), 4)
|
||||||
}
|
},
|
||||||
}
|
false),
|
||||||
)
|
new Array2DRowRealMatrix(
|
||||||
.jacobian(
|
new double[][]{
|
||||||
new MultivariateMatrixFunction() {
|
{0.25 * FastMath.pow(point.getEntry(0), 3)}
|
||||||
public double[][] value(double[] point) {
|
},
|
||||||
return new double[][]{
|
false)
|
||||||
{0.25 * FastMath.pow(point[0], 3)}
|
);
|
||||||
};
|
}
|
||||||
}
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
Optimum optimum = optimizer.optimize(lsp);
|
Optimum optimum = optimizer.optimize(lsp);
|
||||||
|
@ -554,8 +553,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
final double[] weights = new double[target.length];
|
final double[] weights = new double[target.length];
|
||||||
Arrays.fill(weights, 1.0);
|
Arrays.fill(weights, 1.0);
|
||||||
return base()
|
return base()
|
||||||
.model(getModelFunction())
|
.model(getModelFunction(), getModelFunctionJacobian())
|
||||||
.jacobian(getModelFunctionJacobian())
|
|
||||||
.target(target)
|
.target(target)
|
||||||
.weight(new DiagonalMatrix(weights))
|
.weight(new DiagonalMatrix(weights))
|
||||||
.start(new double[factors.getColumnDimension()]);
|
.start(new double[factors.getColumnDimension()]);
|
||||||
|
|
|
@ -44,8 +44,7 @@ public class EvaluationTest {
|
||||||
Arrays.fill(weights, 1d);
|
Arrays.fill(weights, 1d);
|
||||||
|
|
||||||
return new LeastSquaresBuilder()
|
return new LeastSquaresBuilder()
|
||||||
.model(problem.getModelFunction())
|
.model(problem.getModelFunction(), problem.getModelFunctionJacobian())
|
||||||
.jacobian(problem.getModelFunctionJacobian())
|
|
||||||
.target(observed)
|
.target(observed)
|
||||||
.weight(new DiagonalMatrix(weights))
|
.weight(new DiagonalMatrix(weights))
|
||||||
.start(start);
|
.start(start);
|
||||||
|
|
|
@ -288,8 +288,7 @@ public class EvaluationTestValidation {
|
||||||
|
|
||||||
LeastSquaresBuilder builder(StraightLineProblem problem){
|
LeastSquaresBuilder builder(StraightLineProblem problem){
|
||||||
return new LeastSquaresBuilder()
|
return new LeastSquaresBuilder()
|
||||||
.model(problem.getModelFunction())
|
.model(problem.getModelFunction(), problem.getModelFunctionJacobian())
|
||||||
.jacobian(problem.getModelFunctionJacobian())
|
|
||||||
.target(problem.target())
|
.target(problem.target())
|
||||||
.weight(new DiagonalMatrix(problem.weight()))
|
.weight(new DiagonalMatrix(problem.weight()))
|
||||||
//unused start point to avoid NPE
|
//unused start point to avoid NPE
|
||||||
|
|
|
@ -49,14 +49,12 @@ public class LevenbergMarquardtOptimizerTest
|
||||||
|
|
||||||
public LeastSquaresBuilder builder(BevingtonProblem problem){
|
public LeastSquaresBuilder builder(BevingtonProblem problem){
|
||||||
return base()
|
return base()
|
||||||
.model(problem.getModelFunction())
|
.model(problem.getModelFunction(), problem.getModelFunctionJacobian());
|
||||||
.jacobian(problem.getModelFunctionJacobian());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public LeastSquaresBuilder builder(CircleProblem problem){
|
public LeastSquaresBuilder builder(CircleProblem problem){
|
||||||
return base()
|
return base()
|
||||||
.model(problem.getModelFunction())
|
.model(problem.getModelFunction(), problem.getModelFunctionJacobian())
|
||||||
.jacobian(problem.getModelFunctionJacobian())
|
|
||||||
.target(problem.target())
|
.target(problem.target())
|
||||||
.weight(new DiagonalMatrix(problem.weight()));
|
.weight(new DiagonalMatrix(problem.weight()));
|
||||||
}
|
}
|
||||||
|
|
|
@ -507,8 +507,7 @@ public class MinpackTest {
|
||||||
LeastSquaresProblem problem = new LeastSquaresBuilder()
|
LeastSquaresProblem problem = new LeastSquaresBuilder()
|
||||||
.maxEvaluations(400 * (function.getN() + 1))
|
.maxEvaluations(400 * (function.getN() + 1))
|
||||||
.maxIterations(2000)
|
.maxIterations(2000)
|
||||||
.model(function.getModelFunction())
|
.model(function.getModelFunction(), function.getModelFunctionJacobian())
|
||||||
.jacobian(function.getModelFunctionJacobian())
|
|
||||||
.target(function.getTarget())
|
.target(function.getTarget())
|
||||||
.weight(new DiagonalMatrix(function.getWeight()))
|
.weight(new DiagonalMatrix(function.getWeight()))
|
||||||
.start(function.getStartPoint())
|
.start(function.getStartPoint())
|
||||||
|
|
Loading…
Reference in New Issue