diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java index 647bc4d2f..42545ac68 100644 --- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java +++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java @@ -3,7 +3,9 @@ package org.apache.commons.math3.fitting.leastsquares; import org.apache.commons.math3.analysis.MultivariateMatrixFunction; import org.apache.commons.math3.analysis.MultivariateVectorFunction; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; +import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.optim.PointVectorValuePair; @@ -22,13 +24,11 @@ public class LeastSquaresBuilder { /** convergence checker */ private ConvergenceChecker checker; /** model function */ - private MultivariateVectorFunction model; - /** Jacobian function */ - private MultivariateMatrixFunction jacobian; + private MultivariateJacobianFunction model; /** observed values */ - private double[] target; + private RealVector target; /** initial guess */ - private double[] start; + private RealVector start; /** weight matrix */ private RealMatrix weight; @@ -39,7 +39,7 @@ public class LeastSquaresBuilder { * @return a new {@link LeastSquaresProblem}. */ public LeastSquaresProblem build() { - return LeastSquaresFactory.create(model, jacobian, target, start, weight, checker, maxEvaluations, maxIterations); + return LeastSquaresFactory.create(model, target, start, weight, checker, maxEvaluations, maxIterations); } /** @@ -90,22 +90,34 @@ public class LeastSquaresBuilder { /** * Configure the model function. * - * @param model the model function + * @param value the model function value + * @param jacobian the Jacobian of {@code value} * @return this */ - public LeastSquaresBuilder model(final MultivariateVectorFunction model) { + public LeastSquaresBuilder model(final MultivariateVectorFunction value, + final MultivariateMatrixFunction jacobian) { + return model(LeastSquaresFactory.model(value, jacobian)); + } + + /** + * Configure the model function. + * + * @param model the model function value and Jacobian + * @return this + */ + public LeastSquaresBuilder model(final MultivariateJacobianFunction model) { this.model = model; return this; } /** - * Configure the Jacobian function. + * Configure the observed data. * - * @param jacobian the Jacobian function + * @param target the observed data. * @return this */ - public LeastSquaresBuilder jacobian(final MultivariateMatrixFunction jacobian) { - this.jacobian = jacobian; + public LeastSquaresBuilder target(final RealVector target) { + this.target = target; return this; } @@ -116,7 +128,17 @@ public class LeastSquaresBuilder { * @return this */ public LeastSquaresBuilder target(final double[] target) { - this.target = target; + return target(new ArrayRealVector(target, false)); + } + + /** + * Configure the initial guess. + * + * @param start the initial guess. + * @return this + */ + public LeastSquaresBuilder start(final RealVector start) { + this.start = start; return this; } @@ -127,8 +149,7 @@ public class LeastSquaresBuilder { * @return this */ public LeastSquaresBuilder start(final double[] start) { - this.start = start; - return this; + return start(new ArrayRealVector(start, false)); } /** diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java b/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java index 55d4d3a73..7baf5e94a 100644 --- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java +++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java @@ -22,12 +22,15 @@ import org.apache.commons.math3.exception.ConvergenceException; import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum; import org.apache.commons.math3.geometry.euclidean.twod.Vector2D; +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.BlockRealMatrix; import org.apache.commons.math3.linear.DiagonalMatrix; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.optim.SimpleVectorValueChecker; import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.Pair; import org.junit.Assert; import java.io.IOException; @@ -59,8 +62,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest { final double[] weights = new double[c.getN()]; Arrays.fill(weights, 1.0); return base() - .model(c.getModelFunction()) - .jacobian(c.getModelFunctionJacobian()) + .model(c.getModelFunction(), c.getModelFunctionJacobian()) .target(new double[c.getN()]) .weight(new DiagonalMatrix(weights)); } @@ -71,8 +73,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest { final double[] weights = new double[dataset.getNumObservations()]; Arrays.fill(weights, 1.0); return base() - .model(problem.getModelFunction()) - .jacobian(problem.getModelFunctionJacobian()) + .model(problem.getModelFunction(), problem.getModelFunctionJacobian()) .target(dataset.getData()[1]) .weight(new DiagonalMatrix(weights)) .start(dataset.getStartingPoint(0)); @@ -133,24 +134,22 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest { .target(new double[]{1}) .weight(new DiagonalMatrix(new double[]{1})) .start(new double[]{3}) - .model( - new MultivariateVectorFunction() { - public double[] value(double[] point) { - return new double[]{ - FastMath.pow(point[0], 4) - }; - } - } - ) - .jacobian( - new MultivariateMatrixFunction() { - public double[][] value(double[] point) { - return new double[][]{ - {0.25 * FastMath.pow(point[0], 3)} - }; - } - } - ) + .model(new MultivariateJacobianFunction() { + public Pair value(final RealVector point) { + return new Pair( + new ArrayRealVector( + new double[]{ + FastMath.pow(point.getEntry(0), 4) + }, + false), + new Array2DRowRealMatrix( + new double[][]{ + {0.25 * FastMath.pow(point.getEntry(0), 3)} + }, + false) + ); + } + }) .build(); Optimum optimum = optimizer.optimize(lsp); @@ -554,8 +553,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest { final double[] weights = new double[target.length]; Arrays.fill(weights, 1.0); return base() - .model(getModelFunction()) - .jacobian(getModelFunctionJacobian()) + .model(getModelFunction(), getModelFunctionJacobian()) .target(target) .weight(new DiagonalMatrix(weights)) .start(new double[factors.getColumnDimension()]); diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java b/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java index f0150283a..d59243455 100644 --- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java +++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java @@ -44,8 +44,7 @@ public class EvaluationTest { Arrays.fill(weights, 1d); return new LeastSquaresBuilder() - .model(problem.getModelFunction()) - .jacobian(problem.getModelFunctionJacobian()) + .model(problem.getModelFunction(), problem.getModelFunctionJacobian()) .target(observed) .weight(new DiagonalMatrix(weights)) .start(start); diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTestValidation.java b/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTestValidation.java index 181e7a1b6..6779c2340 100644 --- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTestValidation.java +++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTestValidation.java @@ -288,8 +288,7 @@ public class EvaluationTestValidation { LeastSquaresBuilder builder(StraightLineProblem problem){ return new LeastSquaresBuilder() - .model(problem.getModelFunction()) - .jacobian(problem.getModelFunctionJacobian()) + .model(problem.getModelFunction(), problem.getModelFunctionJacobian()) .target(problem.target()) .weight(new DiagonalMatrix(problem.weight())) //unused start point to avoid NPE diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java b/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java index f07fb6759..10d345761 100644 --- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java +++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java @@ -49,14 +49,12 @@ public class LevenbergMarquardtOptimizerTest public LeastSquaresBuilder builder(BevingtonProblem problem){ return base() - .model(problem.getModelFunction()) - .jacobian(problem.getModelFunctionJacobian()); + .model(problem.getModelFunction(), problem.getModelFunctionJacobian()); } public LeastSquaresBuilder builder(CircleProblem problem){ return base() - .model(problem.getModelFunction()) - .jacobian(problem.getModelFunctionJacobian()) + .model(problem.getModelFunction(), problem.getModelFunctionJacobian()) .target(problem.target()) .weight(new DiagonalMatrix(problem.weight())); } diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/MinpackTest.java b/src/test/java/org/apache/commons/math3/fitting/leastsquares/MinpackTest.java index 2ef8b5f24..c99925844 100644 --- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/MinpackTest.java +++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/MinpackTest.java @@ -507,8 +507,7 @@ public class MinpackTest { LeastSquaresProblem problem = new LeastSquaresBuilder() .maxEvaluations(400 * (function.getN() + 1)) .maxIterations(2000) - .model(function.getModelFunction()) - .jacobian(function.getModelFunctionJacobian()) + .model(function.getModelFunction(), function.getModelFunctionJacobian()) .target(function.getTarget()) .weight(new DiagonalMatrix(function.getWeight())) .start(function.getStartPoint())