Use Real{Vector,Matrix} in LeastSquares interfaces

Covered all of the interfaces in the leastsquares package to use RealVector and
RealMatrix instead of double[] and double[][]. This reduced some duplicated code.
For example Evaluation.computeResiduals() was a complete duplication of
RealVector.subtract(). It also presents a consistent interface and allows data
encapsulation.

Lastly, this change enables [math] to "eat our own dog food." It enables the
linear package to be used in the implementation of the optimization algorithms.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569354 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2014-02-18 14:32:54 +00:00
parent a7a380f934
commit 0079828734
15 changed files with 202 additions and 171 deletions

View File

@ -5,6 +5,7 @@ import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.FastMath;
/**
@ -31,7 +32,7 @@ abstract class AbstractEvaluation implements Evaluation {
}
/** {@inheritDoc} */
public double[][] computeCovariances(double threshold) {
public RealMatrix computeCovariances(double threshold) {
// Set up the Jacobian.
final RealMatrix j = this.computeJacobian();
@ -41,16 +42,16 @@ abstract class AbstractEvaluation implements Evaluation {
// Compute the covariances matrix.
final DecompositionSolver solver
= new QRDecomposition(jTj, threshold).getSolver();
return solver.getInverse().getData();
return solver.getInverse();
}
/** {@inheritDoc} */
public double[] computeSigma(double covarianceSingularityThreshold) {
final double[][] cov = this.computeCovariances(covarianceSingularityThreshold);
final int nC = cov.length;
final double[] sig = new double[nC];
public RealVector computeSigma(double covarianceSingularityThreshold) {
final RealMatrix cov = this.computeCovariances(covarianceSingularityThreshold);
final int nC = cov.getColumnDimension();
final RealVector sig = new ArrayRealVector(nC);
for (int i = 0; i < nC; ++i) {
sig[i] = FastMath.sqrt(cov[i][i]);
sig.setEntry(i, FastMath.sqrt(cov.getEntry(i,i)));
}
return sig;
}

View File

@ -2,6 +2,7 @@ package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
/**
* Applies a dense weight matrix to an evaluation.
@ -37,19 +38,19 @@ class DenseWeightedEvaluation extends AbstractEvaluation {
}
/** {@inheritDoc} */
public double[] computeResiduals() {
public RealVector computeResiduals() {
return this.weightSqrt.operate(this.unweighted.computeResiduals());
}
/* delegate */
/** {@inheritDoc} */
public double[] getPoint() {
public RealVector getPoint() {
return unweighted.getPoint();
}
/** {@inheritDoc} */
public double[] computeValue() {
public RealVector computeValue() {
return unweighted.computeValue();
}
}

View File

@ -26,6 +26,7 @@ import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.util.Incrementor;
@ -133,7 +134,7 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
final int nR = lsp.getObservationSize(); // Number of observed data.
final int nC = lsp.getParameterSize();
final double[] currentPoint = lsp.getStart();
final RealVector currentPoint = lsp.getStart();
// iterate until convergence is reached
Evaluation current = null;
@ -145,7 +146,7 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
// Value of the objective function at "currentPoint".
evaluationCounter.incrementCount();
current = lsp.evaluate(currentPoint);
final double[] currentResiduals = current.computeResiduals();
final RealVector currentResiduals = current.computeResiduals();
final RealMatrix weightedJacobian = current.computeJacobian();
// Check convergence.
@ -164,7 +165,7 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
for (int i = 0; i < nR; ++i) {
final double[] grad = weightedJacobian.getRow(i);
final double residual = currentResiduals[i];
final double residual = currentResiduals.getEntry(i);
// compute the normal equation
//residual is already weighted
@ -186,10 +187,10 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
// solve the linearized least squares problem
RealMatrix mA = new BlockRealMatrix(a);
DecompositionSolver solver = this.decomposition.getSolver(mA);
final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
final RealVector dX = solver.solve(new ArrayRealVector(b, false));
// update the estimated parameters
for (int i = 0; i < nC; ++i) {
currentPoint[i] += dX[i];
currentPoint.setEntry(i, currentPoint.getEntry(i) + dX.getEntry(i));
}
} catch (SingularMatrixException e) {
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);

View File

@ -1,5 +1,6 @@
package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.util.Incrementor;
@ -23,7 +24,7 @@ public class LeastSquaresAdapter implements LeastSquaresProblem {
}
/** {@inheritDoc} */
public double[] getStart() {
public RealVector getStart() {
return problem.getStart();
}
@ -37,8 +38,9 @@ public class LeastSquaresAdapter implements LeastSquaresProblem {
return problem.getParameterSize();
}
/** {@inheritDoc} */
public Evaluation evaluate(final double[] point) {
/** {@inheritDoc}
* @param point*/
public Evaluation evaluate(final RealVector point) {
return problem.evaluate(point);
}

View File

@ -26,6 +26,34 @@ public class LeastSquaresFactory {
private LeastSquaresFactory() {
}
/**
* Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem}
* from the given elements. There will be no weights applied (Identity weights).
*
* @param model the model function. Produces the computed values.
* @param observed the observed (target) values
* @param start the initial guess.
* @param checker convergence checker
* @param maxEvaluations the maximum number of times to evaluate the model
* @param maxIterations the maximum number to times to iterate in the algorithm
* @return the specified General Least Squares problem.
*/
public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
final RealVector observed,
final RealVector start,
final ConvergenceChecker<Evaluation> checker,
final int maxEvaluations,
final int maxIterations) {
return new LeastSquaresProblemImpl(
model,
observed,
start,
checker,
maxEvaluations,
maxIterations
);
}
/**
* Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem}
* from the given elements. There will be no weights applied (Identity weights).
@ -44,10 +72,10 @@ public class LeastSquaresFactory {
final ConvergenceChecker<Evaluation> checker,
final int maxEvaluations,
final int maxIterations) {
return new LeastSquaresProblemImpl(
return create(
model,
observed,
start,
new ArrayRealVector(observed, false),
new ArrayRealVector(start, false),
checker,
maxEvaluations,
maxIterations
@ -132,7 +160,7 @@ public class LeastSquaresFactory {
final RealMatrix weightSquareRoot = squareRoot(weights);
return new LeastSquaresAdapter(problem) {
@Override
public Evaluation evaluate(final double[] point) {
public Evaluation evaluate(final RealVector point) {
return new DenseWeightedEvaluation(super.evaluate(point), weightSquareRoot);
}
};
@ -154,7 +182,7 @@ public class LeastSquaresFactory {
/**
* Count the evaluations of a particular problem. The {@code counter} will be
* incremented every time {@link LeastSquaresProblem#evaluate(double[])} is called on
* incremented every time {@link LeastSquaresProblem#evaluate(RealVector)} is called on
* the <em>returned</em> problem.
*
* @param problem the problem to track.
@ -165,7 +193,7 @@ public class LeastSquaresFactory {
final Incrementor counter) {
return new LeastSquaresAdapter(problem) {
public Evaluation evaluate(final double[] point) {
public Evaluation evaluate(final RealVector point) {
counter.incrementCount();
return super.evaluate(point);
}
@ -192,12 +220,12 @@ public class LeastSquaresFactory {
return checker.converged(
iteration,
new PointVectorValuePair(
previous.getPoint(),
previous.computeValue(),
previous.getPoint().toArray(),
previous.computeValue().toArray(),
false),
new PointVectorValuePair(
current.getPoint(),
current.computeValue(),
current.getPoint().toArray(),
current.computeValue().toArray(),
false)
);
}
@ -237,11 +265,13 @@ public class LeastSquaresFactory {
final MultivariateMatrixFunction jacobian
) {
return new MultivariateJacobianFunction() {
public Pair<RealVector, RealMatrix> value(final double[] point) {
//evaluate and use Real* interfaces without copying
public Pair<RealVector, RealMatrix> value(final RealVector point) {
//TODO get array from RealVector without copying?
final double[] pointArray = point.toArray();
//evaluate and return data without copying
return new Pair<RealVector, RealMatrix>(
new ArrayRealVector(value.value(point), false),
new Array2DRowRealMatrix(jacobian.value(point), false));
new ArrayRealVector(value.value(pointArray), false),
new Array2DRowRealMatrix(jacobian.value(pointArray), false));
}
};
}

View File

@ -3,6 +3,7 @@ package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
/**
* The data necessary to define a non-linear least squares problem. Includes the observed
@ -19,7 +20,7 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
*
* @return the initial guess values.
*/
double[] getStart();
RealVector getStart();
/**
* Get the number of observations (rows in the Jacobian) in this problem.
@ -38,13 +39,14 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
/**
* Evaluate the model at the specified point.
*
*
* @param point the parameter values.
* @return the model's value and derivative at the given point.
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
* if the maximal number of evaluations (of the model vector function) is
* exceeded.
*/
Evaluation evaluate(double[] point);
Evaluation evaluate(RealVector point);
/**
* An evaluation of a {@link LeastSquaresProblem} at a particular point. This class
@ -59,12 +61,13 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
* way for the caller to specify that the result of this computation should be
* considered meaningless, and thus trigger an exception.
*
*
* @param threshold Singularity threshold.
* @return the covariance matrix.
* @throws org.apache.commons.math3.linear.SingularMatrixException
* if the covariance matrix cannot be computed (singular problem).
*/
double[][] computeCovariances(double threshold);
RealMatrix computeCovariances(double threshold);
/**
* Computes an estimate of the standard deviation of the parameters. The returned
@ -72,13 +75,14 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
* matrix, {@code sd(a[i]) ~= sqrt(C[i][i])}, where {@code a[i]} is the optimized
* value of the {@code i}-th parameter, and {@code C} is the covariance matrix.
*
*
* @param covarianceSingularityThreshold Singularity threshold (see {@link
* #computeCovariances(double) computeCovariances}).
* @return an estimate of the standard deviation of the optimized parameters
* @throws org.apache.commons.math3.linear.SingularMatrixException
* if the covariance matrix cannot be computed.
*/
double[] computeSigma(double covarianceSingularityThreshold);
RealVector computeSigma(double covarianceSingularityThreshold);
/**
* Computes the normalized cost. It is the square-root of the sum of squared of
@ -93,7 +97,7 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
*
* @return the objective function value at the specified point.
*/
double[] computeValue();
RealVector computeValue();
/**
* Computes the weighted Jacobian matrix.
@ -121,13 +125,13 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
* @return the weighted residuals: W<sup>1/2</sup> K.
* @throws DimensionMismatchException if the residuals have the wrong length.
*/
double[] computeResiduals();
RealVector computeResiduals();
/**
* Get the abscissa (independent variables) of this evaluation.
*
* @return the point provided to {@link #evaluate(double[])}.
* @return the point provided to {@link #evaluate(RealVector)}.
*/
double[] getPoint();
RealVector getPoint();
}
}

View File

@ -16,7 +16,6 @@
*/
package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
@ -36,15 +35,15 @@ class LeastSquaresProblemImpl
implements LeastSquaresProblem {
/** Target values for the model function at optimum. */
private double[] target;
private RealVector target;
/** Model function. */
private MultivariateJacobianFunction model;
/** Initial guess. */
private double[] start;
private RealVector start;
LeastSquaresProblemImpl(final MultivariateJacobianFunction model,
final double[] target,
final double[] start,
final RealVector target,
final RealVector start,
final ConvergenceChecker<Evaluation> checker,
final int maxEvaluations,
final int maxIterations) {
@ -55,18 +54,18 @@ class LeastSquaresProblemImpl
}
public int getObservationSize() {
return target.length;
return target.getDimension();
}
public int getParameterSize() {
return start.length;
return start.getDimension();
}
public double[] getStart() {
return start == null ? null : start.clone();
public RealVector getStart() {
return start == null ? null : start.copy();
}
public Evaluation evaluate(final double[] point) {
public Evaluation evaluate(final RealVector point) {
//evaluate value and jacobian in one function call
final Pair<RealVector, RealMatrix> value = this.model.value(point);
return new UnweightedEvaluation(
@ -84,19 +83,19 @@ class LeastSquaresProblemImpl
private static class UnweightedEvaluation extends AbstractEvaluation {
/** the point of evaluation */
private final double[] point;
private final RealVector point;
/** value at point */
private final RealVector values;
/** deriviative at point */
private final RealMatrix jacobian;
/** reference to the observed values */
private final double[] target;
private final RealVector target;
private UnweightedEvaluation(final RealVector values,
final RealMatrix jacobian,
final double[] target,
final double[] point) {
super(target.length);
final RealVector target,
final RealVector point) {
super(target.getDimension());
this.values = values;
this.jacobian = jacobian;
this.target = target;
@ -104,31 +103,20 @@ class LeastSquaresProblemImpl
}
public double[] computeValue() {
return this.values.toArray();
public RealVector computeValue() {
return this.values;
}
public RealMatrix computeJacobian() {
return this.jacobian;
}
public double[] getPoint() {
public RealVector getPoint() {
return this.point;
}
public double[] computeResiduals() {
final double[] objectiveValue = this.computeValue();
if (objectiveValue.length != target.length) {
throw new DimensionMismatchException(target.length,
objectiveValue.length);
}
final double[] residuals = new double[target.length];
for (int i = 0; i < target.length; i++) {
residuals[i] = target[i] - objectiveValue[i];
}
return residuals;
public RealVector computeResiduals() {
return target.subtract(this.computeValue());
}
}

View File

@ -19,6 +19,7 @@ package org.apache.commons.math3.fitting.leastsquares;
import java.util.Arrays;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
@ -297,7 +298,7 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
//pull in relevant data from the problem as locals
final int nR = problem.getObservationSize(); // Number of observed data.
final int nC = problem.getParameterSize(); // Number of parameters.
final double[] currentPoint = problem.getStart();
final double[] currentPoint = problem.getStart().toArray();
//counters
final Incrementor iterationCounter = problem.getIterationCounter();
final Incrementor evaluationCounter = problem.getEvaluationCounter();
@ -327,8 +328,8 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
// Evaluate the function at the starting point and calculate its norm.
evaluationCounter.incrementCount();
//value will be reassigned in the loop
Evaluation current = problem.evaluate(currentPoint);
double[] currentResiduals = current.computeResiduals();
Evaluation current = problem.evaluate(new ArrayRealVector(currentPoint, false));
double[] currentResiduals = current.computeResiduals().toArray();
double currentCost = current.computeCost();
// Outer loop.
@ -444,8 +445,8 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
// Evaluate the function at x + p and calculate its norm.
evaluationCounter.incrementCount();
current = problem.evaluate(currentPoint);
currentResiduals = current.computeResiduals();
current = problem.evaluate(new ArrayRealVector(currentPoint,false));
currentResiduals = current.computeResiduals().toArray();
currentCost = current.computeCost();
// compute the scaled actual reduction

View File

@ -18,6 +18,6 @@ public interface MultivariateJacobianFunction {
* @param point the abscissae
* @return the values and their Jacobian of this vector valued function.
*/
Pair<RealVector, RealMatrix> value(double[] point);
Pair<RealVector, RealMatrix> value(RealVector point);
}

View File

@ -3,6 +3,7 @@ package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
/**
* A pedantic implementation of {@link Optimum}.
@ -44,12 +45,12 @@ class OptimumImpl implements Optimum {
}
/** {@inheritDoc} */
public double[][] computeCovariances(double threshold) {
public RealMatrix computeCovariances(double threshold) {
return value.computeCovariances(threshold);
}
/** {@inheritDoc} */
public double[] computeSigma(double covarianceSingularityThreshold) {
public RealVector computeSigma(double covarianceSingularityThreshold) {
return value.computeSigma(covarianceSingularityThreshold);
}
@ -59,7 +60,7 @@ class OptimumImpl implements Optimum {
}
/** {@inheritDoc} */
public double[] computeValue() {
public RealVector computeValue() {
return value.computeValue();
}
@ -74,12 +75,12 @@ class OptimumImpl implements Optimum {
}
/** {@inheritDoc} */
public double[] computeResiduals() {
public RealVector computeResiduals() {
return value.computeResiduals();
}
/** {@inheritDoc} */
public double[] getPoint() {
public RealVector getPoint() {
return value.getPoint();
}
}

View File

@ -25,6 +25,7 @@ import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.SimpleVectorValueChecker;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
@ -44,6 +45,9 @@ import java.util.Arrays;
*/
public abstract class AbstractLeastSquaresOptimizerAbstractTest {
/** default absolute tolerance of comparisons */
public static final double TOl = 1e-10;
public LeastSquaresBuilder base() {
return new LeastSquaresBuilder()
.checkerPair(new SimpleVectorValueChecker(1e-6, 1e-6))
@ -78,6 +82,19 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Assert.fail("Expected Exception from: " + optimizer.toString());
}
/**
* Check the value of a vector.
* @param tolerance the absolute tolerance of comparisons
* @param actual the vector to test
* @param expected the expected values
*/
public void assertEquals(double tolerance, RealVector actual, double... expected){
for (int i = 0; i < expected.length; i++) {
Assert.assertEquals(expected[i], actual.getEntry(i), tolerance);
}
Assert.assertEquals(expected.length, actual.getDimension());
}
/**
* @return the default number of allowed iterations (which will be used when not
* specified otherwise).
@ -150,9 +167,9 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(ls);
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
Assert.assertEquals(1.5, optimum.getPoint()[0], 1e-10);
Assert.assertEquals(3.0, optimum.computeValue()[0], 1e-10);
Assert.assertEquals(0, optimum.computeRMS(), TOl);
assertEquals(TOl, optimum.getPoint(), 1.5);
Assert.assertEquals(3.0, optimum.computeValue().getEntry(0), TOl);
}
public void testQRColumnsPermutation(LeastSquaresOptimizer optimizer) {
@ -162,12 +179,9 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
Assert.assertEquals(7, optimum.getPoint()[0], 1e-10);
Assert.assertEquals(3, optimum.getPoint()[1], 1e-10);
Assert.assertEquals(4, optimum.computeValue()[0], 1e-10);
Assert.assertEquals(6, optimum.computeValue()[1], 1e-10);
Assert.assertEquals(1, optimum.computeValue()[2], 1e-10);
Assert.assertEquals(0, optimum.computeRMS(), TOl);
assertEquals(TOl, optimum.getPoint(), 7, 3);
assertEquals(TOl, optimum.computeValue(), 4, 6, 1);
}
public void testNoDependency(LeastSquaresOptimizer optimizer) {
@ -182,9 +196,9 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
Assert.assertEquals(0, optimum.computeRMS(), TOl);
for (int i = 0; i < problem.target.length; ++i) {
Assert.assertEquals(0.55 * i, optimum.getPoint()[i], 1e-10);
Assert.assertEquals(0.55 * i, optimum.getPoint().getEntry(i), TOl);
}
}
@ -197,10 +211,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
Assert.assertEquals(1, optimum.getPoint()[0], 1e-10);
Assert.assertEquals(2, optimum.getPoint()[1], 1e-10);
Assert.assertEquals(3, optimum.getPoint()[2], 1e-10);
Assert.assertEquals(0, optimum.computeRMS(), TOl);
assertEquals(TOl, optimum.getPoint(), 1, 2, 3);
}
public void testTwoSets(LeastSquaresOptimizer optimizer) {
@ -216,13 +228,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
Assert.assertEquals(3, optimum.getPoint()[0], 1e-10);
Assert.assertEquals(4, optimum.getPoint()[1], 1e-10);
Assert.assertEquals(-1, optimum.getPoint()[2], 1e-10);
Assert.assertEquals(-2, optimum.getPoint()[3], 1e-10);
Assert.assertEquals(1 + epsilon, optimum.getPoint()[4], 1e-10);
Assert.assertEquals(1 - epsilon, optimum.getPoint()[5], 1e-10);
Assert.assertEquals(0, optimum.computeRMS(), TOl);
assertEquals(TOl, optimum.getPoint(), 3, 4, -1, -2, 1 + epsilon, 1 - epsilon);
}
public void testNonInvertible(LeastSquaresOptimizer optimizer) throws Exception {
@ -253,11 +260,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer
.optimize(problem1.getBuilder().start(start).build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
Assert.assertEquals(1, optimum.getPoint()[0], 1e-10);
Assert.assertEquals(1, optimum.getPoint()[1], 1e-10);
Assert.assertEquals(1, optimum.getPoint()[2], 1e-10);
Assert.assertEquals(1, optimum.getPoint()[3], 1e-10);
Assert.assertEquals(0, optimum.computeRMS(), TOl);
assertEquals(TOl, optimum.getPoint(), 1, 1, 1, 1);
LinearProblem problem2 = new LinearProblem(new double[][]{
{10.00, 7.00, 8.10, 7.20},
@ -268,11 +272,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
optimum = optimizer.optimize(problem2.getBuilder().start(start).build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
Assert.assertEquals(-81, optimum.getPoint()[0], 1e-8);
Assert.assertEquals(137, optimum.getPoint()[1], 1e-8);
Assert.assertEquals(-34, optimum.getPoint()[2], 1e-8);
Assert.assertEquals(22, optimum.getPoint()[3], 1e-8);
Assert.assertEquals(0, optimum.computeRMS(), TOl);
assertEquals(1e-8, optimum.getPoint(), -81, 137, -34, 22);
}
public void testMoreEstimatedParametersSimple(LeastSquaresOptimizer optimizer) {
@ -285,7 +286,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer
.optimize(problem.getBuilder().start(new double[]{7, 6, 5, 4}).build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
Assert.assertEquals(0, optimum.computeRMS(), TOl);
}
public void testMoreEstimatedParametersUnsorted(LeastSquaresOptimizer optimizer) {
@ -300,11 +301,9 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(
problem.getBuilder().start(new double[]{2, 2, 2, 2, 2, 2}).build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
Assert.assertEquals(3, optimum.getPoint()[2], 1e-10);
Assert.assertEquals(4, optimum.getPoint()[3], 1e-10);
Assert.assertEquals(5, optimum.getPoint()[4], 1e-10);
Assert.assertEquals(6, optimum.getPoint()[5], 1e-10);
Assert.assertEquals(0, optimum.computeRMS(), TOl);
//TODO the first two elements of point were not previously checked
assertEquals(TOl, optimum.getPoint(), 2, 1, 3, 4, 5, 6);
}
public void testRedundantEquations(LeastSquaresOptimizer optimizer) {
@ -317,9 +316,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer
.optimize(problem.getBuilder().start(new double[]{1, 1}).build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
Assert.assertEquals(2, optimum.getPoint()[0], 1e-10);
Assert.assertEquals(1, optimum.getPoint()[1], 1e-10);
Assert.assertEquals(0, optimum.computeRMS(), TOl);
assertEquals(TOl, optimum.getPoint(), 2, 1);
}
public void testInconsistentEquations(LeastSquaresOptimizer optimizer) {
@ -346,9 +344,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
//TODO why is this part here? hasn't it been tested already?
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
Assert.assertEquals(-1, optimum.getPoint()[0], 1e-10);
Assert.assertEquals(1, optimum.getPoint()[1], 1e-10);
Assert.assertEquals(0, optimum.computeRMS(), TOl);
assertEquals(TOl, optimum.getPoint(), -1, 1);
//TODO move to builder test
optimizer.optimize(
@ -368,9 +365,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
Assert.assertEquals(-1, optimum.getPoint()[0], 1e-10);
Assert.assertEquals(1, optimum.getPoint()[1], 1e-10);
Assert.assertEquals(0, optimum.computeRMS(), TOl);
assertEquals(TOl, optimum.getPoint(), -1, 1);
//TODO move to builder test
optimizer.optimize(
@ -400,14 +396,14 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Assert.assertTrue(optimum.getEvaluations() < 10);
double rms = optimum.computeRMS();
Assert.assertEquals(1.768262623567235, FastMath.sqrt(circle.getN()) * rms, 1e-10);
Assert.assertEquals(1.768262623567235, FastMath.sqrt(circle.getN()) * rms, TOl);
Vector2D center = new Vector2D(optimum.getPoint()[0], optimum.getPoint()[1]);
Vector2D center = new Vector2D(optimum.getPoint().getEntry(0), optimum.getPoint().getEntry(1));
Assert.assertEquals(69.96016176931406, circle.getRadius(center), 1e-6);
Assert.assertEquals(96.07590211815305, center.getX(), 1e-6);
Assert.assertEquals(48.13516790438953, center.getY(), 1e-6);
double[][] cov = optimum.computeCovariances(1e-14);
double[][] cov = optimum.computeCovariances(1e-14).getData();
Assert.assertEquals(1.839, cov[0][0], 0.001);
Assert.assertEquals(0.731, cov[0][1], 0.001);
Assert.assertEquals(cov[0][1], cov[1][0], 1e-14);
@ -425,7 +421,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
optimum = optimizer.optimize(
builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
cov = optimum.computeCovariances(1e-14);
cov = optimum.computeCovariances(1e-14).getData();
Assert.assertEquals(0.0016, cov[0][0], 0.001);
Assert.assertEquals(3.2e-7, cov[0][1], 1e-9);
Assert.assertEquals(cov[0][1], cov[1][0], 1e-14);
@ -444,7 +440,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
Vector2D center = new Vector2D(optimum.getPoint()[0], optimum.getPoint()[1]);
Vector2D center = new Vector2D(optimum.getPoint().getEntry(0), optimum.getPoint().getEntry(1));
Assert.assertTrue(optimum.getEvaluations() < 25);
Assert.assertEquals(0.043, optimum.computeRMS(), 1e-3);
Assert.assertEquals(0.292235, circle.getRadius(center), 1e-6);
@ -465,8 +461,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
Optimum optimum = optimizer.optimize(
builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
Assert.assertEquals(-0.1517383071957963, optimum.getPoint()[0], 1e-6);
Assert.assertEquals(0.2074999736353867, optimum.getPoint()[1], 1e-6);
assertEquals(1e-6, optimum.getPoint(), -0.1517383071957963, 0.2074999736353867);
Assert.assertEquals(0.04268731682389561, optimum.computeRMS(), 1e-8);
}
@ -509,12 +504,12 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
final Optimum optimum = optimizer.optimize(builder(dataset).build());
final double[] actual = optimum.getPoint();
for (int i = 0; i < actual.length; i++) {
final RealVector actual = optimum.getPoint();
for (int i = 0; i < actual.getDimension(); i++) {
double expected = dataset.getParameter(i);
double delta = FastMath.abs(errParams * expected);
Assert.assertEquals(dataset.getName() + ", param #" + i,
expected, actual[i], delta);
expected, actual.getEntry(i), delta);
}
}

View File

@ -15,6 +15,7 @@ package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
@ -87,10 +88,10 @@ public class EvaluationTest {
final Evaluation evaluation = lsp.evaluate(lsp.getStart());
final double cost = evaluation.computeCost();
final double[] sig = evaluation.computeSigma(1e-14);
final RealVector sig = evaluation.computeSigma(1e-14);
final int dof = lsp.getObservationSize() - lsp.getParameterSize();
for (int i = 0; i < sig.length; i++) {
final double actual = FastMath.sqrt(cost * cost / dof) * sig[i];
for (int i = 0; i < sig.getDimension(); i++) {
final double actual = FastMath.sqrt(cost * cost / dof) * sig.getEntry(i);
Assert.assertEquals(dataset.getName() + ", parameter #" + i,
expected[i], actual, 1e-6 * expected[i]);
}

View File

@ -13,17 +13,18 @@
*/
package org.apache.commons.math3.fitting.leastsquares;
import java.util.Arrays;
import java.util.List;
import java.util.ArrayList;
import java.awt.geom.Point2D;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.commons.math3.stat.descriptive.StatisticalSummary;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.stat.descriptive.StatisticalSummary;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.commons.math3.util.FastMath;
import org.junit.Test;
import org.junit.Assert;
import org.junit.Test;
import java.awt.geom.Point2D;
import java.util.ArrayList;
import java.util.List;
/**
* This class demonstrates the main functionality of the
@ -95,7 +96,7 @@ public class EvaluationTestValidation {
sigmaEstimate[i] = new SummaryStatistics();
}
final double[] init = { slope, offset };
final RealVector init = new ArrayRealVector(new double[]{ slope, offset }, false);
// Monte-Carlo (generates many sets of observations).
final int mcRepeat = MONTE_CARLO_RUNS;
@ -117,12 +118,12 @@ public class EvaluationTestValidation {
// covariance matrix).
final LeastSquaresProblem lsp = builder(problem).build();
final double[] sigma = lsp.evaluate(init).computeSigma(1e-14);
final RealVector sigma = lsp.evaluate(init).computeSigma(1e-14);
// Accumulate statistics.
for (int i = 0; i < numParams; i++) {
paramsFoundByDirectSolution[i].addValue(regress[i]);
sigmaEstimate[i].addValue(sigma[i]);
sigmaEstimate[i].addValue(sigma.getEntry(i));
}
// Next Monte-Carlo.
@ -138,7 +139,7 @@ public class EvaluationTestValidation {
StatisticalSummary s = paramsFoundByDirectSolution[i].getSummary();
System.out.printf(" %+.6e %+.6e %+.6e\n",
init[i],
init.getEntry(i),
s.getMean(),
s.getStandardDeviation());
@ -212,7 +213,7 @@ public class EvaluationTestValidation {
}
// Direct solution (using simple regression).
final double[] regress = problem.solve();
final RealVector regress = new ArrayRealVector(problem.solve(), false);
// Dummy optimizer (to compute the chi-square).
final LeastSquaresProblem lsp = builder(problem).build();
@ -221,7 +222,7 @@ public class EvaluationTestValidation {
// Get chi-square of the best parameters set for the given set of
// observations.
final double bestChi2N = getChi2N(lsp, regress);
final double[] sigma = lsp.evaluate(regress).computeSigma(1e-14);
final RealVector sigma = lsp.evaluate(regress).computeSigma(1e-14);
// Monte-Carlo (generates a grid of parameters).
final int mcRepeat = MONTE_CARLO_RUNS;
@ -233,8 +234,8 @@ public class EvaluationTestValidation {
// Index 2 = normalized chi2
final List<double[]> paramsAndChi2 = new ArrayList<double[]>(gridSize * gridSize);
final double slopeRange = 10 * sigma[0];
final double offsetRange = 10 * sigma[1];
final double slopeRange = 10 * sigma.getEntry(0);
final double offsetRange = 10 * sigma.getEntry(1);
final double minSlope = slope - 0.5 * slopeRange;
final double minOffset = offset - 0.5 * offsetRange;
final double deltaSlope = slopeRange/ gridSize;
@ -243,7 +244,8 @@ public class EvaluationTestValidation {
final double s = minSlope + i * deltaSlope;
for (int j = 0; j < gridSize; j++) {
final double o = minOffset + j * deltaOffset;
final double chi2N = getChi2N(lsp, new double[] {s, o});
final double chi2N = getChi2N(lsp,
new ArrayRealVector(new double[] {s, o}, false));
paramsAndChi2.add(new double[] {s, o, chi2N});
}
@ -260,7 +262,7 @@ public class EvaluationTestValidation {
final String lineFmt = "%+.10e %+.10e %.8e\n";
// Point with smallest chi-square.
System.out.printf(lineFmt, regress[0], regress[1], bestChi2N);
System.out.printf(lineFmt, regress.getEntry(0), regress.getEntry(1), bestChi2N);
System.out.println(); // Empty line.
// Points within the confidence interval.
@ -280,7 +282,7 @@ public class EvaluationTestValidation {
}
System.out.println(); // Empty line.
System.out.println("# sigma=" + Arrays.toString(sigma));
System.out.println("# sigma=" + sigma.toString());
System.out.println("# " + numLarger + " sets filtered out");
}
@ -289,15 +291,17 @@ public class EvaluationTestValidation {
.model(problem.getModelFunction())
.jacobian(problem.getModelFunctionJacobian())
.target(problem.target())
.weight(new DiagonalMatrix(problem.weight()));
.weight(new DiagonalMatrix(problem.weight()))
//unused start point to avoid NPE
.start(new double[2]);
}
/**
* @return the normalized chi-square.
*/
private double getChi2N(LeastSquaresProblem lsp,
double[] params) {
RealVector params) {
final double cost = lsp.evaluate(params).computeCost();
return cost * cost / (lsp.getObservationSize() - params.length);
return cost * cost / (lsp.getObservationSize() - params.getDimension());
}
}

View File

@ -24,6 +24,8 @@ import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Precision;
@ -201,10 +203,10 @@ public class LevenbergMarquardtOptimizerTest
.build()
);
final double[] solution = optimum.getPoint();
final RealVector solution = optimum.getPoint();
final double[] expectedSolution = { 10.4, 958.3, 131.4, 33.9, 205.0 };
final double[][] covarMatrix = optimum.computeCovariances(1e-14);
final RealMatrix covarMatrix = optimum.computeCovariances(1e-14);
final double[][] expectedCovarMatrix = {
{ 3.38, -3.69, 27.98, -2.34, -49.24 },
{ -3.69, 2492.26, 81.89, -69.21, -8.9 },
@ -218,7 +220,7 @@ public class LevenbergMarquardtOptimizerTest
// Check that the computed solution is within the reference error range.
for (int i = 0; i < numParams; i++) {
final double error = FastMath.sqrt(expectedCovarMatrix[i][i]);
Assert.assertEquals("Parameter " + i, expectedSolution[i], solution[i], error);
Assert.assertEquals("Parameter " + i, expectedSolution[i], solution.getEntry(i), error);
}
// Check that each entry of the computed covariance matrix is within 10%
@ -227,7 +229,7 @@ public class LevenbergMarquardtOptimizerTest
for (int j = 0; j < numParams; j++) {
Assert.assertEquals("Covariance matrix [" + i + "][" + j + "]",
expectedCovarMatrix[i][j],
covarMatrix[i][j],
covarMatrix.getEntry(i, j),
FastMath.abs(0.1 * expectedCovarMatrix[i][j]));
}
}
@ -258,10 +260,10 @@ public class LevenbergMarquardtOptimizerTest
final Optimum optimum = optimizer.optimize(
builder(circle).maxIterations(50).start(init).build());
final double[] paramFound = optimum.getPoint();
final double[] paramFound = optimum.getPoint().toArray();
// Retrieve errors estimation.
final double[] asymptoticStandardErrorFound = optimum.computeSigma(1e-14);
final double[] asymptoticStandardErrorFound = optimum.computeSigma(1e-14).toArray();
// Check that the parameters are found within the assumed error bars.
Assert.assertEquals(xCenter, paramFound[0], asymptoticStandardErrorFound[0]);

View File

@ -518,7 +518,7 @@ public class MinpackTest {
final Optimum optimum = optimizer.optimize(problem);
Assert.assertFalse(exceptionExpected);
function.checkTheoreticalMinCost(optimum.computeRMS());
function.checkTheoreticalMinParams(optimum.getPoint());
function.checkTheoreticalMinParams(optimum.getPoint().toArray());
} catch (TooManyEvaluationsException e) {
Assert.assertTrue(exceptionExpected);
}