Use Real{Vector,Matrix} in LeastSquares interfaces

Covered all of the interfaces in the leastsquares package to use RealVector and
RealMatrix instead of double[] and double[][]. This reduced some duplicated code.
For example Evaluation.computeResiduals() was a complete duplication of
RealVector.subtract(). It also presents a consistent interface and allows data
encapsulation.

Lastly, this change enables [math] to "eat our own dog food." It enables the
linear package to be used in the implementation of the optimization algorithms.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569354 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2014-02-18 14:32:54 +00:00
parent a7a380f934
commit 0079828734
15 changed files with 202 additions and 171 deletions

View File

@ -5,6 +5,7 @@ import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.DecompositionSolver; import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.QRDecomposition; import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
/** /**
@ -31,7 +32,7 @@ abstract class AbstractEvaluation implements Evaluation {
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
public double[][] computeCovariances(double threshold) { public RealMatrix computeCovariances(double threshold) {
// Set up the Jacobian. // Set up the Jacobian.
final RealMatrix j = this.computeJacobian(); final RealMatrix j = this.computeJacobian();
@ -41,16 +42,16 @@ abstract class AbstractEvaluation implements Evaluation {
// Compute the covariances matrix. // Compute the covariances matrix.
final DecompositionSolver solver final DecompositionSolver solver
= new QRDecomposition(jTj, threshold).getSolver(); = new QRDecomposition(jTj, threshold).getSolver();
return solver.getInverse().getData(); return solver.getInverse();
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
public double[] computeSigma(double covarianceSingularityThreshold) { public RealVector computeSigma(double covarianceSingularityThreshold) {
final double[][] cov = this.computeCovariances(covarianceSingularityThreshold); final RealMatrix cov = this.computeCovariances(covarianceSingularityThreshold);
final int nC = cov.length; final int nC = cov.getColumnDimension();
final double[] sig = new double[nC]; final RealVector sig = new ArrayRealVector(nC);
for (int i = 0; i < nC; ++i) { for (int i = 0; i < nC; ++i) {
sig[i] = FastMath.sqrt(cov[i][i]); sig.setEntry(i, FastMath.sqrt(cov.getEntry(i,i)));
} }
return sig; return sig;
} }

View File

@ -2,6 +2,7 @@ package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
/** /**
* Applies a dense weight matrix to an evaluation. * Applies a dense weight matrix to an evaluation.
@ -37,19 +38,19 @@ class DenseWeightedEvaluation extends AbstractEvaluation {
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
public double[] computeResiduals() { public RealVector computeResiduals() {
return this.weightSqrt.operate(this.unweighted.computeResiduals()); return this.weightSqrt.operate(this.unweighted.computeResiduals());
} }
/* delegate */ /* delegate */
/** {@inheritDoc} */ /** {@inheritDoc} */
public double[] getPoint() { public RealVector getPoint() {
return unweighted.getPoint(); return unweighted.getPoint();
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
public double[] computeValue() { public RealVector computeValue() {
return unweighted.computeValue(); return unweighted.computeValue();
} }
} }

View File

@ -26,6 +26,7 @@ import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition; import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.QRDecomposition; import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException; import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.util.Incrementor; import org.apache.commons.math3.util.Incrementor;
@ -133,7 +134,7 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
final int nR = lsp.getObservationSize(); // Number of observed data. final int nR = lsp.getObservationSize(); // Number of observed data.
final int nC = lsp.getParameterSize(); final int nC = lsp.getParameterSize();
final double[] currentPoint = lsp.getStart(); final RealVector currentPoint = lsp.getStart();
// iterate until convergence is reached // iterate until convergence is reached
Evaluation current = null; Evaluation current = null;
@ -145,7 +146,7 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
// Value of the objective function at "currentPoint". // Value of the objective function at "currentPoint".
evaluationCounter.incrementCount(); evaluationCounter.incrementCount();
current = lsp.evaluate(currentPoint); current = lsp.evaluate(currentPoint);
final double[] currentResiduals = current.computeResiduals(); final RealVector currentResiduals = current.computeResiduals();
final RealMatrix weightedJacobian = current.computeJacobian(); final RealMatrix weightedJacobian = current.computeJacobian();
// Check convergence. // Check convergence.
@ -164,7 +165,7 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
for (int i = 0; i < nR; ++i) { for (int i = 0; i < nR; ++i) {
final double[] grad = weightedJacobian.getRow(i); final double[] grad = weightedJacobian.getRow(i);
final double residual = currentResiduals[i]; final double residual = currentResiduals.getEntry(i);
// compute the normal equation // compute the normal equation
//residual is already weighted //residual is already weighted
@ -186,10 +187,10 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
// solve the linearized least squares problem // solve the linearized least squares problem
RealMatrix mA = new BlockRealMatrix(a); RealMatrix mA = new BlockRealMatrix(a);
DecompositionSolver solver = this.decomposition.getSolver(mA); DecompositionSolver solver = this.decomposition.getSolver(mA);
final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray(); final RealVector dX = solver.solve(new ArrayRealVector(b, false));
// update the estimated parameters // update the estimated parameters
for (int i = 0; i < nC; ++i) { for (int i = 0; i < nC; ++i) {
currentPoint[i] += dX[i]; currentPoint.setEntry(i, currentPoint.getEntry(i) + dX.getEntry(i));
} }
} catch (SingularMatrixException e) { } catch (SingularMatrixException e) {
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM); throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);

View File

@ -1,5 +1,6 @@
package org.apache.commons.math3.fitting.leastsquares; package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.util.Incrementor; import org.apache.commons.math3.util.Incrementor;
@ -23,7 +24,7 @@ public class LeastSquaresAdapter implements LeastSquaresProblem {
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
public double[] getStart() { public RealVector getStart() {
return problem.getStart(); return problem.getStart();
} }
@ -37,8 +38,9 @@ public class LeastSquaresAdapter implements LeastSquaresProblem {
return problem.getParameterSize(); return problem.getParameterSize();
} }
/** {@inheritDoc} */ /** {@inheritDoc}
public Evaluation evaluate(final double[] point) { * @param point*/
public Evaluation evaluate(final RealVector point) {
return problem.evaluate(point); return problem.evaluate(point);
} }

View File

@ -39,8 +39,8 @@ public class LeastSquaresFactory {
* @return the specified General Least Squares problem. * @return the specified General Least Squares problem.
*/ */
public static LeastSquaresProblem create(final MultivariateJacobianFunction model, public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
final double[] observed, final RealVector observed,
final double[] start, final RealVector start,
final ConvergenceChecker<Evaluation> checker, final ConvergenceChecker<Evaluation> checker,
final int maxEvaluations, final int maxEvaluations,
final int maxIterations) { final int maxIterations) {
@ -54,6 +54,34 @@ public class LeastSquaresFactory {
); );
} }
/**
* Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem}
* from the given elements. There will be no weights applied (Identity weights).
*
* @param model the model function. Produces the computed values.
* @param observed the observed (target) values
* @param start the initial guess.
* @param checker convergence checker
* @param maxEvaluations the maximum number of times to evaluate the model
* @param maxIterations the maximum number to times to iterate in the algorithm
* @return the specified General Least Squares problem.
*/
public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
final double[] observed,
final double[] start,
final ConvergenceChecker<Evaluation> checker,
final int maxEvaluations,
final int maxIterations) {
return create(
model,
new ArrayRealVector(observed, false),
new ArrayRealVector(start, false),
checker,
maxEvaluations,
maxIterations
);
}
/** /**
* Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem} * Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem}
* from the given elements. There will be no weights applied (Identity weights). * from the given elements. There will be no weights applied (Identity weights).
@ -132,7 +160,7 @@ public class LeastSquaresFactory {
final RealMatrix weightSquareRoot = squareRoot(weights); final RealMatrix weightSquareRoot = squareRoot(weights);
return new LeastSquaresAdapter(problem) { return new LeastSquaresAdapter(problem) {
@Override @Override
public Evaluation evaluate(final double[] point) { public Evaluation evaluate(final RealVector point) {
return new DenseWeightedEvaluation(super.evaluate(point), weightSquareRoot); return new DenseWeightedEvaluation(super.evaluate(point), weightSquareRoot);
} }
}; };
@ -154,7 +182,7 @@ public class LeastSquaresFactory {
/** /**
* Count the evaluations of a particular problem. The {@code counter} will be * Count the evaluations of a particular problem. The {@code counter} will be
* incremented every time {@link LeastSquaresProblem#evaluate(double[])} is called on * incremented every time {@link LeastSquaresProblem#evaluate(RealVector)} is called on
* the <em>returned</em> problem. * the <em>returned</em> problem.
* *
* @param problem the problem to track. * @param problem the problem to track.
@ -165,7 +193,7 @@ public class LeastSquaresFactory {
final Incrementor counter) { final Incrementor counter) {
return new LeastSquaresAdapter(problem) { return new LeastSquaresAdapter(problem) {
public Evaluation evaluate(final double[] point) { public Evaluation evaluate(final RealVector point) {
counter.incrementCount(); counter.incrementCount();
return super.evaluate(point); return super.evaluate(point);
} }
@ -192,12 +220,12 @@ public class LeastSquaresFactory {
return checker.converged( return checker.converged(
iteration, iteration,
new PointVectorValuePair( new PointVectorValuePair(
previous.getPoint(), previous.getPoint().toArray(),
previous.computeValue(), previous.computeValue().toArray(),
false), false),
new PointVectorValuePair( new PointVectorValuePair(
current.getPoint(), current.getPoint().toArray(),
current.computeValue(), current.computeValue().toArray(),
false) false)
); );
} }
@ -237,11 +265,13 @@ public class LeastSquaresFactory {
final MultivariateMatrixFunction jacobian final MultivariateMatrixFunction jacobian
) { ) {
return new MultivariateJacobianFunction() { return new MultivariateJacobianFunction() {
public Pair<RealVector, RealMatrix> value(final double[] point) { public Pair<RealVector, RealMatrix> value(final RealVector point) {
//evaluate and use Real* interfaces without copying //TODO get array from RealVector without copying?
final double[] pointArray = point.toArray();
//evaluate and return data without copying
return new Pair<RealVector, RealMatrix>( return new Pair<RealVector, RealMatrix>(
new ArrayRealVector(value.value(point), false), new ArrayRealVector(value.value(pointArray), false),
new Array2DRowRealMatrix(jacobian.value(point), false)); new Array2DRowRealMatrix(jacobian.value(pointArray), false));
} }
}; };
} }

View File

@ -3,6 +3,7 @@ package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
/** /**
* The data necessary to define a non-linear least squares problem. Includes the observed * The data necessary to define a non-linear least squares problem. Includes the observed
@ -19,7 +20,7 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
* *
* @return the initial guess values. * @return the initial guess values.
*/ */
double[] getStart(); RealVector getStart();
/** /**
* Get the number of observations (rows in the Jacobian) in this problem. * Get the number of observations (rows in the Jacobian) in this problem.
@ -38,13 +39,14 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
/** /**
* Evaluate the model at the specified point. * Evaluate the model at the specified point.
* *
*
* @param point the parameter values. * @param point the parameter values.
* @return the model's value and derivative at the given point. * @return the model's value and derivative at the given point.
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException * @throws org.apache.commons.math3.exception.TooManyEvaluationsException
* if the maximal number of evaluations (of the model vector function) is * if the maximal number of evaluations (of the model vector function) is
* exceeded. * exceeded.
*/ */
Evaluation evaluate(double[] point); Evaluation evaluate(RealVector point);
/** /**
* An evaluation of a {@link LeastSquaresProblem} at a particular point. This class * An evaluation of a {@link LeastSquaresProblem} at a particular point. This class
@ -59,12 +61,13 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
* way for the caller to specify that the result of this computation should be * way for the caller to specify that the result of this computation should be
* considered meaningless, and thus trigger an exception. * considered meaningless, and thus trigger an exception.
* *
*
* @param threshold Singularity threshold. * @param threshold Singularity threshold.
* @return the covariance matrix. * @return the covariance matrix.
* @throws org.apache.commons.math3.linear.SingularMatrixException * @throws org.apache.commons.math3.linear.SingularMatrixException
* if the covariance matrix cannot be computed (singular problem). * if the covariance matrix cannot be computed (singular problem).
*/ */
double[][] computeCovariances(double threshold); RealMatrix computeCovariances(double threshold);
/** /**
* Computes an estimate of the standard deviation of the parameters. The returned * Computes an estimate of the standard deviation of the parameters. The returned
@ -72,13 +75,14 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
* matrix, {@code sd(a[i]) ~= sqrt(C[i][i])}, where {@code a[i]} is the optimized * matrix, {@code sd(a[i]) ~= sqrt(C[i][i])}, where {@code a[i]} is the optimized
* value of the {@code i}-th parameter, and {@code C} is the covariance matrix. * value of the {@code i}-th parameter, and {@code C} is the covariance matrix.
* *
*
* @param covarianceSingularityThreshold Singularity threshold (see {@link * @param covarianceSingularityThreshold Singularity threshold (see {@link
* #computeCovariances(double) computeCovariances}). * #computeCovariances(double) computeCovariances}).
* @return an estimate of the standard deviation of the optimized parameters * @return an estimate of the standard deviation of the optimized parameters
* @throws org.apache.commons.math3.linear.SingularMatrixException * @throws org.apache.commons.math3.linear.SingularMatrixException
* if the covariance matrix cannot be computed. * if the covariance matrix cannot be computed.
*/ */
double[] computeSigma(double covarianceSingularityThreshold); RealVector computeSigma(double covarianceSingularityThreshold);
/** /**
* Computes the normalized cost. It is the square-root of the sum of squared of * Computes the normalized cost. It is the square-root of the sum of squared of
@ -93,7 +97,7 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
* *
* @return the objective function value at the specified point. * @return the objective function value at the specified point.
*/ */
double[] computeValue(); RealVector computeValue();
/** /**
* Computes the weighted Jacobian matrix. * Computes the weighted Jacobian matrix.
@ -121,13 +125,13 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
* @return the weighted residuals: W<sup>1/2</sup> K. * @return the weighted residuals: W<sup>1/2</sup> K.
* @throws DimensionMismatchException if the residuals have the wrong length. * @throws DimensionMismatchException if the residuals have the wrong length.
*/ */
double[] computeResiduals(); RealVector computeResiduals();
/** /**
* Get the abscissa (independent variables) of this evaluation. * Get the abscissa (independent variables) of this evaluation.
* *
* @return the point provided to {@link #evaluate(double[])}. * @return the point provided to {@link #evaluate(RealVector)}.
*/ */
double[] getPoint(); RealVector getPoint();
} }
} }

View File

@ -16,7 +16,6 @@
*/ */
package org.apache.commons.math3.fitting.leastsquares; package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.linear.RealVector;
@ -36,15 +35,15 @@ class LeastSquaresProblemImpl
implements LeastSquaresProblem { implements LeastSquaresProblem {
/** Target values for the model function at optimum. */ /** Target values for the model function at optimum. */
private double[] target; private RealVector target;
/** Model function. */ /** Model function. */
private MultivariateJacobianFunction model; private MultivariateJacobianFunction model;
/** Initial guess. */ /** Initial guess. */
private double[] start; private RealVector start;
LeastSquaresProblemImpl(final MultivariateJacobianFunction model, LeastSquaresProblemImpl(final MultivariateJacobianFunction model,
final double[] target, final RealVector target,
final double[] start, final RealVector start,
final ConvergenceChecker<Evaluation> checker, final ConvergenceChecker<Evaluation> checker,
final int maxEvaluations, final int maxEvaluations,
final int maxIterations) { final int maxIterations) {
@ -55,18 +54,18 @@ class LeastSquaresProblemImpl
} }
public int getObservationSize() { public int getObservationSize() {
return target.length; return target.getDimension();
} }
public int getParameterSize() { public int getParameterSize() {
return start.length; return start.getDimension();
} }
public double[] getStart() { public RealVector getStart() {
return start == null ? null : start.clone(); return start == null ? null : start.copy();
} }
public Evaluation evaluate(final double[] point) { public Evaluation evaluate(final RealVector point) {
//evaluate value and jacobian in one function call //evaluate value and jacobian in one function call
final Pair<RealVector, RealMatrix> value = this.model.value(point); final Pair<RealVector, RealMatrix> value = this.model.value(point);
return new UnweightedEvaluation( return new UnweightedEvaluation(
@ -84,19 +83,19 @@ class LeastSquaresProblemImpl
private static class UnweightedEvaluation extends AbstractEvaluation { private static class UnweightedEvaluation extends AbstractEvaluation {
/** the point of evaluation */ /** the point of evaluation */
private final double[] point; private final RealVector point;
/** value at point */ /** value at point */
private final RealVector values; private final RealVector values;
/** deriviative at point */ /** deriviative at point */
private final RealMatrix jacobian; private final RealMatrix jacobian;
/** reference to the observed values */ /** reference to the observed values */
private final double[] target; private final RealVector target;
private UnweightedEvaluation(final RealVector values, private UnweightedEvaluation(final RealVector values,
final RealMatrix jacobian, final RealMatrix jacobian,
final double[] target, final RealVector target,
final double[] point) { final RealVector point) {
super(target.length); super(target.getDimension());
this.values = values; this.values = values;
this.jacobian = jacobian; this.jacobian = jacobian;
this.target = target; this.target = target;
@ -104,31 +103,20 @@ class LeastSquaresProblemImpl
} }
public double[] computeValue() { public RealVector computeValue() {
return this.values.toArray(); return this.values;
} }
public RealMatrix computeJacobian() { public RealMatrix computeJacobian() {
return this.jacobian; return this.jacobian;
} }
public double[] getPoint() { public RealVector getPoint() {
return this.point; return this.point;
} }
public double[] computeResiduals() { public RealVector computeResiduals() {
final double[] objectiveValue = this.computeValue(); return target.subtract(this.computeValue());
if (objectiveValue.length != target.length) {
throw new DimensionMismatchException(target.length,
objectiveValue.length);
}
final double[] residuals = new double[target.length];
for (int i = 0; i < target.length; i++) {
residuals[i] = target[i] - objectiveValue[i];
}
return residuals;
} }
} }

View File

@ -19,6 +19,7 @@ package org.apache.commons.math3.fitting.leastsquares;
import java.util.Arrays; import java.util.Arrays;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.exception.ConvergenceException; import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.util.LocalizedFormats; import org.apache.commons.math3.exception.util.LocalizedFormats;
@ -297,7 +298,7 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
//pull in relevant data from the problem as locals //pull in relevant data from the problem as locals
final int nR = problem.getObservationSize(); // Number of observed data. final int nR = problem.getObservationSize(); // Number of observed data.
final int nC = problem.getParameterSize(); // Number of parameters. final int nC = problem.getParameterSize(); // Number of parameters.
final double[] currentPoint = problem.getStart(); final double[] currentPoint = problem.getStart().toArray();
//counters //counters
final Incrementor iterationCounter = problem.getIterationCounter(); final Incrementor iterationCounter = problem.getIterationCounter();
final Incrementor evaluationCounter = problem.getEvaluationCounter(); final Incrementor evaluationCounter = problem.getEvaluationCounter();
@ -327,8 +328,8 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
// Evaluate the function at the starting point and calculate its norm. // Evaluate the function at the starting point and calculate its norm.
evaluationCounter.incrementCount(); evaluationCounter.incrementCount();
//value will be reassigned in the loop //value will be reassigned in the loop
Evaluation current = problem.evaluate(currentPoint); Evaluation current = problem.evaluate(new ArrayRealVector(currentPoint, false));
double[] currentResiduals = current.computeResiduals(); double[] currentResiduals = current.computeResiduals().toArray();
double currentCost = current.computeCost(); double currentCost = current.computeCost();
// Outer loop. // Outer loop.
@ -444,8 +445,8 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
// Evaluate the function at x + p and calculate its norm. // Evaluate the function at x + p and calculate its norm.
evaluationCounter.incrementCount(); evaluationCounter.incrementCount();
current = problem.evaluate(currentPoint); current = problem.evaluate(new ArrayRealVector(currentPoint,false));
currentResiduals = current.computeResiduals(); currentResiduals = current.computeResiduals().toArray();
currentCost = current.computeCost(); currentCost = current.computeCost();
// compute the scaled actual reduction // compute the scaled actual reduction

View File

@ -18,6 +18,6 @@ public interface MultivariateJacobianFunction {
* @param point the abscissae * @param point the abscissae
* @return the values and their Jacobian of this vector valued function. * @return the values and their Jacobian of this vector valued function.
*/ */
Pair<RealVector, RealMatrix> value(double[] point); Pair<RealVector, RealMatrix> value(RealVector point);
} }

View File

@ -3,6 +3,7 @@ package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
/** /**
* A pedantic implementation of {@link Optimum}. * A pedantic implementation of {@link Optimum}.
@ -44,12 +45,12 @@ class OptimumImpl implements Optimum {
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
public double[][] computeCovariances(double threshold) { public RealMatrix computeCovariances(double threshold) {
return value.computeCovariances(threshold); return value.computeCovariances(threshold);
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
public double[] computeSigma(double covarianceSingularityThreshold) { public RealVector computeSigma(double covarianceSingularityThreshold) {
return value.computeSigma(covarianceSingularityThreshold); return value.computeSigma(covarianceSingularityThreshold);
} }
@ -59,7 +60,7 @@ class OptimumImpl implements Optimum {
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
public double[] computeValue() { public RealVector computeValue() {
return value.computeValue(); return value.computeValue();
} }
@ -74,12 +75,12 @@ class OptimumImpl implements Optimum {
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
public double[] computeResiduals() { public RealVector computeResiduals() {
return value.computeResiduals(); return value.computeResiduals();
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
public double[] getPoint() { public RealVector getPoint() {
return value.getPoint(); return value.getPoint();
} }
} }

View File

@ -25,6 +25,7 @@ import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.linear.BlockRealMatrix; import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.DiagonalMatrix; import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.SimpleVectorValueChecker; import org.apache.commons.math3.optim.SimpleVectorValueChecker;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.junit.Assert; import org.junit.Assert;
@ -44,6 +45,9 @@ import java.util.Arrays;
*/ */
public abstract class AbstractLeastSquaresOptimizerAbstractTest { public abstract class AbstractLeastSquaresOptimizerAbstractTest {
/** default absolute tolerance of comparisons */
public static final double TOl = 1e-10;
public LeastSquaresBuilder base() { public LeastSquaresBuilder base() {
return new LeastSquaresBuilder() return new LeastSquaresBuilder()
.checkerPair(new SimpleVectorValueChecker(1e-6, 1e-6)) .checkerPair(new SimpleVectorValueChecker(1e-6, 1e-6))
@ -78,6 +82,19 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Assert.fail("Expected Exception from: " + optimizer.toString()); Assert.fail("Expected Exception from: " + optimizer.toString());
} }
/**
* Check the value of a vector.
* @param tolerance the absolute tolerance of comparisons
* @param actual the vector to test
* @param expected the expected values
*/
public void assertEquals(double tolerance, RealVector actual, double... expected){
for (int i = 0; i < expected.length; i++) {
Assert.assertEquals(expected[i], actual.getEntry(i), tolerance);
}
Assert.assertEquals(expected.length, actual.getDimension());
}
/** /**
* @return the default number of allowed iterations (which will be used when not * @return the default number of allowed iterations (which will be used when not
* specified otherwise). * specified otherwise).
@ -150,9 +167,9 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(ls); Optimum optimum = optimizer.optimize(ls);
Assert.assertEquals(0, optimum.computeRMS(), 1e-10); Assert.assertEquals(0, optimum.computeRMS(), TOl);
Assert.assertEquals(1.5, optimum.getPoint()[0], 1e-10); assertEquals(TOl, optimum.getPoint(), 1.5);
Assert.assertEquals(3.0, optimum.computeValue()[0], 1e-10); Assert.assertEquals(3.0, optimum.computeValue().getEntry(0), TOl);
} }
public void testQRColumnsPermutation(LeastSquaresOptimizer optimizer) { public void testQRColumnsPermutation(LeastSquaresOptimizer optimizer) {
@ -162,12 +179,9 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(problem.getBuilder().build()); Optimum optimum = optimizer.optimize(problem.getBuilder().build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10); Assert.assertEquals(0, optimum.computeRMS(), TOl);
Assert.assertEquals(7, optimum.getPoint()[0], 1e-10); assertEquals(TOl, optimum.getPoint(), 7, 3);
Assert.assertEquals(3, optimum.getPoint()[1], 1e-10); assertEquals(TOl, optimum.computeValue(), 4, 6, 1);
Assert.assertEquals(4, optimum.computeValue()[0], 1e-10);
Assert.assertEquals(6, optimum.computeValue()[1], 1e-10);
Assert.assertEquals(1, optimum.computeValue()[2], 1e-10);
} }
public void testNoDependency(LeastSquaresOptimizer optimizer) { public void testNoDependency(LeastSquaresOptimizer optimizer) {
@ -182,9 +196,9 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(problem.getBuilder().build()); Optimum optimum = optimizer.optimize(problem.getBuilder().build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10); Assert.assertEquals(0, optimum.computeRMS(), TOl);
for (int i = 0; i < problem.target.length; ++i) { for (int i = 0; i < problem.target.length; ++i) {
Assert.assertEquals(0.55 * i, optimum.getPoint()[i], 1e-10); Assert.assertEquals(0.55 * i, optimum.getPoint().getEntry(i), TOl);
} }
} }
@ -197,10 +211,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(problem.getBuilder().build()); Optimum optimum = optimizer.optimize(problem.getBuilder().build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10); Assert.assertEquals(0, optimum.computeRMS(), TOl);
Assert.assertEquals(1, optimum.getPoint()[0], 1e-10); assertEquals(TOl, optimum.getPoint(), 1, 2, 3);
Assert.assertEquals(2, optimum.getPoint()[1], 1e-10);
Assert.assertEquals(3, optimum.getPoint()[2], 1e-10);
} }
public void testTwoSets(LeastSquaresOptimizer optimizer) { public void testTwoSets(LeastSquaresOptimizer optimizer) {
@ -216,13 +228,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(problem.getBuilder().build()); Optimum optimum = optimizer.optimize(problem.getBuilder().build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10); Assert.assertEquals(0, optimum.computeRMS(), TOl);
Assert.assertEquals(3, optimum.getPoint()[0], 1e-10); assertEquals(TOl, optimum.getPoint(), 3, 4, -1, -2, 1 + epsilon, 1 - epsilon);
Assert.assertEquals(4, optimum.getPoint()[1], 1e-10);
Assert.assertEquals(-1, optimum.getPoint()[2], 1e-10);
Assert.assertEquals(-2, optimum.getPoint()[3], 1e-10);
Assert.assertEquals(1 + epsilon, optimum.getPoint()[4], 1e-10);
Assert.assertEquals(1 - epsilon, optimum.getPoint()[5], 1e-10);
} }
public void testNonInvertible(LeastSquaresOptimizer optimizer) throws Exception { public void testNonInvertible(LeastSquaresOptimizer optimizer) throws Exception {
@ -253,11 +260,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer Optimum optimum = optimizer
.optimize(problem1.getBuilder().start(start).build()); .optimize(problem1.getBuilder().start(start).build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10); Assert.assertEquals(0, optimum.computeRMS(), TOl);
Assert.assertEquals(1, optimum.getPoint()[0], 1e-10); assertEquals(TOl, optimum.getPoint(), 1, 1, 1, 1);
Assert.assertEquals(1, optimum.getPoint()[1], 1e-10);
Assert.assertEquals(1, optimum.getPoint()[2], 1e-10);
Assert.assertEquals(1, optimum.getPoint()[3], 1e-10);
LinearProblem problem2 = new LinearProblem(new double[][]{ LinearProblem problem2 = new LinearProblem(new double[][]{
{10.00, 7.00, 8.10, 7.20}, {10.00, 7.00, 8.10, 7.20},
@ -268,11 +272,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
optimum = optimizer.optimize(problem2.getBuilder().start(start).build()); optimum = optimizer.optimize(problem2.getBuilder().start(start).build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10); Assert.assertEquals(0, optimum.computeRMS(), TOl);
Assert.assertEquals(-81, optimum.getPoint()[0], 1e-8); assertEquals(1e-8, optimum.getPoint(), -81, 137, -34, 22);
Assert.assertEquals(137, optimum.getPoint()[1], 1e-8);
Assert.assertEquals(-34, optimum.getPoint()[2], 1e-8);
Assert.assertEquals(22, optimum.getPoint()[3], 1e-8);
} }
public void testMoreEstimatedParametersSimple(LeastSquaresOptimizer optimizer) { public void testMoreEstimatedParametersSimple(LeastSquaresOptimizer optimizer) {
@ -285,7 +286,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer Optimum optimum = optimizer
.optimize(problem.getBuilder().start(new double[]{7, 6, 5, 4}).build()); .optimize(problem.getBuilder().start(new double[]{7, 6, 5, 4}).build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10); Assert.assertEquals(0, optimum.computeRMS(), TOl);
} }
public void testMoreEstimatedParametersUnsorted(LeastSquaresOptimizer optimizer) { public void testMoreEstimatedParametersUnsorted(LeastSquaresOptimizer optimizer) {
@ -300,11 +301,9 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize( Optimum optimum = optimizer.optimize(
problem.getBuilder().start(new double[]{2, 2, 2, 2, 2, 2}).build()); problem.getBuilder().start(new double[]{2, 2, 2, 2, 2, 2}).build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10); Assert.assertEquals(0, optimum.computeRMS(), TOl);
Assert.assertEquals(3, optimum.getPoint()[2], 1e-10); //TODO the first two elements of point were not previously checked
Assert.assertEquals(4, optimum.getPoint()[3], 1e-10); assertEquals(TOl, optimum.getPoint(), 2, 1, 3, 4, 5, 6);
Assert.assertEquals(5, optimum.getPoint()[4], 1e-10);
Assert.assertEquals(6, optimum.getPoint()[5], 1e-10);
} }
public void testRedundantEquations(LeastSquaresOptimizer optimizer) { public void testRedundantEquations(LeastSquaresOptimizer optimizer) {
@ -317,9 +316,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer Optimum optimum = optimizer
.optimize(problem.getBuilder().start(new double[]{1, 1}).build()); .optimize(problem.getBuilder().start(new double[]{1, 1}).build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10); Assert.assertEquals(0, optimum.computeRMS(), TOl);
Assert.assertEquals(2, optimum.getPoint()[0], 1e-10); assertEquals(TOl, optimum.getPoint(), 2, 1);
Assert.assertEquals(1, optimum.getPoint()[1], 1e-10);
} }
public void testInconsistentEquations(LeastSquaresOptimizer optimizer) { public void testInconsistentEquations(LeastSquaresOptimizer optimizer) {
@ -346,9 +344,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
//TODO why is this part here? hasn't it been tested already? //TODO why is this part here? hasn't it been tested already?
Optimum optimum = optimizer.optimize(problem.getBuilder().build()); Optimum optimum = optimizer.optimize(problem.getBuilder().build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10); Assert.assertEquals(0, optimum.computeRMS(), TOl);
Assert.assertEquals(-1, optimum.getPoint()[0], 1e-10); assertEquals(TOl, optimum.getPoint(), -1, 1);
Assert.assertEquals(1, optimum.getPoint()[1], 1e-10);
//TODO move to builder test //TODO move to builder test
optimizer.optimize( optimizer.optimize(
@ -368,9 +365,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(problem.getBuilder().build()); Optimum optimum = optimizer.optimize(problem.getBuilder().build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10); Assert.assertEquals(0, optimum.computeRMS(), TOl);
Assert.assertEquals(-1, optimum.getPoint()[0], 1e-10); assertEquals(TOl, optimum.getPoint(), -1, 1);
Assert.assertEquals(1, optimum.getPoint()[1], 1e-10);
//TODO move to builder test //TODO move to builder test
optimizer.optimize( optimizer.optimize(
@ -400,14 +396,14 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Assert.assertTrue(optimum.getEvaluations() < 10); Assert.assertTrue(optimum.getEvaluations() < 10);
double rms = optimum.computeRMS(); double rms = optimum.computeRMS();
Assert.assertEquals(1.768262623567235, FastMath.sqrt(circle.getN()) * rms, 1e-10); Assert.assertEquals(1.768262623567235, FastMath.sqrt(circle.getN()) * rms, TOl);
Vector2D center = new Vector2D(optimum.getPoint()[0], optimum.getPoint()[1]); Vector2D center = new Vector2D(optimum.getPoint().getEntry(0), optimum.getPoint().getEntry(1));
Assert.assertEquals(69.96016176931406, circle.getRadius(center), 1e-6); Assert.assertEquals(69.96016176931406, circle.getRadius(center), 1e-6);
Assert.assertEquals(96.07590211815305, center.getX(), 1e-6); Assert.assertEquals(96.07590211815305, center.getX(), 1e-6);
Assert.assertEquals(48.13516790438953, center.getY(), 1e-6); Assert.assertEquals(48.13516790438953, center.getY(), 1e-6);
double[][] cov = optimum.computeCovariances(1e-14); double[][] cov = optimum.computeCovariances(1e-14).getData();
Assert.assertEquals(1.839, cov[0][0], 0.001); Assert.assertEquals(1.839, cov[0][0], 0.001);
Assert.assertEquals(0.731, cov[0][1], 0.001); Assert.assertEquals(0.731, cov[0][1], 0.001);
Assert.assertEquals(cov[0][1], cov[1][0], 1e-14); Assert.assertEquals(cov[0][1], cov[1][0], 1e-14);
@ -425,7 +421,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
optimum = optimizer.optimize( optimum = optimizer.optimize(
builder(circle).weight(new DiagonalMatrix(weights)).start(start).build()); builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
cov = optimum.computeCovariances(1e-14); cov = optimum.computeCovariances(1e-14).getData();
Assert.assertEquals(0.0016, cov[0][0], 0.001); Assert.assertEquals(0.0016, cov[0][0], 0.001);
Assert.assertEquals(3.2e-7, cov[0][1], 1e-9); Assert.assertEquals(3.2e-7, cov[0][1], 1e-9);
Assert.assertEquals(cov[0][1], cov[1][0], 1e-14); Assert.assertEquals(cov[0][1], cov[1][0], 1e-14);
@ -444,7 +440,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(builder(circle).weight(new DiagonalMatrix(weights)).start(start).build()); Optimum optimum = optimizer.optimize(builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
Vector2D center = new Vector2D(optimum.getPoint()[0], optimum.getPoint()[1]); Vector2D center = new Vector2D(optimum.getPoint().getEntry(0), optimum.getPoint().getEntry(1));
Assert.assertTrue(optimum.getEvaluations() < 25); Assert.assertTrue(optimum.getEvaluations() < 25);
Assert.assertEquals(0.043, optimum.computeRMS(), 1e-3); Assert.assertEquals(0.043, optimum.computeRMS(), 1e-3);
Assert.assertEquals(0.292235, circle.getRadius(center), 1e-6); Assert.assertEquals(0.292235, circle.getRadius(center), 1e-6);
@ -465,8 +461,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize( Optimum optimum = optimizer.optimize(
builder(circle).weight(new DiagonalMatrix(weights)).start(start).build()); builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
Assert.assertEquals(-0.1517383071957963, optimum.getPoint()[0], 1e-6); assertEquals(1e-6, optimum.getPoint(), -0.1517383071957963, 0.2074999736353867);
Assert.assertEquals(0.2074999736353867, optimum.getPoint()[1], 1e-6);
Assert.assertEquals(0.04268731682389561, optimum.computeRMS(), 1e-8); Assert.assertEquals(0.04268731682389561, optimum.computeRMS(), 1e-8);
} }
@ -509,12 +504,12 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
final Optimum optimum = optimizer.optimize(builder(dataset).build()); final Optimum optimum = optimizer.optimize(builder(dataset).build());
final double[] actual = optimum.getPoint(); final RealVector actual = optimum.getPoint();
for (int i = 0; i < actual.length; i++) { for (int i = 0; i < actual.getDimension(); i++) {
double expected = dataset.getParameter(i); double expected = dataset.getParameter(i);
double delta = FastMath.abs(errParams * expected); double delta = FastMath.abs(errParams * expected);
Assert.assertEquals(dataset.getName() + ", param #" + i, Assert.assertEquals(dataset.getName() + ", param #" + i,
expected, actual[i], delta); expected, actual.getEntry(i), delta);
} }
} }

View File

@ -15,6 +15,7 @@ package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.DiagonalMatrix; import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
@ -87,10 +88,10 @@ public class EvaluationTest {
final Evaluation evaluation = lsp.evaluate(lsp.getStart()); final Evaluation evaluation = lsp.evaluate(lsp.getStart());
final double cost = evaluation.computeCost(); final double cost = evaluation.computeCost();
final double[] sig = evaluation.computeSigma(1e-14); final RealVector sig = evaluation.computeSigma(1e-14);
final int dof = lsp.getObservationSize() - lsp.getParameterSize(); final int dof = lsp.getObservationSize() - lsp.getParameterSize();
for (int i = 0; i < sig.length; i++) { for (int i = 0; i < sig.getDimension(); i++) {
final double actual = FastMath.sqrt(cost * cost / dof) * sig[i]; final double actual = FastMath.sqrt(cost * cost / dof) * sig.getEntry(i);
Assert.assertEquals(dataset.getName() + ", parameter #" + i, Assert.assertEquals(dataset.getName() + ", parameter #" + i,
expected[i], actual, 1e-6 * expected[i]); expected[i], actual, 1e-6 * expected[i]);
} }

View File

@ -13,17 +13,18 @@
*/ */
package org.apache.commons.math3.fitting.leastsquares; package org.apache.commons.math3.fitting.leastsquares;
import java.util.Arrays; import org.apache.commons.math3.linear.ArrayRealVector;
import java.util.List;
import java.util.ArrayList;
import java.awt.geom.Point2D;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.commons.math3.stat.descriptive.StatisticalSummary;
import org.apache.commons.math3.linear.DiagonalMatrix; import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.stat.descriptive.StatisticalSummary;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.junit.Test;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test;
import java.awt.geom.Point2D;
import java.util.ArrayList;
import java.util.List;
/** /**
* This class demonstrates the main functionality of the * This class demonstrates the main functionality of the
@ -95,7 +96,7 @@ public class EvaluationTestValidation {
sigmaEstimate[i] = new SummaryStatistics(); sigmaEstimate[i] = new SummaryStatistics();
} }
final double[] init = { slope, offset }; final RealVector init = new ArrayRealVector(new double[]{ slope, offset }, false);
// Monte-Carlo (generates many sets of observations). // Monte-Carlo (generates many sets of observations).
final int mcRepeat = MONTE_CARLO_RUNS; final int mcRepeat = MONTE_CARLO_RUNS;
@ -117,12 +118,12 @@ public class EvaluationTestValidation {
// covariance matrix). // covariance matrix).
final LeastSquaresProblem lsp = builder(problem).build(); final LeastSquaresProblem lsp = builder(problem).build();
final double[] sigma = lsp.evaluate(init).computeSigma(1e-14); final RealVector sigma = lsp.evaluate(init).computeSigma(1e-14);
// Accumulate statistics. // Accumulate statistics.
for (int i = 0; i < numParams; i++) { for (int i = 0; i < numParams; i++) {
paramsFoundByDirectSolution[i].addValue(regress[i]); paramsFoundByDirectSolution[i].addValue(regress[i]);
sigmaEstimate[i].addValue(sigma[i]); sigmaEstimate[i].addValue(sigma.getEntry(i));
} }
// Next Monte-Carlo. // Next Monte-Carlo.
@ -138,7 +139,7 @@ public class EvaluationTestValidation {
StatisticalSummary s = paramsFoundByDirectSolution[i].getSummary(); StatisticalSummary s = paramsFoundByDirectSolution[i].getSummary();
System.out.printf(" %+.6e %+.6e %+.6e\n", System.out.printf(" %+.6e %+.6e %+.6e\n",
init[i], init.getEntry(i),
s.getMean(), s.getMean(),
s.getStandardDeviation()); s.getStandardDeviation());
@ -212,7 +213,7 @@ public class EvaluationTestValidation {
} }
// Direct solution (using simple regression). // Direct solution (using simple regression).
final double[] regress = problem.solve(); final RealVector regress = new ArrayRealVector(problem.solve(), false);
// Dummy optimizer (to compute the chi-square). // Dummy optimizer (to compute the chi-square).
final LeastSquaresProblem lsp = builder(problem).build(); final LeastSquaresProblem lsp = builder(problem).build();
@ -221,7 +222,7 @@ public class EvaluationTestValidation {
// Get chi-square of the best parameters set for the given set of // Get chi-square of the best parameters set for the given set of
// observations. // observations.
final double bestChi2N = getChi2N(lsp, regress); final double bestChi2N = getChi2N(lsp, regress);
final double[] sigma = lsp.evaluate(regress).computeSigma(1e-14); final RealVector sigma = lsp.evaluate(regress).computeSigma(1e-14);
// Monte-Carlo (generates a grid of parameters). // Monte-Carlo (generates a grid of parameters).
final int mcRepeat = MONTE_CARLO_RUNS; final int mcRepeat = MONTE_CARLO_RUNS;
@ -233,8 +234,8 @@ public class EvaluationTestValidation {
// Index 2 = normalized chi2 // Index 2 = normalized chi2
final List<double[]> paramsAndChi2 = new ArrayList<double[]>(gridSize * gridSize); final List<double[]> paramsAndChi2 = new ArrayList<double[]>(gridSize * gridSize);
final double slopeRange = 10 * sigma[0]; final double slopeRange = 10 * sigma.getEntry(0);
final double offsetRange = 10 * sigma[1]; final double offsetRange = 10 * sigma.getEntry(1);
final double minSlope = slope - 0.5 * slopeRange; final double minSlope = slope - 0.5 * slopeRange;
final double minOffset = offset - 0.5 * offsetRange; final double minOffset = offset - 0.5 * offsetRange;
final double deltaSlope = slopeRange/ gridSize; final double deltaSlope = slopeRange/ gridSize;
@ -243,7 +244,8 @@ public class EvaluationTestValidation {
final double s = minSlope + i * deltaSlope; final double s = minSlope + i * deltaSlope;
for (int j = 0; j < gridSize; j++) { for (int j = 0; j < gridSize; j++) {
final double o = minOffset + j * deltaOffset; final double o = minOffset + j * deltaOffset;
final double chi2N = getChi2N(lsp, new double[] {s, o}); final double chi2N = getChi2N(lsp,
new ArrayRealVector(new double[] {s, o}, false));
paramsAndChi2.add(new double[] {s, o, chi2N}); paramsAndChi2.add(new double[] {s, o, chi2N});
} }
@ -260,7 +262,7 @@ public class EvaluationTestValidation {
final String lineFmt = "%+.10e %+.10e %.8e\n"; final String lineFmt = "%+.10e %+.10e %.8e\n";
// Point with smallest chi-square. // Point with smallest chi-square.
System.out.printf(lineFmt, regress[0], regress[1], bestChi2N); System.out.printf(lineFmt, regress.getEntry(0), regress.getEntry(1), bestChi2N);
System.out.println(); // Empty line. System.out.println(); // Empty line.
// Points within the confidence interval. // Points within the confidence interval.
@ -280,7 +282,7 @@ public class EvaluationTestValidation {
} }
System.out.println(); // Empty line. System.out.println(); // Empty line.
System.out.println("# sigma=" + Arrays.toString(sigma)); System.out.println("# sigma=" + sigma.toString());
System.out.println("# " + numLarger + " sets filtered out"); System.out.println("# " + numLarger + " sets filtered out");
} }
@ -289,15 +291,17 @@ public class EvaluationTestValidation {
.model(problem.getModelFunction()) .model(problem.getModelFunction())
.jacobian(problem.getModelFunctionJacobian()) .jacobian(problem.getModelFunctionJacobian())
.target(problem.target()) .target(problem.target())
.weight(new DiagonalMatrix(problem.weight())); .weight(new DiagonalMatrix(problem.weight()))
//unused start point to avoid NPE
.start(new double[2]);
} }
/** /**
* @return the normalized chi-square. * @return the normalized chi-square.
*/ */
private double getChi2N(LeastSquaresProblem lsp, private double getChi2N(LeastSquaresProblem lsp,
double[] params) { RealVector params) {
final double cost = lsp.evaluate(params).computeCost(); final double cost = lsp.evaluate(params).computeCost();
return cost * cost / (lsp.getObservationSize() - params.length); return cost * cost / (lsp.getObservationSize() - params.getDimension());
} }
} }

View File

@ -24,6 +24,8 @@ import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D; import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.linear.DiagonalMatrix; import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException; import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Precision; import org.apache.commons.math3.util.Precision;
@ -201,10 +203,10 @@ public class LevenbergMarquardtOptimizerTest
.build() .build()
); );
final double[] solution = optimum.getPoint(); final RealVector solution = optimum.getPoint();
final double[] expectedSolution = { 10.4, 958.3, 131.4, 33.9, 205.0 }; final double[] expectedSolution = { 10.4, 958.3, 131.4, 33.9, 205.0 };
final double[][] covarMatrix = optimum.computeCovariances(1e-14); final RealMatrix covarMatrix = optimum.computeCovariances(1e-14);
final double[][] expectedCovarMatrix = { final double[][] expectedCovarMatrix = {
{ 3.38, -3.69, 27.98, -2.34, -49.24 }, { 3.38, -3.69, 27.98, -2.34, -49.24 },
{ -3.69, 2492.26, 81.89, -69.21, -8.9 }, { -3.69, 2492.26, 81.89, -69.21, -8.9 },
@ -218,7 +220,7 @@ public class LevenbergMarquardtOptimizerTest
// Check that the computed solution is within the reference error range. // Check that the computed solution is within the reference error range.
for (int i = 0; i < numParams; i++) { for (int i = 0; i < numParams; i++) {
final double error = FastMath.sqrt(expectedCovarMatrix[i][i]); final double error = FastMath.sqrt(expectedCovarMatrix[i][i]);
Assert.assertEquals("Parameter " + i, expectedSolution[i], solution[i], error); Assert.assertEquals("Parameter " + i, expectedSolution[i], solution.getEntry(i), error);
} }
// Check that each entry of the computed covariance matrix is within 10% // Check that each entry of the computed covariance matrix is within 10%
@ -227,7 +229,7 @@ public class LevenbergMarquardtOptimizerTest
for (int j = 0; j < numParams; j++) { for (int j = 0; j < numParams; j++) {
Assert.assertEquals("Covariance matrix [" + i + "][" + j + "]", Assert.assertEquals("Covariance matrix [" + i + "][" + j + "]",
expectedCovarMatrix[i][j], expectedCovarMatrix[i][j],
covarMatrix[i][j], covarMatrix.getEntry(i, j),
FastMath.abs(0.1 * expectedCovarMatrix[i][j])); FastMath.abs(0.1 * expectedCovarMatrix[i][j]));
} }
} }
@ -258,10 +260,10 @@ public class LevenbergMarquardtOptimizerTest
final Optimum optimum = optimizer.optimize( final Optimum optimum = optimizer.optimize(
builder(circle).maxIterations(50).start(init).build()); builder(circle).maxIterations(50).start(init).build());
final double[] paramFound = optimum.getPoint(); final double[] paramFound = optimum.getPoint().toArray();
// Retrieve errors estimation. // Retrieve errors estimation.
final double[] asymptoticStandardErrorFound = optimum.computeSigma(1e-14); final double[] asymptoticStandardErrorFound = optimum.computeSigma(1e-14).toArray();
// Check that the parameters are found within the assumed error bars. // Check that the parameters are found within the assumed error bars.
Assert.assertEquals(xCenter, paramFound[0], asymptoticStandardErrorFound[0]); Assert.assertEquals(xCenter, paramFound[0], asymptoticStandardErrorFound[0]);

View File

@ -518,7 +518,7 @@ public class MinpackTest {
final Optimum optimum = optimizer.optimize(problem); final Optimum optimum = optimizer.optimize(problem);
Assert.assertFalse(exceptionExpected); Assert.assertFalse(exceptionExpected);
function.checkTheoreticalMinCost(optimum.computeRMS()); function.checkTheoreticalMinCost(optimum.computeRMS());
function.checkTheoreticalMinParams(optimum.getPoint()); function.checkTheoreticalMinParams(optimum.getPoint().toArray());
} catch (TooManyEvaluationsException e) { } catch (TooManyEvaluationsException e) {
Assert.assertTrue(exceptionExpected); Assert.assertTrue(exceptionExpected);
} }