Added general methods to guess errors on estimated parameters

JIRA: MATH-176

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@613474 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2008-01-19 22:48:16 +00:00
parent a1dcceb840
commit 376834acf1
8 changed files with 495 additions and 354 deletions

View File

@ -87,9 +87,17 @@ public class MessagesResources_fr
{ "Conversion Exception in Transformation: {0}", { "Conversion Exception in Transformation: {0}",
"Exception de conversion dans une transformation : {0}" }, "Exception de conversion dans une transformation : {0}" },
// org.apache.commons.math.estimation.AbstractEstimator
{ "maximal number of evaluations exceeded ({0})",
"nombre maximal d''\u00e9valuations d\u00e9pass\u00e9 ({0})" },
{ "unable to compute covariances: singular problem",
"impossible de calculer les covariances : probl\u00e8me singulier"},
{ "no degrees of freedom ({0} measurements, {1} parameters)",
"aucun degr\u00e9 de libert\u00e9 ({0} mesures, {1} param\u00e8tres)" },
// org.apache.commons.math.estimation.GaussNewtonEstimator // org.apache.commons.math.estimation.GaussNewtonEstimator
{ "unable to converge in {0} iterations", { "unable to solve: singular problem",
"pas de convergence apr\u00e8s {0} it\u00e9rations" }, "r\u00e9solution impossible : probl\u00e8me singulier" },
// org.apache.commons.math.estimation.LevenbergMarquardtEstimator // org.apache.commons.math.estimation.LevenbergMarquardtEstimator
{ "cost relative tolerance is too small ({0}), no further reduction in the sum of squares is possible", { "cost relative tolerance is too small ({0}), no further reduction in the sum of squares is possible",
@ -98,8 +106,6 @@ public class MessagesResources_fr
"trop petite tol\u00e9rance relative sur les param\u00e8tres ({0}), aucune am\u00e9lioration de la solution approximative n''est possible" }, "trop petite tol\u00e9rance relative sur les param\u00e8tres ({0}), aucune am\u00e9lioration de la solution approximative n''est possible" },
{ "orthogonality tolerance is too small ({0}), solution is orthogonal to the jacobian", { "orthogonality tolerance is too small ({0}), solution is orthogonal to the jacobian",
"trop petite tol\u00e9rance sur l''orthogonalit\u00e9 ({0}), la solution est orthogonale \u00e0 la jacobienne" }, "trop petite tol\u00e9rance sur l''orthogonalit\u00e9 ({0}), la solution est orthogonale \u00e0 la jacobienne" },
{ "maximal number of evaluations exceeded ({0})",
"nombre maximal d''\u00e9valuations d\u00e9pass\u00e9 ({0})" },
// org.apache.commons.math.geometry.CardanEulerSingularityException // org.apache.commons.math.geometry.CardanEulerSingularityException
{ "Cardan angles singularity", { "Cardan angles singularity",

View File

@ -0,0 +1,274 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.estimation;
import java.util.Arrays;
import org.apache.commons.math.linear.InvalidMatrixException;
import org.apache.commons.math.linear.RealMatrixImpl;
public abstract class AbstractEstimator implements Estimator {
/**
* Build an abstract estimator for least squares problems.
*/
protected AbstractEstimator() {
}
/**
* Set the maximal number of cost evaluations allowed.
*
* @param maxCostEval maximal number of cost evaluations allowed
* @see #estimate
*/
public void setMaxCostEval(int maxCostEval) {
this.maxCostEval = maxCostEval;
}
/**
* Get the number of cost evaluations.
*
* @return number of cost evaluations
* */
public int getCostEvaluations() {
return costEvaluations;
}
/**
* Get the number of jacobian evaluations.
*
* @return number of jacobian evaluations
* */
public int getJacobianEvaluations() {
return jacobianEvaluations;
}
/**
* Update the jacobian matrix.
*/
protected void updateJacobian() {
++jacobianEvaluations;
Arrays.fill(jacobian, 0);
for (int i = 0, index = 0; i < rows; i++) {
WeightedMeasurement wm = measurements[i];
double factor = -Math.sqrt(wm.getWeight());
for (int j = 0; j < cols; ++j) {
jacobian[index++] = factor * wm.getPartial(parameters[j]);
}
}
}
/**
* Update the residuals array and cost function value.
* @exception EstimationException if the number of cost evaluations
* exceeds the maximum allowed
*/
protected void updateResidualsAndCost()
throws EstimationException {
if (++costEvaluations > maxCostEval) {
throw new EstimationException("maximal number of evaluations exceeded ({0})",
new String[] {
Integer.toString(maxCostEval)
});
}
cost = 0;
for (int i = 0, index = 0; i < rows; i++, index += cols) {
WeightedMeasurement wm = measurements[i];
double residual = wm.getResidual();
residuals[i] = Math.sqrt(wm.getWeight()) * residual;
cost += wm.getWeight() * residual * residual;
}
cost = Math.sqrt(cost);
}
/**
* Get the Root Mean Square value.
* Get the Root Mean Square value, i.e. the root of the arithmetic
* mean of the square of all weighted residuals. This is related to the
* criterion that is minimized by the estimator as follows: if
* <em>c</em> if the criterion, and <em>n</em> is the number of
* measurements, then the RMS is <em>sqrt (c/n)</em>.
*
* @param problem estimation problem
* @return RMS value
*/
public double getRMS(EstimationProblem problem) {
WeightedMeasurement[] wm = problem.getMeasurements();
double criterion = 0;
for (int i = 0; i < wm.length; ++i) {
double residual = wm[i].getResidual();
criterion += wm[i].getWeight() * residual * residual;
}
return Math.sqrt(criterion / wm.length);
}
/**
* Get the Chi-Square value.
* @param problem estimation problem
* @return chi-square value
*/
public double getChiSquare(EstimationProblem problem) {
WeightedMeasurement[] wm = problem.getMeasurements();
double chiSquare = 0;
for (int i = 0; i < wm.length; ++i) {
double residual = wm[i].getResidual();
chiSquare += residual * residual / wm[i].getWeight();
}
return chiSquare;
}
/**
* Get the covariance matrix of estimated parameters.
* @param problem estimation problem
* @return covariance matrix
* @exception EstimationException if the covariance matrix
* cannot be computed (singular problem)
*/
public double[][] getCovariances(EstimationProblem problem)
throws EstimationException {
// set up the jacobian
updateJacobian();
// compute transpose(J).J, avoiding building big intermediate matrices
final int rows = problem.getMeasurements().length;
final int cols = problem.getAllParameters().length;
final int max = cols * rows;
double[][] jTj = new double[cols][cols];
for (int i = 0; i < cols; ++i) {
for (int j = i; j < cols; ++j) {
double sum = 0;
for (int k = 0; k < max; k += cols) {
sum += jacobian[k + i] * jacobian[k + j];
}
jTj[i][j] = sum;
jTj[j][i] = sum;
}
}
try {
// compute the covariances matrix
return new RealMatrixImpl(jTj).inverse().getData();
} catch (InvalidMatrixException ime) {
throw new EstimationException("unable to compute covariances: singular problem",
new Object[0]);
}
}
/**
* Guess the errors in estimated parameters.
* <p>Guessing is covariance-based, it only gives rough order of magnitude.</p>
* @param problem estimation problem
* @return errors in estimated parameters
* @exception EstimationException if the covariances matrix cannot be computed
* or the number of degrees of freedom is not positive (number of measurements
* lesser or equal to number of parameters)
*/
public double[] guessParametersErrors(EstimationProblem problem)
throws EstimationException {
int m = problem.getMeasurements().length;
int p = problem.getAllParameters().length;
if (m <= p) {
throw new EstimationException("no degrees of freedom ({0} measurements, {1} parameters)",
new Object[] { new Integer(m), new Integer(p)});
}
double[] errors = new double[problem.getAllParameters().length];
final double c = Math.sqrt(getChiSquare(problem) / (m - p));
double[][] covar = getCovariances(problem);
for (int i = 0; i < errors.length; ++i) {
errors[i] = Math.sqrt(covar[i][i]) * c;
}
return errors;
}
/**
* Initialization of the common parts of the estimation.
* <p>This method <em>must</em> be called at the start
* of the {@link #estimate(EstimationProblem) estimate}
* method.</p>
* @param problem estimation problem to solve
*/
protected void initializeEstimate(EstimationProblem problem) {
// reset counters
costEvaluations = 0;
jacobianEvaluations = 0;
// retrieve the equations and the parameters
measurements = problem.getMeasurements();
parameters = problem.getUnboundParameters();
// arrays shared with the other private methods
rows = measurements.length;
cols = parameters.length;
jacobian = new double[rows * cols];
residuals = new double[rows];
cost = Double.POSITIVE_INFINITY;
}
public abstract void estimate(EstimationProblem problem)
throws EstimationException;
/** Array of measurements. */
protected WeightedMeasurement[] measurements;
/** Array of parameters. */
protected EstimatedParameter[] parameters;
/**
* Jacobian matrix.
* <p>This matrix is in canonical form just after the calls to
* {@link #updateJacobian()}, but may be modified by the solver
* in the derived class (the {@link LevenbergMarquardtEstimator
* Levenberg-Marquardt estimator} does this).</p>
*/
protected double[] jacobian;
/** Number of columns of the jacobian matrix. */
protected int cols;
/** Number of rows of the jacobian matrix. */
protected int rows;
/** Residuals array.
* <p>This array is in canonical form just after the calls to
* {@link #updateJacobian()}, but may be modified by the solver
* in the derived class (the {@link LevenbergMarquardtEstimator
* Levenberg-Marquardt estimator} does this).</p>
*/
protected double[] residuals;
/** Cost value (square root of the sum of the residuals). */
protected double cost;
/** Maximal allowed number of cost evaluations. */
protected int maxCostEval;
/** Number of cost evaluations. */
protected int costEvaluations;
/** Number of jacobian evaluations. */
protected int jacobianEvaluations;
}

View File

@ -30,7 +30,7 @@ public class EstimationException
extends MathException { extends MathException {
/** Serializable version identifier. */ /** Serializable version identifier. */
private static final long serialVersionUID = -7414806622114810487L; private static final long serialVersionUID = -573038581493881337L;
/** /**
* Simple constructor. * Simple constructor.
@ -38,17 +38,8 @@ extends MathException {
* @param specifier format specifier (to be translated) * @param specifier format specifier (to be translated)
* @param parts to insert in the format (no translation) * @param parts to insert in the format (no translation)
*/ */
public EstimationException(String specifier, String[] parts) { public EstimationException(String specifier, Object[] parts) {
super(specifier, parts); super(specifier, parts);
} }
/**
* Simple constructor.
* Build an exception from a cause
* @param cause cause of this exception
*/
public EstimationException(Throwable cause) {
super(cause);
}
} }

View File

@ -60,10 +60,31 @@ public interface Estimator {
* criterion that is minimized by the estimator as follows: if * criterion that is minimized by the estimator as follows: if
* <em>c</em> is the criterion, and <em>n</em> is the number of * <em>c</em> is the criterion, and <em>n</em> is the number of
* measurements, then the RMS is <em>sqrt (c/n)</em>. * measurements, then the RMS is <em>sqrt (c/n)</em>.
* @see #guessParametersErrors(EstimationProblem)
* *
* @param problem estimation problem * @param problem estimation problem
* @return RMS value * @return RMS value
*/ */
public double getRMS(EstimationProblem problem); public double getRMS(EstimationProblem problem);
/**
* Get the covariance matrix of estimated parameters.
* @param problem estimation problem
* @return covariance matrix
* @exception EstimationException if the covariance matrix
* cannot be computed (singular problem)
*/
public double[][] getCovariances(EstimationProblem problem)
throws EstimationException;
/**
* Guess the errors in estimated parameters.
* @see #getRMS(EstimationProblem)
* @param problem estimation problem
* @return errors in estimated parameters
* @exception EstimationException if the error cannot be guessed
*/
public double[] guessParametersErrors(EstimationProblem problem)
throws EstimationException;
} }

View File

@ -34,206 +34,145 @@ import org.apache.commons.math.linear.RealMatrixImpl;
* *
*/ */
public class GaussNewtonEstimator public class GaussNewtonEstimator extends AbstractEstimator implements Serializable {
implements Estimator, Serializable {
/** /**
* Simple constructor. * Simple constructor.
* *
* <p>This constructor builds an estimator and stores its convergence * <p>This constructor builds an estimator and stores its convergence
* characteristics.</p> * characteristics.</p>
* *
* <p>An estimator is considered to have converged whenever either * <p>An estimator is considered to have converged whenever either
* the criterion goes below a physical threshold under which * the criterion goes below a physical threshold under which
* improvements are considered useless or when the algorithm is * improvements are considered useless or when the algorithm is
* unable to improve it (even if it is still high). The first * unable to improve it (even if it is still high). The first
* condition that is met stops the iterations.</p> * condition that is met stops the iterations.</p>
* *
* <p>The fact an estimator has converged does not mean that the * <p>The fact an estimator has converged does not mean that the
* model accurately fits the measurements. It only means no better * model accurately fits the measurements. It only means no better
* solution can be found, it does not mean this one is good. Such an * solution can be found, it does not mean this one is good. Such an
* analysis is left to the caller.</p> * analysis is left to the caller.</p>
* *
* <p>If neither conditions are fulfilled before a given number of * <p>If neither conditions are fulfilled before a given number of
* iterations, the algorithm is considered to have failed and an * iterations, the algorithm is considered to have failed and an
* {@link EstimationException} is thrown.</p> * {@link EstimationException} is thrown.</p>
* *
* @param maxIterations maximum number of iterations allowed * @param maxCostEval maximal number of cost evaluations allowed
* @param convergence criterion threshold below which we do not need * @param convergence criterion threshold below which we do not need
* to improve the criterion anymore * to improve the criterion anymore
* @param steadyStateThreshold steady state detection threshold, the * @param steadyStateThreshold steady state detection threshold, the
* problem has converged has reached a steady state if * problem has converged has reached a steady state if
* <code>Math.abs (Jn - Jn-1) < Jn * convergence</code>, where * <code>Math.abs (Jn - Jn-1) < Jn * convergence</code>, where
* <code>Jn</code> and <code>Jn-1</code> are the current and * <code>Jn</code> and <code>Jn-1</code> are the current and
* preceding criterion value (square sum of the weighted residuals * preceding criterion value (square sum of the weighted residuals
* of considered measurements). * of considered measurements).
*/ */
public GaussNewtonEstimator(int maxIterations, public GaussNewtonEstimator(int maxCostEval,
double convergence, double convergence,
double steadyStateThreshold) { double steadyStateThreshold) {
this.maxIterations = maxIterations; setMaxCostEval(maxCostEval);
this.steadyStateThreshold = steadyStateThreshold; this.steadyStateThreshold = steadyStateThreshold;
this.convergence = convergence; this.convergence = convergence;
} }
/** /**
* Solve an estimation problem using a least squares criterion. * Solve an estimation problem using a least squares criterion.
* *
* <p>This method set the unbound parameters of the given problem * <p>This method set the unbound parameters of the given problem
* starting from their current values through several iterations. At * starting from their current values through several iterations. At
* each step, the unbound parameters are changed in order to * each step, the unbound parameters are changed in order to
* minimize a weighted least square criterion based on the * minimize a weighted least square criterion based on the
* measurements of the problem.</p> * measurements of the problem.</p>
* *
* <p>The iterations are stopped either when the criterion goes * <p>The iterations are stopped either when the criterion goes
* below a physical threshold under which improvement are considered * below a physical threshold under which improvement are considered
* useless or when the algorithm is unable to improve it (even if it * useless or when the algorithm is unable to improve it (even if it
* is still high). The first condition that is met stops the * is still high). The first condition that is met stops the
* iterations. If the convergence it nos reached before the maximum * iterations. If the convergence it nos reached before the maximum
* number of iterations, an {@link EstimationException} is * number of iterations, an {@link EstimationException} is
* thrown.</p> * thrown.</p>
* *
* @param problem estimation problem to solve * @param problem estimation problem to solve
* @exception EstimationException if the problem cannot be solved * @exception EstimationException if the problem cannot be solved
* *
* @see EstimationProblem * @see EstimationProblem
* *
*/ */
public void estimate(EstimationProblem problem) public void estimate(EstimationProblem problem)
throws EstimationException {
int iterations = 0;
double previous = 0.0;
double current = 0.0;
// iterate until convergence is reached
do {
if (++iterations > maxIterations) {
throw new EstimationException ("unable to converge in {0} iterations",
new String[] {
Integer.toString(maxIterations)
});
}
// perform one iteration
linearEstimate(problem);
previous = current;
current = evaluateCriterion(problem);
} while ((iterations < 2)
|| (Math.abs(previous - current) > (current * steadyStateThreshold)
&& (Math.abs(current) > convergence)));
}
/**
* Estimate the solution of a linear least square problem.
*
* <p>The Gauss-Newton algorithm is iterative. Each iteration
* consists in solving a linearized least square problem. Several
* iterations are needed for general problems since the
* linearization is only an approximation of the problem
* behaviour. However, for linear problems one iteration is enough
* to get the solution. This method is provided in the public
* interface in order to handle more efficiently these linear
* problems.</p>
*
* @param problem estimation problem to solve
* @exception EstimationException if the problem cannot be solved
*
*/
public void linearEstimate(EstimationProblem problem)
throws EstimationException { throws EstimationException {
EstimatedParameter[] parameters = problem.getUnboundParameters(); initializeEstimate(problem);
WeightedMeasurement[] measurements = problem.getMeasurements();
// build the linear problem // work matrices
RealMatrix b = new RealMatrixImpl(parameters.length, 1); double[] grad = new double[parameters.length];
RealMatrix a = new RealMatrixImpl(parameters.length, parameters.length); RealMatrixImpl bDecrement = new RealMatrixImpl(parameters.length, 1);
double[] grad = new double[parameters.length]; double[][] bDecrementData = bDecrement.getDataRef();
RealMatrixImpl bDecrement = new RealMatrixImpl(parameters.length, 1); RealMatrixImpl wGradGradT = new RealMatrixImpl(parameters.length, parameters.length);
double[][] bDecrementData = bDecrement.getDataRef(); double[][] wggData = wGradGradT.getDataRef();
RealMatrixImpl wGradGradT = new RealMatrixImpl(parameters.length, parameters.length);
double[][] wggData = wGradGradT.getDataRef();
for (int i = 0; i < measurements.length; ++i) {
if (! measurements [i].isIgnored()) {
double weight = measurements[i].getWeight(); // iterate until convergence is reached
double residual = measurements[i].getResidual(); double previous = Double.POSITIVE_INFINITY;
do {
// compute the normal equation // build the linear problem
for (int j = 0; j < parameters.length; ++j) { ++jacobianEvaluations;
grad[j] = measurements[i].getPartial(parameters[j]); RealMatrix b = new RealMatrixImpl(parameters.length, 1);
bDecrementData[j][0] = weight * residual * grad[j]; RealMatrix a = new RealMatrixImpl(parameters.length, parameters.length);
} for (int i = 0; i < measurements.length; ++i) {
if (! measurements [i].isIgnored()) {
double weight = measurements[i].getWeight();
double residual = measurements[i].getResidual();
// compute the normal equation
for (int j = 0; j < parameters.length; ++j) {
grad[j] = measurements[i].getPartial(parameters[j]);
bDecrementData[j][0] = weight * residual * grad[j];
}
// build the contribution matrix for measurement i
for (int k = 0; k < parameters.length; ++k) {
double[] wggRow = wggData[k];
double gk = grad[k];
for (int l = 0; l < parameters.length; ++l) {
wggRow[l] = weight * gk * grad[l];
}
}
// update the matrices
a = a.add(wGradGradT);
b = b.add(bDecrement);
// build the contribution matrix for measurement i
for (int k = 0; k < parameters.length; ++k) {
double[] wggRow = wggData[k];
double gk = grad[k];
for (int l = 0; l < parameters.length; ++l) {
wggRow[l] = weight * gk * grad[l];
} }
} }
// update the matrices try {
a = a.add(wGradGradT);
b = b.add(bDecrement); // solve the linearized least squares problem
RealMatrix dX = a.solve(b);
// update the estimated parameters
for (int i = 0; i < parameters.length; ++i) {
parameters[i].setEstimate(parameters[i].getEstimate() + dX.getEntry(i, 0));
}
} catch(InvalidMatrixException e) {
throw new EstimationException("unable to solve: singular problem", new Object[0]);
}
previous = cost;
updateResidualsAndCost();
} while ((getCostEvaluations() < 2)
|| (Math.abs(previous - cost) > (cost * steadyStateThreshold)
&& (Math.abs(cost) > convergence)));
}
} }
try { private double steadyStateThreshold;
private double convergence;
// solve the linearized least squares problem private static final long serialVersionUID = 5485001826076289109L;
RealMatrix dX = a.solve(b);
// update the estimated parameters
for (int i = 0; i < parameters.length; ++i) {
parameters[i].setEstimate(parameters[i].getEstimate() + dX.getEntry(i, 0));
}
} catch(InvalidMatrixException e) {
throw new EstimationException(e);
}
}
private double evaluateCriterion(EstimationProblem problem) {
double criterion = 0.0;
WeightedMeasurement[] measurements = problem.getMeasurements();
for (int i = 0; i < measurements.length; ++i) {
double residual = measurements[i].getResidual();
criterion += measurements[i].getWeight() * residual * residual;
}
return criterion;
}
/**
* Get the Root Mean Square value.
* Get the Root Mean Square value, i.e. the root of the arithmetic
* mean of the square of all weighted residuals. This is related to the
* criterion that is minimized by the estimator as follows: if
* <em>c</em> if the criterion, and <em>n</em> is the number of
* measurements, then the RMS is <em>sqrt (c/n)</em>.
* @param problem estimation problem
* @return RMS value
*/
public double getRMS(EstimationProblem problem) {
double criterion = evaluateCriterion(problem);
int n = problem.getMeasurements().length;
return Math.sqrt(criterion / n);
}
private int maxIterations;
private double steadyStateThreshold;
private double convergence;
private static final long serialVersionUID = -7606628156644194170L;
} }

View File

@ -19,6 +19,7 @@ package org.apache.commons.math.estimation;
import java.io.Serializable; import java.io.Serializable;
import java.util.Arrays; import java.util.Arrays;
/** /**
* This class solves a least squares problem. * This class solves a least squares problem.
* *
@ -92,7 +93,7 @@ import java.util.Arrays;
* @author Kenneth E. Hillstrom (original fortran) * @author Kenneth E. Hillstrom (original fortran)
* @author Jorge J. More (original fortran) * @author Jorge J. More (original fortran)
*/ */
public class LevenbergMarquardtEstimator implements Serializable, Estimator { public class LevenbergMarquardtEstimator extends AbstractEstimator implements Serializable {
/** /**
* Build an estimator for least squares problems. * Build an estimator for least squares problems.
@ -107,12 +108,16 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
* </p> * </p>
*/ */
public LevenbergMarquardtEstimator() { public LevenbergMarquardtEstimator() {
// set up the superclass with a default max cost evaluations setting
setMaxCostEval(1000);
// default values for the tuning parameters // default values for the tuning parameters
setInitialStepBoundFactor(100.0); setInitialStepBoundFactor(100.0);
setMaxCostEval(1000);
setCostRelativeTolerance(1.0e-10); setCostRelativeTolerance(1.0e-10);
setParRelativeTolerance(1.0e-10); setParRelativeTolerance(1.0e-10);
setOrthoTolerance(1.0e-10); setOrthoTolerance(1.0e-10);
} }
/** /**
@ -128,16 +133,6 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
this.initialStepBoundFactor = initialStepBoundFactor; this.initialStepBoundFactor = initialStepBoundFactor;
} }
/**
* Set the maximal number of cost evaluations.
*
* @param maxCostEval maximal number of cost evaluations
* @see #estimate
*/
public void setMaxCostEval(int maxCostEval) {
this.maxCostEval = maxCostEval;
}
/** /**
* Set the desired relative error in the sum of squares. * Set the desired relative error in the sum of squares.
* *
@ -170,75 +165,6 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
this.orthoTolerance = orthoTolerance; this.orthoTolerance = orthoTolerance;
} }
/**
* Get the number of cost evaluations.
*
* @return number of cost evaluations
* */
public int getCostEvaluations() {
return costEvaluations;
}
/**
* Get the number of jacobian evaluations.
*
* @return number of jacobian evaluations
* */
public int getJacobianEvaluations() {
return jacobianEvaluations;
}
/**
* Update the jacobian matrix.
*/
private void updateJacobian() {
++jacobianEvaluations;
Arrays.fill(jacobian, 0);
for (int i = 0, index = 0; i < rows; i++) {
WeightedMeasurement wm = measurements[i];
double factor = -Math.sqrt(wm.getWeight());
for (int j = 0; j < cols; ++j) {
jacobian[index++] = factor * wm.getPartial(parameters[j]);
}
}
}
/**
* Update the residuals array and cost function value.
*/
private void updateResidualsAndCost() {
++costEvaluations;
cost = 0;
for (int i = 0, index = 0; i < rows; i++, index += cols) {
WeightedMeasurement wm = measurements[i];
double residual = wm.getResidual();
residuals[i] = Math.sqrt(wm.getWeight()) * residual;
cost += wm.getWeight() * residual * residual;
}
cost = Math.sqrt(cost);
}
/**
* Get the Root Mean Square value.
* Get the Root Mean Square value, i.e. the root of the arithmetic
* mean of the square of all weighted residuals. This is related to the
* criterion that is minimized by the estimator as follows: if
* <em>c</em> if the criterion, and <em>n</em> is the number of
* measurements, then the RMS is <em>sqrt (c/n)</em>.
*
* @param problem estimation problem
* @return RMS value
*/
public double getRMS(EstimationProblem problem) {
WeightedMeasurement[] wm = problem.getMeasurements();
double criterion = 0;
for (int i = 0; i < wm.length; ++i) {
double residual = wm[i].getResidual();
criterion += wm[i].getWeight() * residual * residual;
}
return Math.sqrt(criterion / wm.length);
}
/** /**
* Solve an estimation problem using the Levenberg-Marquardt algorithm. * Solve an estimation problem using the Levenberg-Marquardt algorithm.
* <p>The algorithm used is a modified Levenberg-Marquardt one, based * <p>The algorithm used is a modified Levenberg-Marquardt one, based
@ -263,7 +189,6 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
* reached with the specified algorithm settings or if there are more variables * reached with the specified algorithm settings or if there are more variables
* than equations * than equations
* @see #setInitialStepBoundFactor * @see #setInitialStepBoundFactor
* @see #setMaxCostEval
* @see #setCostRelativeTolerance * @see #setCostRelativeTolerance
* @see #setParRelativeTolerance * @see #setParRelativeTolerance
* @see #setOrthoTolerance * @see #setOrthoTolerance
@ -271,21 +196,15 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
public void estimate(EstimationProblem problem) public void estimate(EstimationProblem problem)
throws EstimationException { throws EstimationException {
// retrieve the equations and the parameters initializeEstimate(problem);
measurements = problem.getMeasurements();
parameters = problem.getUnboundParameters();
// arrays shared with the other private methods // arrays shared with the other private methods
rows = measurements.length;
cols = parameters.length;
solvedCols = Math.min(rows, cols); solvedCols = Math.min(rows, cols);
jacobian = new double[rows * cols];
diagR = new double[cols]; diagR = new double[cols];
jacNorm = new double[cols]; jacNorm = new double[cols];
beta = new double[cols]; beta = new double[cols];
permutation = new int[cols]; permutation = new int[cols];
lmDir = new double[cols]; lmDir = new double[cols];
residuals = new double[rows];
// local variables // local variables
double delta = 0, xNorm = 0; double delta = 0, xNorm = 0;
@ -300,11 +219,9 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
updateResidualsAndCost(); updateResidualsAndCost();
// outer loop // outer loop
lmPar = 0; lmPar = 0;
costEvaluations = 0;
jacobianEvaluations = 0;
boolean firstIteration = true; boolean firstIteration = true;
while (costEvaluations < maxCostEval) { while (true) {
// compute the Q.R. decomposition of the jacobian matrix // compute the Q.R. decomposition of the jacobian matrix
updateJacobian(); updateJacobian();
@ -477,42 +394,28 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
// tests for termination and stringent tolerances // tests for termination and stringent tolerances
// (2.2204e-16 is the machine epsilon for IEEE754) // (2.2204e-16 is the machine epsilon for IEEE754)
if (costEvaluations >= maxCostEval) {
break;
}
if ((Math.abs(actRed) <= 2.2204e-16) if ((Math.abs(actRed) <= 2.2204e-16)
&& (preRed <= 2.2204e-16) && (preRed <= 2.2204e-16)
&& (ratio <= 2.0)) { && (ratio <= 2.0)) {
throw new EstimationException("cost relative tolerance is too small ({0})," throw new EstimationException("cost relative tolerance is too small ({0}),"
+ " no further reduction in the" + " no further reduction in the"
+ " sum of squares is possible", + " sum of squares is possible",
new String[] { new Object[] { new Double(costRelativeTolerance) });
Double.toString(costRelativeTolerance)
});
} else if (delta <= 2.2204e-16 * xNorm) { } else if (delta <= 2.2204e-16 * xNorm) {
throw new EstimationException("parameters relative tolerance is too small" throw new EstimationException("parameters relative tolerance is too small"
+ " ({0}), no further improvement in" + " ({0}), no further improvement in"
+ " the approximate solution is possible", + " the approximate solution is possible",
new String[] { new Object[] { new Double(parRelativeTolerance) });
Double.toString(parRelativeTolerance)
});
} else if (maxCosine <= 2.2204e-16) { } else if (maxCosine <= 2.2204e-16) {
throw new EstimationException("orthogonality tolerance is too small ({0})," throw new EstimationException("orthogonality tolerance is too small ({0}),"
+ " solution is orthogonal to the jacobian", + " solution is orthogonal to the jacobian",
new String[] { new Object[] { new Double(orthoTolerance) });
Double.toString(orthoTolerance)
});
} }
} }
} }
throw new EstimationException("maximal number of evaluations exceeded ({0})",
new String[] {
Integer.toString(maxCostEval)
});
} }
/** /**
@ -919,29 +822,9 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
} }
} }
/** Array of measurements. */
private WeightedMeasurement[] measurements;
/** Array of parameters. */
private EstimatedParameter[] parameters;
/**
* Jacobian matrix.
* <p>Depending on the computation phase, this matrix is either in
* canonical form (just after the calls to updateJacobian) or in
* Q.R. decomposed form (after calls to qrDecomposition)</p>
*/
private double[] jacobian;
/** Number of columns of the jacobian matrix. */
private int cols;
/** Number of solved variables. */ /** Number of solved variables. */
private int solvedCols; private int solvedCols;
/** Number of rows of the jacobian matrix. */
private int rows;
/** Diagonal elements of the R matrix in the Q.R. decomposition. */ /** Diagonal elements of the R matrix in the Q.R. decomposition. */
private double[] diagR; private double[] diagR;
@ -963,28 +846,9 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
/** Parameters evolution direction associated with lmPar. */ /** Parameters evolution direction associated with lmPar. */
private double[] lmDir; private double[] lmDir;
/** Residuals array.
* <p>Depending on the computation phase, this array is either in
* canonical form (just after the calls to updateResiduals) or in
* premultiplied by Qt form (just after calls to qTy)</p>
*/
private double[] residuals;
/** Cost value (square root of the sum of the residuals). */
private double cost;
/** Positive input variable used in determining the initial step bound. */ /** Positive input variable used in determining the initial step bound. */
private double initialStepBoundFactor; private double initialStepBoundFactor;
/** Maximal number of cost evaluations. */
private int maxCostEval;
/** Number of cost evaluations. */
private int costEvaluations;
/** Number of jacobian evaluations. */
private int jacobianEvaluations;
/** Desired relative error in the sum of squares. */ /** Desired relative error in the sum of squares. */
private double costRelativeTolerance; private double costRelativeTolerance;
@ -995,6 +859,6 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
* and the columns of the jacobian. */ * and the columns of the jacobian. */
private double orthoTolerance; private double orthoTolerance;
private static final long serialVersionUID = 5387476316105068340L; private static final long serialVersionUID = -5705952631533171019L;
} }

View File

@ -110,6 +110,14 @@ public class LevenbergMarquardtEstimatorTest
LevenbergMarquardtEstimator estimator = new LevenbergMarquardtEstimator(); LevenbergMarquardtEstimator estimator = new LevenbergMarquardtEstimator();
estimator.estimate(problem); estimator.estimate(problem);
assertEquals(0, estimator.getRMS(problem), 1.0e-10); assertEquals(0, estimator.getRMS(problem), 1.0e-10);
try {
estimator.guessParametersErrors(problem);
fail("an exception should have been thrown");
} catch (EstimationException ee) {
// expected behavior
} catch (Exception e) {
fail("wrong exception caught");
}
assertEquals(1.5, assertEquals(1.5,
problem.getUnboundParameters()[0].getEstimate(), problem.getUnboundParameters()[0].getEstimate(),
1.0e-10); 1.0e-10);
@ -267,7 +275,15 @@ public class LevenbergMarquardtEstimatorTest
estimator.estimate(problem); estimator.estimate(problem);
assertTrue(estimator.getRMS(problem) < initialCost); assertTrue(estimator.getRMS(problem) < initialCost);
assertTrue(Math.sqrt(m.length) * estimator.getRMS(problem) > 0.6); assertTrue(Math.sqrt(m.length) * estimator.getRMS(problem) > 0.6);
double dJ0 = 2 * (m[0].getResidual() * m[0].getPartial(p[0]) try {
estimator.getCovariances(problem);
fail("an exception should have been thrown");
} catch (EstimationException ee) {
// expected behavior
} catch (Exception e) {
fail("wrong exception caught");
}
double dJ0 = 2 * (m[0].getResidual() * m[0].getPartial(p[0])
+ m[1].getResidual() * m[1].getPartial(p[0]) + m[1].getResidual() * m[1].getPartial(p[0])
+ m[2].getResidual() * m[2].getPartial(p[0])); + m[2].getResidual() * m[2].getPartial(p[0]));
double dJ1 = 2 * (m[0].getResidual() * m[0].getPartial(p[1]) double dJ1 = 2 * (m[0].getResidual() * m[0].getPartial(p[1])
@ -496,7 +512,34 @@ public class LevenbergMarquardtEstimatorTest
assertEquals(69.96016176931406, circle.getRadius(), 1.0e-10); assertEquals(69.96016176931406, circle.getRadius(), 1.0e-10);
assertEquals(96.07590211815305, circle.getX(), 1.0e-10); assertEquals(96.07590211815305, circle.getX(), 1.0e-10);
assertEquals(48.13516790438953, circle.getY(), 1.0e-10); assertEquals(48.13516790438953, circle.getY(), 1.0e-10);
} double[][] cov = estimator.getCovariances(circle);
assertEquals(1.839, cov[0][0], 0.001);
assertEquals(0.731, cov[0][1], 0.001);
assertEquals(cov[0][1], cov[1][0], 1.0e-14);
assertEquals(0.786, cov[1][1], 0.001);
double[] errors = estimator.guessParametersErrors(circle);
assertEquals(1.384, errors[0], 0.001);
assertEquals(0.905, errors[1], 0.001);
// add perfect measurements and check errors are reduced
double cx = circle.getX();
double cy = circle.getY();
double r = circle.getRadius();
for (double d= 0; d < 2 * Math.PI; d += 0.01) {
circle.addPoint(cx + r * Math.cos(d), cy + r * Math.sin(d));
}
estimator = new LevenbergMarquardtEstimator();
estimator.estimate(circle);
cov = estimator.getCovariances(circle);
assertEquals(0.004, cov[0][0], 0.001);
assertEquals(6.40e-7, cov[0][1], 1.0e-9);
assertEquals(cov[0][1], cov[1][0], 1.0e-14);
assertEquals(0.003, cov[1][1], 0.001);
errors = estimator.guessParametersErrors(circle);
assertEquals(0.004, errors[0], 0.001);
assertEquals(0.004, errors[1], 0.001);
}
public void testCircleFittingBadInit() throws EstimationException { public void testCircleFittingBadInit() throws EstimationException {
Circle circle = new Circle(-12, -12); Circle circle = new Circle(-12, -12);

View File

@ -124,6 +124,9 @@ Commons Math Release Notes</title>
<action dev="luc" type="fix" issue="MATH-164"> <action dev="luc" type="fix" issue="MATH-164">
Handle multiplication of Complex numbers with infinite parts specially. Handle multiplication of Complex numbers with infinite parts specially.
</action> </action>
<action dev="luc" type="update" issue="MATH-176" due-to="Kazuhiro Koshino">
Add errors guessing to least-squares estimators.
</action>
</release> </release>
<release version="1.1" date="2005-12-17" <release version="1.1" date="2005-12-17"
description="This is a maintenance release containing bug fixes and enhancements. description="This is a maintenance release containing bug fixes and enhancements.