From 376834acf1c807d231a40b76fffe7d92ed3b4c17 Mon Sep 17 00:00:00 2001
From: Luc Maisonobe
Date: Sat, 19 Jan 2008 22:48:16 +0000
Subject: [PATCH] Added general methods to guess errors on estimated parameters
JIRA: MATH-176
git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@613474 13f79535-47bb-0310-9956-ffa450edef68
---
.../commons/math/MessagesResources_fr.java | 14 +-
.../math/estimation/AbstractEstimator.java | 274 +++++++++++++++
.../math/estimation/EstimationException.java | 13 +-
.../commons/math/estimation/Estimator.java | 23 +-
.../math/estimation/GaussNewtonEstimator.java | 311 +++++++-----------
.../LevenbergMarquardtEstimator.java | 164 +--------
.../LevenbergMarquardtEstimatorTest.java | 47 ++-
xdocs/changes.xml | 3 +
8 files changed, 495 insertions(+), 354 deletions(-)
create mode 100644 src/java/org/apache/commons/math/estimation/AbstractEstimator.java
diff --git a/src/java/org/apache/commons/math/MessagesResources_fr.java b/src/java/org/apache/commons/math/MessagesResources_fr.java
index 4344c9537..c75ad70ac 100644
--- a/src/java/org/apache/commons/math/MessagesResources_fr.java
+++ b/src/java/org/apache/commons/math/MessagesResources_fr.java
@@ -87,9 +87,17 @@ public class MessagesResources_fr
{ "Conversion Exception in Transformation: {0}",
"Exception de conversion dans une transformation : {0}" },
+ // org.apache.commons.math.estimation.AbstractEstimator
+ { "maximal number of evaluations exceeded ({0})",
+ "nombre maximal d''\u00e9valuations d\u00e9pass\u00e9 ({0})" },
+ { "unable to compute covariances: singular problem",
+ "impossible de calculer les covariances : probl\u00e8me singulier"},
+ { "no degrees of freedom ({0} measurements, {1} parameters)",
+ "aucun degr\u00e9 de libert\u00e9 ({0} mesures, {1} param\u00e8tres)" },
+
// org.apache.commons.math.estimation.GaussNewtonEstimator
- { "unable to converge in {0} iterations",
- "pas de convergence apr\u00e8s {0} it\u00e9rations" },
+ { "unable to solve: singular problem",
+ "r\u00e9solution impossible : probl\u00e8me singulier" },
// org.apache.commons.math.estimation.LevenbergMarquardtEstimator
{ "cost relative tolerance is too small ({0}), no further reduction in the sum of squares is possible",
@@ -98,8 +106,6 @@ public class MessagesResources_fr
"trop petite tol\u00e9rance relative sur les param\u00e8tres ({0}), aucune am\u00e9lioration de la solution approximative n''est possible" },
{ "orthogonality tolerance is too small ({0}), solution is orthogonal to the jacobian",
"trop petite tol\u00e9rance sur l''orthogonalit\u00e9 ({0}), la solution est orthogonale \u00e0 la jacobienne" },
- { "maximal number of evaluations exceeded ({0})",
- "nombre maximal d''\u00e9valuations d\u00e9pass\u00e9 ({0})" },
// org.apache.commons.math.geometry.CardanEulerSingularityException
{ "Cardan angles singularity",
diff --git a/src/java/org/apache/commons/math/estimation/AbstractEstimator.java b/src/java/org/apache/commons/math/estimation/AbstractEstimator.java
new file mode 100644
index 000000000..db23cba0b
--- /dev/null
+++ b/src/java/org/apache/commons/math/estimation/AbstractEstimator.java
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math.estimation;
+
+import java.util.Arrays;
+
+import org.apache.commons.math.linear.InvalidMatrixException;
+import org.apache.commons.math.linear.RealMatrixImpl;
+
+public abstract class AbstractEstimator implements Estimator {
+
+ /**
+ * Build an abstract estimator for least squares problems.
+ */
+ protected AbstractEstimator() {
+ }
+
+ /**
+ * Set the maximal number of cost evaluations allowed.
+ *
+ * @param maxCostEval maximal number of cost evaluations allowed
+ * @see #estimate
+ */
+ public void setMaxCostEval(int maxCostEval) {
+ this.maxCostEval = maxCostEval;
+ }
+
+ /**
+ * Get the number of cost evaluations.
+ *
+ * @return number of cost evaluations
+ * */
+ public int getCostEvaluations() {
+ return costEvaluations;
+ }
+
+ /**
+ * Get the number of jacobian evaluations.
+ *
+ * @return number of jacobian evaluations
+ * */
+ public int getJacobianEvaluations() {
+ return jacobianEvaluations;
+ }
+
+ /**
+ * Update the jacobian matrix.
+ */
+ protected void updateJacobian() {
+ ++jacobianEvaluations;
+ Arrays.fill(jacobian, 0);
+ for (int i = 0, index = 0; i < rows; i++) {
+ WeightedMeasurement wm = measurements[i];
+ double factor = -Math.sqrt(wm.getWeight());
+ for (int j = 0; j < cols; ++j) {
+ jacobian[index++] = factor * wm.getPartial(parameters[j]);
+ }
+ }
+ }
+
+ /**
+ * Update the residuals array and cost function value.
+ * @exception EstimationException if the number of cost evaluations
+ * exceeds the maximum allowed
+ */
+ protected void updateResidualsAndCost()
+ throws EstimationException {
+
+ if (++costEvaluations > maxCostEval) {
+ throw new EstimationException("maximal number of evaluations exceeded ({0})",
+ new String[] {
+ Integer.toString(maxCostEval)
+ });
+ }
+
+ cost = 0;
+ for (int i = 0, index = 0; i < rows; i++, index += cols) {
+ WeightedMeasurement wm = measurements[i];
+ double residual = wm.getResidual();
+ residuals[i] = Math.sqrt(wm.getWeight()) * residual;
+ cost += wm.getWeight() * residual * residual;
+ }
+ cost = Math.sqrt(cost);
+
+ }
+
+ /**
+ * Get the Root Mean Square value.
+ * Get the Root Mean Square value, i.e. the root of the arithmetic
+ * mean of the square of all weighted residuals. This is related to the
+ * criterion that is minimized by the estimator as follows: if
+ * c if the criterion, and n is the number of
+ * measurements, then the RMS is sqrt (c/n).
+ *
+ * @param problem estimation problem
+ * @return RMS value
+ */
+ public double getRMS(EstimationProblem problem) {
+ WeightedMeasurement[] wm = problem.getMeasurements();
+ double criterion = 0;
+ for (int i = 0; i < wm.length; ++i) {
+ double residual = wm[i].getResidual();
+ criterion += wm[i].getWeight() * residual * residual;
+ }
+ return Math.sqrt(criterion / wm.length);
+ }
+
+ /**
+ * Get the Chi-Square value.
+ * @param problem estimation problem
+ * @return chi-square value
+ */
+ public double getChiSquare(EstimationProblem problem) {
+ WeightedMeasurement[] wm = problem.getMeasurements();
+ double chiSquare = 0;
+ for (int i = 0; i < wm.length; ++i) {
+ double residual = wm[i].getResidual();
+ chiSquare += residual * residual / wm[i].getWeight();
+ }
+ return chiSquare;
+ }
+
+ /**
+ * Get the covariance matrix of estimated parameters.
+ * @param problem estimation problem
+ * @return covariance matrix
+ * @exception EstimationException if the covariance matrix
+ * cannot be computed (singular problem)
+ */
+ public double[][] getCovariances(EstimationProblem problem)
+ throws EstimationException {
+
+ // set up the jacobian
+ updateJacobian();
+
+ // compute transpose(J).J, avoiding building big intermediate matrices
+ final int rows = problem.getMeasurements().length;
+ final int cols = problem.getAllParameters().length;
+ final int max = cols * rows;
+ double[][] jTj = new double[cols][cols];
+ for (int i = 0; i < cols; ++i) {
+ for (int j = i; j < cols; ++j) {
+ double sum = 0;
+ for (int k = 0; k < max; k += cols) {
+ sum += jacobian[k + i] * jacobian[k + j];
+ }
+ jTj[i][j] = sum;
+ jTj[j][i] = sum;
+ }
+ }
+
+ try {
+ // compute the covariances matrix
+ return new RealMatrixImpl(jTj).inverse().getData();
+ } catch (InvalidMatrixException ime) {
+ throw new EstimationException("unable to compute covariances: singular problem",
+ new Object[0]);
+ }
+
+ }
+
+ /**
+ * Guess the errors in estimated parameters.
+ * Guessing is covariance-based, it only gives rough order of magnitude.
+ * @param problem estimation problem
+ * @return errors in estimated parameters
+ * @exception EstimationException if the covariances matrix cannot be computed
+ * or the number of degrees of freedom is not positive (number of measurements
+ * lesser or equal to number of parameters)
+ */
+ public double[] guessParametersErrors(EstimationProblem problem)
+ throws EstimationException {
+ int m = problem.getMeasurements().length;
+ int p = problem.getAllParameters().length;
+ if (m <= p) {
+ throw new EstimationException("no degrees of freedom ({0} measurements, {1} parameters)",
+ new Object[] { new Integer(m), new Integer(p)});
+ }
+ double[] errors = new double[problem.getAllParameters().length];
+ final double c = Math.sqrt(getChiSquare(problem) / (m - p));
+ double[][] covar = getCovariances(problem);
+ for (int i = 0; i < errors.length; ++i) {
+ errors[i] = Math.sqrt(covar[i][i]) * c;
+ }
+ return errors;
+ }
+
+ /**
+ * Initialization of the common parts of the estimation.
+ * This method must be called at the start
+ * of the {@link #estimate(EstimationProblem) estimate}
+ * method.
+ * @param problem estimation problem to solve
+ */
+ protected void initializeEstimate(EstimationProblem problem) {
+
+ // reset counters
+ costEvaluations = 0;
+ jacobianEvaluations = 0;
+
+ // retrieve the equations and the parameters
+ measurements = problem.getMeasurements();
+ parameters = problem.getUnboundParameters();
+
+ // arrays shared with the other private methods
+ rows = measurements.length;
+ cols = parameters.length;
+ jacobian = new double[rows * cols];
+ residuals = new double[rows];
+
+ cost = Double.POSITIVE_INFINITY;
+
+ }
+
+ public abstract void estimate(EstimationProblem problem)
+ throws EstimationException;
+
+ /** Array of measurements. */
+ protected WeightedMeasurement[] measurements;
+
+ /** Array of parameters. */
+ protected EstimatedParameter[] parameters;
+
+ /**
+ * Jacobian matrix.
+ * This matrix is in canonical form just after the calls to
+ * {@link #updateJacobian()}, but may be modified by the solver
+ * in the derived class (the {@link LevenbergMarquardtEstimator
+ * Levenberg-Marquardt estimator} does this).
+ */
+ protected double[] jacobian;
+
+ /** Number of columns of the jacobian matrix. */
+ protected int cols;
+
+ /** Number of rows of the jacobian matrix. */
+ protected int rows;
+
+ /** Residuals array.
+ * This array is in canonical form just after the calls to
+ * {@link #updateJacobian()}, but may be modified by the solver
+ * in the derived class (the {@link LevenbergMarquardtEstimator
+ * Levenberg-Marquardt estimator} does this).
+ */
+ protected double[] residuals;
+
+ /** Cost value (square root of the sum of the residuals). */
+ protected double cost;
+
+ /** Maximal allowed number of cost evaluations. */
+ protected int maxCostEval;
+
+ /** Number of cost evaluations. */
+ protected int costEvaluations;
+
+ /** Number of jacobian evaluations. */
+ protected int jacobianEvaluations;
+
+}
\ No newline at end of file
diff --git a/src/java/org/apache/commons/math/estimation/EstimationException.java b/src/java/org/apache/commons/math/estimation/EstimationException.java
index 7eba1de7c..9399e8d17 100644
--- a/src/java/org/apache/commons/math/estimation/EstimationException.java
+++ b/src/java/org/apache/commons/math/estimation/EstimationException.java
@@ -30,7 +30,7 @@ public class EstimationException
extends MathException {
/** Serializable version identifier. */
- private static final long serialVersionUID = -7414806622114810487L;
+ private static final long serialVersionUID = -573038581493881337L;
/**
* Simple constructor.
@@ -38,17 +38,8 @@ extends MathException {
* @param specifier format specifier (to be translated)
* @param parts to insert in the format (no translation)
*/
- public EstimationException(String specifier, String[] parts) {
+ public EstimationException(String specifier, Object[] parts) {
super(specifier, parts);
}
- /**
- * Simple constructor.
- * Build an exception from a cause
- * @param cause cause of this exception
- */
- public EstimationException(Throwable cause) {
- super(cause);
- }
-
}
diff --git a/src/java/org/apache/commons/math/estimation/Estimator.java b/src/java/org/apache/commons/math/estimation/Estimator.java
index 5d8b4f07f..0dacd8e5d 100644
--- a/src/java/org/apache/commons/math/estimation/Estimator.java
+++ b/src/java/org/apache/commons/math/estimation/Estimator.java
@@ -60,10 +60,31 @@ public interface Estimator {
* criterion that is minimized by the estimator as follows: if
* c is the criterion, and n is the number of
* measurements, then the RMS is sqrt (c/n).
+ * @see #guessParametersErrors(EstimationProblem)
*
* @param problem estimation problem
* @return RMS value
*/
public double getRMS(EstimationProblem problem);
-
+
+ /**
+ * Get the covariance matrix of estimated parameters.
+ * @param problem estimation problem
+ * @return covariance matrix
+ * @exception EstimationException if the covariance matrix
+ * cannot be computed (singular problem)
+ */
+ public double[][] getCovariances(EstimationProblem problem)
+ throws EstimationException;
+
+ /**
+ * Guess the errors in estimated parameters.
+ * @see #getRMS(EstimationProblem)
+ * @param problem estimation problem
+ * @return errors in estimated parameters
+ * @exception EstimationException if the error cannot be guessed
+ */
+ public double[] guessParametersErrors(EstimationProblem problem)
+ throws EstimationException;
+
}
diff --git a/src/java/org/apache/commons/math/estimation/GaussNewtonEstimator.java b/src/java/org/apache/commons/math/estimation/GaussNewtonEstimator.java
index 3be1c4e95..80c0cba04 100644
--- a/src/java/org/apache/commons/math/estimation/GaussNewtonEstimator.java
+++ b/src/java/org/apache/commons/math/estimation/GaussNewtonEstimator.java
@@ -34,206 +34,145 @@ import org.apache.commons.math.linear.RealMatrixImpl;
*
*/
-public class GaussNewtonEstimator
- implements Estimator, Serializable {
+public class GaussNewtonEstimator extends AbstractEstimator implements Serializable {
- /**
- * Simple constructor.
- *
- * This constructor builds an estimator and stores its convergence
- * characteristics.
- *
- * An estimator is considered to have converged whenever either
- * the criterion goes below a physical threshold under which
- * improvements are considered useless or when the algorithm is
- * unable to improve it (even if it is still high). The first
- * condition that is met stops the iterations.
- *
- * The fact an estimator has converged does not mean that the
- * model accurately fits the measurements. It only means no better
- * solution can be found, it does not mean this one is good. Such an
- * analysis is left to the caller.
- *
- * If neither conditions are fulfilled before a given number of
- * iterations, the algorithm is considered to have failed and an
- * {@link EstimationException} is thrown.
- *
- * @param maxIterations maximum number of iterations allowed
- * @param convergence criterion threshold below which we do not need
- * to improve the criterion anymore
- * @param steadyStateThreshold steady state detection threshold, the
- * problem has converged has reached a steady state if
- * Math.abs (Jn - Jn-1) < Jn * convergence
, where
- * Jn
and Jn-1
are the current and
- * preceding criterion value (square sum of the weighted residuals
- * of considered measurements).
- */
- public GaussNewtonEstimator(int maxIterations,
- double convergence,
- double steadyStateThreshold) {
- this.maxIterations = maxIterations;
- this.steadyStateThreshold = steadyStateThreshold;
- this.convergence = convergence;
- }
+ /**
+ * Simple constructor.
+ *
+ * This constructor builds an estimator and stores its convergence
+ * characteristics.
+ *
+ * An estimator is considered to have converged whenever either
+ * the criterion goes below a physical threshold under which
+ * improvements are considered useless or when the algorithm is
+ * unable to improve it (even if it is still high). The first
+ * condition that is met stops the iterations.
+ *
+ * The fact an estimator has converged does not mean that the
+ * model accurately fits the measurements. It only means no better
+ * solution can be found, it does not mean this one is good. Such an
+ * analysis is left to the caller.
+ *
+ * If neither conditions are fulfilled before a given number of
+ * iterations, the algorithm is considered to have failed and an
+ * {@link EstimationException} is thrown.
+ *
+ * @param maxCostEval maximal number of cost evaluations allowed
+ * @param convergence criterion threshold below which we do not need
+ * to improve the criterion anymore
+ * @param steadyStateThreshold steady state detection threshold, the
+ * problem has converged has reached a steady state if
+ * Math.abs (Jn - Jn-1) < Jn * convergence
, where
+ * Jn
and Jn-1
are the current and
+ * preceding criterion value (square sum of the weighted residuals
+ * of considered measurements).
+ */
+ public GaussNewtonEstimator(int maxCostEval,
+ double convergence,
+ double steadyStateThreshold) {
+ setMaxCostEval(maxCostEval);
+ this.steadyStateThreshold = steadyStateThreshold;
+ this.convergence = convergence;
+ }
- /**
- * Solve an estimation problem using a least squares criterion.
- *
- * This method set the unbound parameters of the given problem
- * starting from their current values through several iterations. At
- * each step, the unbound parameters are changed in order to
- * minimize a weighted least square criterion based on the
- * measurements of the problem.
- *
- * The iterations are stopped either when the criterion goes
- * below a physical threshold under which improvement are considered
- * useless or when the algorithm is unable to improve it (even if it
- * is still high). The first condition that is met stops the
- * iterations. If the convergence it nos reached before the maximum
- * number of iterations, an {@link EstimationException} is
- * thrown.
- *
- * @param problem estimation problem to solve
- * @exception EstimationException if the problem cannot be solved
- *
- * @see EstimationProblem
- *
- */
- public void estimate(EstimationProblem problem)
- throws EstimationException {
- int iterations = 0;
- double previous = 0.0;
- double current = 0.0;
-
- // iterate until convergence is reached
- do {
-
- if (++iterations > maxIterations) {
- throw new EstimationException ("unable to converge in {0} iterations",
- new String[] {
- Integer.toString(maxIterations)
- });
- }
-
- // perform one iteration
- linearEstimate(problem);
-
- previous = current;
- current = evaluateCriterion(problem);
-
- } while ((iterations < 2)
- || (Math.abs(previous - current) > (current * steadyStateThreshold)
- && (Math.abs(current) > convergence)));
-
- }
-
- /**
- * Estimate the solution of a linear least square problem.
- *
- * The Gauss-Newton algorithm is iterative. Each iteration
- * consists in solving a linearized least square problem. Several
- * iterations are needed for general problems since the
- * linearization is only an approximation of the problem
- * behaviour. However, for linear problems one iteration is enough
- * to get the solution. This method is provided in the public
- * interface in order to handle more efficiently these linear
- * problems.
- *
- * @param problem estimation problem to solve
- * @exception EstimationException if the problem cannot be solved
- *
- */
- public void linearEstimate(EstimationProblem problem)
+ /**
+ * Solve an estimation problem using a least squares criterion.
+ *
+ * This method set the unbound parameters of the given problem
+ * starting from their current values through several iterations. At
+ * each step, the unbound parameters are changed in order to
+ * minimize a weighted least square criterion based on the
+ * measurements of the problem.
+ *
+ * The iterations are stopped either when the criterion goes
+ * below a physical threshold under which improvement are considered
+ * useless or when the algorithm is unable to improve it (even if it
+ * is still high). The first condition that is met stops the
+ * iterations. If the convergence it nos reached before the maximum
+ * number of iterations, an {@link EstimationException} is
+ * thrown.
+ *
+ * @param problem estimation problem to solve
+ * @exception EstimationException if the problem cannot be solved
+ *
+ * @see EstimationProblem
+ *
+ */
+ public void estimate(EstimationProblem problem)
throws EstimationException {
- EstimatedParameter[] parameters = problem.getUnboundParameters();
- WeightedMeasurement[] measurements = problem.getMeasurements();
+ initializeEstimate(problem);
- // build the linear problem
- RealMatrix b = new RealMatrixImpl(parameters.length, 1);
- RealMatrix a = new RealMatrixImpl(parameters.length, parameters.length);
- double[] grad = new double[parameters.length];
- RealMatrixImpl bDecrement = new RealMatrixImpl(parameters.length, 1);
- double[][] bDecrementData = bDecrement.getDataRef();
- RealMatrixImpl wGradGradT = new RealMatrixImpl(parameters.length, parameters.length);
- double[][] wggData = wGradGradT.getDataRef();
- for (int i = 0; i < measurements.length; ++i) {
- if (! measurements [i].isIgnored()) {
+ // work matrices
+ double[] grad = new double[parameters.length];
+ RealMatrixImpl bDecrement = new RealMatrixImpl(parameters.length, 1);
+ double[][] bDecrementData = bDecrement.getDataRef();
+ RealMatrixImpl wGradGradT = new RealMatrixImpl(parameters.length, parameters.length);
+ double[][] wggData = wGradGradT.getDataRef();
- double weight = measurements[i].getWeight();
- double residual = measurements[i].getResidual();
+ // iterate until convergence is reached
+ double previous = Double.POSITIVE_INFINITY;
+ do {
- // compute the normal equation
- for (int j = 0; j < parameters.length; ++j) {
- grad[j] = measurements[i].getPartial(parameters[j]);
- bDecrementData[j][0] = weight * residual * grad[j];
- }
+ // build the linear problem
+ ++jacobianEvaluations;
+ RealMatrix b = new RealMatrixImpl(parameters.length, 1);
+ RealMatrix a = new RealMatrixImpl(parameters.length, parameters.length);
+ for (int i = 0; i < measurements.length; ++i) {
+ if (! measurements [i].isIgnored()) {
+
+ double weight = measurements[i].getWeight();
+ double residual = measurements[i].getResidual();
+
+ // compute the normal equation
+ for (int j = 0; j < parameters.length; ++j) {
+ grad[j] = measurements[i].getPartial(parameters[j]);
+ bDecrementData[j][0] = weight * residual * grad[j];
+ }
+
+ // build the contribution matrix for measurement i
+ for (int k = 0; k < parameters.length; ++k) {
+ double[] wggRow = wggData[k];
+ double gk = grad[k];
+ for (int l = 0; l < parameters.length; ++l) {
+ wggRow[l] = weight * gk * grad[l];
+ }
+ }
+
+ // update the matrices
+ a = a.add(wGradGradT);
+ b = b.add(bDecrement);
- // build the contribution matrix for measurement i
- for (int k = 0; k < parameters.length; ++k) {
- double[] wggRow = wggData[k];
- double gk = grad[k];
- for (int l = 0; l < parameters.length; ++l) {
- wggRow[l] = weight * gk * grad[l];
}
}
- // update the matrices
- a = a.add(wGradGradT);
- b = b.add(bDecrement);
+ try {
+
+ // solve the linearized least squares problem
+ RealMatrix dX = a.solve(b);
+
+ // update the estimated parameters
+ for (int i = 0; i < parameters.length; ++i) {
+ parameters[i].setEstimate(parameters[i].getEstimate() + dX.getEntry(i, 0));
+ }
+
+ } catch(InvalidMatrixException e) {
+ throw new EstimationException("unable to solve: singular problem", new Object[0]);
+ }
+
+
+ previous = cost;
+ updateResidualsAndCost();
+
+ } while ((getCostEvaluations() < 2)
+ || (Math.abs(previous - cost) > (cost * steadyStateThreshold)
+ && (Math.abs(cost) > convergence)));
- }
}
- try {
+ private double steadyStateThreshold;
+ private double convergence;
- // solve the linearized least squares problem
- RealMatrix dX = a.solve(b);
-
- // update the estimated parameters
- for (int i = 0; i < parameters.length; ++i) {
- parameters[i].setEstimate(parameters[i].getEstimate() + dX.getEntry(i, 0));
- }
-
- } catch(InvalidMatrixException e) {
- throw new EstimationException(e);
- }
-
- }
-
- private double evaluateCriterion(EstimationProblem problem) {
- double criterion = 0.0;
- WeightedMeasurement[] measurements = problem.getMeasurements();
-
- for (int i = 0; i < measurements.length; ++i) {
- double residual = measurements[i].getResidual();
- criterion += measurements[i].getWeight() * residual * residual;
- }
-
- return criterion;
-
- }
-
- /**
- * Get the Root Mean Square value.
- * Get the Root Mean Square value, i.e. the root of the arithmetic
- * mean of the square of all weighted residuals. This is related to the
- * criterion that is minimized by the estimator as follows: if
- * c if the criterion, and n is the number of
- * measurements, then the RMS is sqrt (c/n).
- * @param problem estimation problem
- * @return RMS value
- */
- public double getRMS(EstimationProblem problem) {
- double criterion = evaluateCriterion(problem);
- int n = problem.getMeasurements().length;
- return Math.sqrt(criterion / n);
- }
-
- private int maxIterations;
- private double steadyStateThreshold;
- private double convergence;
-
- private static final long serialVersionUID = -7606628156644194170L;
+ private static final long serialVersionUID = 5485001826076289109L;
}
diff --git a/src/java/org/apache/commons/math/estimation/LevenbergMarquardtEstimator.java b/src/java/org/apache/commons/math/estimation/LevenbergMarquardtEstimator.java
index 3eeee1576..858419e73 100644
--- a/src/java/org/apache/commons/math/estimation/LevenbergMarquardtEstimator.java
+++ b/src/java/org/apache/commons/math/estimation/LevenbergMarquardtEstimator.java
@@ -19,6 +19,7 @@ package org.apache.commons.math.estimation;
import java.io.Serializable;
import java.util.Arrays;
+
/**
* This class solves a least squares problem.
*
@@ -92,7 +93,7 @@ import java.util.Arrays;
* @author Kenneth E. Hillstrom (original fortran)
* @author Jorge J. More (original fortran)
*/
-public class LevenbergMarquardtEstimator implements Serializable, Estimator {
+public class LevenbergMarquardtEstimator extends AbstractEstimator implements Serializable {
/**
* Build an estimator for least squares problems.
@@ -107,12 +108,16 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
*
*/
public LevenbergMarquardtEstimator() {
+
+ // set up the superclass with a default max cost evaluations setting
+ setMaxCostEval(1000);
+
// default values for the tuning parameters
setInitialStepBoundFactor(100.0);
- setMaxCostEval(1000);
setCostRelativeTolerance(1.0e-10);
setParRelativeTolerance(1.0e-10);
setOrthoTolerance(1.0e-10);
+
}
/**
@@ -128,16 +133,6 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
this.initialStepBoundFactor = initialStepBoundFactor;
}
- /**
- * Set the maximal number of cost evaluations.
- *
- * @param maxCostEval maximal number of cost evaluations
- * @see #estimate
- */
- public void setMaxCostEval(int maxCostEval) {
- this.maxCostEval = maxCostEval;
- }
-
/**
* Set the desired relative error in the sum of squares.
*
@@ -170,75 +165,6 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
this.orthoTolerance = orthoTolerance;
}
- /**
- * Get the number of cost evaluations.
- *
- * @return number of cost evaluations
- * */
- public int getCostEvaluations() {
- return costEvaluations;
- }
-
- /**
- * Get the number of jacobian evaluations.
- *
- * @return number of jacobian evaluations
- * */
- public int getJacobianEvaluations() {
- return jacobianEvaluations;
- }
-
- /**
- * Update the jacobian matrix.
- */
- private void updateJacobian() {
- ++jacobianEvaluations;
- Arrays.fill(jacobian, 0);
- for (int i = 0, index = 0; i < rows; i++) {
- WeightedMeasurement wm = measurements[i];
- double factor = -Math.sqrt(wm.getWeight());
- for (int j = 0; j < cols; ++j) {
- jacobian[index++] = factor * wm.getPartial(parameters[j]);
- }
- }
- }
-
- /**
- * Update the residuals array and cost function value.
- */
- private void updateResidualsAndCost() {
- ++costEvaluations;
- cost = 0;
- for (int i = 0, index = 0; i < rows; i++, index += cols) {
- WeightedMeasurement wm = measurements[i];
- double residual = wm.getResidual();
- residuals[i] = Math.sqrt(wm.getWeight()) * residual;
- cost += wm.getWeight() * residual * residual;
- }
- cost = Math.sqrt(cost);
- }
-
- /**
- * Get the Root Mean Square value.
- * Get the Root Mean Square value, i.e. the root of the arithmetic
- * mean of the square of all weighted residuals. This is related to the
- * criterion that is minimized by the estimator as follows: if
- * c if the criterion, and n is the number of
- * measurements, then the RMS is sqrt (c/n).
- *
- * @param problem estimation problem
- * @return RMS value
- */
- public double getRMS(EstimationProblem problem) {
- WeightedMeasurement[] wm = problem.getMeasurements();
- double criterion = 0;
- for (int i = 0; i < wm.length; ++i) {
- double residual = wm[i].getResidual();
- criterion += wm[i].getWeight() * residual * residual;
- }
- return Math.sqrt(criterion / wm.length);
- }
-
/**
* Solve an estimation problem using the Levenberg-Marquardt algorithm.
* The algorithm used is a modified Levenberg-Marquardt one, based
@@ -263,7 +189,6 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
* reached with the specified algorithm settings or if there are more variables
* than equations
* @see #setInitialStepBoundFactor
- * @see #setMaxCostEval
* @see #setCostRelativeTolerance
* @see #setParRelativeTolerance
* @see #setOrthoTolerance
@@ -271,21 +196,15 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
public void estimate(EstimationProblem problem)
throws EstimationException {
- // retrieve the equations and the parameters
- measurements = problem.getMeasurements();
- parameters = problem.getUnboundParameters();
+ initializeEstimate(problem);
// arrays shared with the other private methods
- rows = measurements.length;
- cols = parameters.length;
solvedCols = Math.min(rows, cols);
- jacobian = new double[rows * cols];
diagR = new double[cols];
jacNorm = new double[cols];
beta = new double[cols];
permutation = new int[cols];
lmDir = new double[cols];
- residuals = new double[rows];
// local variables
double delta = 0, xNorm = 0;
@@ -300,11 +219,9 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
updateResidualsAndCost();
// outer loop
- lmPar = 0;
- costEvaluations = 0;
- jacobianEvaluations = 0;
+ lmPar = 0;
boolean firstIteration = true;
- while (costEvaluations < maxCostEval) {
+ while (true) {
// compute the Q.R. decomposition of the jacobian matrix
updateJacobian();
@@ -477,42 +394,28 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
// tests for termination and stringent tolerances
// (2.2204e-16 is the machine epsilon for IEEE754)
- if (costEvaluations >= maxCostEval) {
- break;
- }
if ((Math.abs(actRed) <= 2.2204e-16)
&& (preRed <= 2.2204e-16)
&& (ratio <= 2.0)) {
throw new EstimationException("cost relative tolerance is too small ({0}),"
+ " no further reduction in the"
+ " sum of squares is possible",
- new String[] {
- Double.toString(costRelativeTolerance)
- });
+ new Object[] { new Double(costRelativeTolerance) });
} else if (delta <= 2.2204e-16 * xNorm) {
throw new EstimationException("parameters relative tolerance is too small"
+ " ({0}), no further improvement in"
+ " the approximate solution is possible",
- new String[] {
- Double.toString(parRelativeTolerance)
- });
+ new Object[] { new Double(parRelativeTolerance) });
} else if (maxCosine <= 2.2204e-16) {
throw new EstimationException("orthogonality tolerance is too small ({0}),"
+ " solution is orthogonal to the jacobian",
- new String[] {
- Double.toString(orthoTolerance)
- });
+ new Object[] { new Double(orthoTolerance) });
}
}
}
- throw new EstimationException("maximal number of evaluations exceeded ({0})",
- new String[] {
- Integer.toString(maxCostEval)
- });
-
}
/**
@@ -919,29 +822,9 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
}
}
- /** Array of measurements. */
- private WeightedMeasurement[] measurements;
-
- /** Array of parameters. */
- private EstimatedParameter[] parameters;
-
- /**
- * Jacobian matrix.
- *
Depending on the computation phase, this matrix is either in
- * canonical form (just after the calls to updateJacobian) or in
- * Q.R. decomposed form (after calls to qrDecomposition)
- */
- private double[] jacobian;
-
- /** Number of columns of the jacobian matrix. */
- private int cols;
-
/** Number of solved variables. */
private int solvedCols;
- /** Number of rows of the jacobian matrix. */
- private int rows;
-
/** Diagonal elements of the R matrix in the Q.R. decomposition. */
private double[] diagR;
@@ -963,28 +846,9 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
/** Parameters evolution direction associated with lmPar. */
private double[] lmDir;
- /** Residuals array.
- * Depending on the computation phase, this array is either in
- * canonical form (just after the calls to updateResiduals) or in
- * premultiplied by Qt form (just after calls to qTy)
- */
- private double[] residuals;
-
- /** Cost value (square root of the sum of the residuals). */
- private double cost;
-
/** Positive input variable used in determining the initial step bound. */
private double initialStepBoundFactor;
- /** Maximal number of cost evaluations. */
- private int maxCostEval;
-
- /** Number of cost evaluations. */
- private int costEvaluations;
-
- /** Number of jacobian evaluations. */
- private int jacobianEvaluations;
-
/** Desired relative error in the sum of squares. */
private double costRelativeTolerance;
@@ -995,6 +859,6 @@ public class LevenbergMarquardtEstimator implements Serializable, Estimator {
* and the columns of the jacobian. */
private double orthoTolerance;
- private static final long serialVersionUID = 5387476316105068340L;
+ private static final long serialVersionUID = -5705952631533171019L;
}
diff --git a/src/test/org/apache/commons/math/estimation/LevenbergMarquardtEstimatorTest.java b/src/test/org/apache/commons/math/estimation/LevenbergMarquardtEstimatorTest.java
index cd6de0cc1..dc64695a1 100644
--- a/src/test/org/apache/commons/math/estimation/LevenbergMarquardtEstimatorTest.java
+++ b/src/test/org/apache/commons/math/estimation/LevenbergMarquardtEstimatorTest.java
@@ -110,6 +110,14 @@ public class LevenbergMarquardtEstimatorTest
LevenbergMarquardtEstimator estimator = new LevenbergMarquardtEstimator();
estimator.estimate(problem);
assertEquals(0, estimator.getRMS(problem), 1.0e-10);
+ try {
+ estimator.guessParametersErrors(problem);
+ fail("an exception should have been thrown");
+ } catch (EstimationException ee) {
+ // expected behavior
+ } catch (Exception e) {
+ fail("wrong exception caught");
+ }
assertEquals(1.5,
problem.getUnboundParameters()[0].getEstimate(),
1.0e-10);
@@ -267,7 +275,15 @@ public class LevenbergMarquardtEstimatorTest
estimator.estimate(problem);
assertTrue(estimator.getRMS(problem) < initialCost);
assertTrue(Math.sqrt(m.length) * estimator.getRMS(problem) > 0.6);
- double dJ0 = 2 * (m[0].getResidual() * m[0].getPartial(p[0])
+ try {
+ estimator.getCovariances(problem);
+ fail("an exception should have been thrown");
+ } catch (EstimationException ee) {
+ // expected behavior
+ } catch (Exception e) {
+ fail("wrong exception caught");
+ }
+ double dJ0 = 2 * (m[0].getResidual() * m[0].getPartial(p[0])
+ m[1].getResidual() * m[1].getPartial(p[0])
+ m[2].getResidual() * m[2].getPartial(p[0]));
double dJ1 = 2 * (m[0].getResidual() * m[0].getPartial(p[1])
@@ -496,7 +512,34 @@ public class LevenbergMarquardtEstimatorTest
assertEquals(69.96016176931406, circle.getRadius(), 1.0e-10);
assertEquals(96.07590211815305, circle.getX(), 1.0e-10);
assertEquals(48.13516790438953, circle.getY(), 1.0e-10);
- }
+ double[][] cov = estimator.getCovariances(circle);
+ assertEquals(1.839, cov[0][0], 0.001);
+ assertEquals(0.731, cov[0][1], 0.001);
+ assertEquals(cov[0][1], cov[1][0], 1.0e-14);
+ assertEquals(0.786, cov[1][1], 0.001);
+ double[] errors = estimator.guessParametersErrors(circle);
+ assertEquals(1.384, errors[0], 0.001);
+ assertEquals(0.905, errors[1], 0.001);
+
+ // add perfect measurements and check errors are reduced
+ double cx = circle.getX();
+ double cy = circle.getY();
+ double r = circle.getRadius();
+ for (double d= 0; d < 2 * Math.PI; d += 0.01) {
+ circle.addPoint(cx + r * Math.cos(d), cy + r * Math.sin(d));
+ }
+ estimator = new LevenbergMarquardtEstimator();
+ estimator.estimate(circle);
+ cov = estimator.getCovariances(circle);
+ assertEquals(0.004, cov[0][0], 0.001);
+ assertEquals(6.40e-7, cov[0][1], 1.0e-9);
+ assertEquals(cov[0][1], cov[1][0], 1.0e-14);
+ assertEquals(0.003, cov[1][1], 0.001);
+ errors = estimator.guessParametersErrors(circle);
+ assertEquals(0.004, errors[0], 0.001);
+ assertEquals(0.004, errors[1], 0.001);
+
+ }
public void testCircleFittingBadInit() throws EstimationException {
Circle circle = new Circle(-12, -12);
diff --git a/xdocs/changes.xml b/xdocs/changes.xml
index aade6532e..94f34f5b3 100644
--- a/xdocs/changes.xml
+++ b/xdocs/changes.xml
@@ -124,6 +124,9 @@ Commons Math Release Notes
Handle multiplication of Complex numbers with infinite parts specially.
+
+ Add errors guessing to least-squares estimators.
+