Use Real{Vector,Matrix} in LeastSquares interfaces
Covered all of the interfaces in the leastsquares package to use RealVector and RealMatrix instead of double[] and double[][]. This reduced some duplicated code. For example Evaluation.computeResiduals() was a complete duplication of RealVector.subtract(). It also presents a consistent interface and allows data encapsulation. Lastly, this change enables [math] to "eat our own dog food." It enables the linear package to be used in the implementation of the optimization algorithms. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1569354 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
a7a380f934
commit
0079828734
|
@ -5,6 +5,7 @@ import org.apache.commons.math3.linear.ArrayRealVector;
|
||||||
import org.apache.commons.math3.linear.DecompositionSolver;
|
import org.apache.commons.math3.linear.DecompositionSolver;
|
||||||
import org.apache.commons.math3.linear.QRDecomposition;
|
import org.apache.commons.math3.linear.QRDecomposition;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -31,7 +32,7 @@ abstract class AbstractEvaluation implements Evaluation {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
public double[][] computeCovariances(double threshold) {
|
public RealMatrix computeCovariances(double threshold) {
|
||||||
// Set up the Jacobian.
|
// Set up the Jacobian.
|
||||||
final RealMatrix j = this.computeJacobian();
|
final RealMatrix j = this.computeJacobian();
|
||||||
|
|
||||||
|
@ -41,16 +42,16 @@ abstract class AbstractEvaluation implements Evaluation {
|
||||||
// Compute the covariances matrix.
|
// Compute the covariances matrix.
|
||||||
final DecompositionSolver solver
|
final DecompositionSolver solver
|
||||||
= new QRDecomposition(jTj, threshold).getSolver();
|
= new QRDecomposition(jTj, threshold).getSolver();
|
||||||
return solver.getInverse().getData();
|
return solver.getInverse();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
public double[] computeSigma(double covarianceSingularityThreshold) {
|
public RealVector computeSigma(double covarianceSingularityThreshold) {
|
||||||
final double[][] cov = this.computeCovariances(covarianceSingularityThreshold);
|
final RealMatrix cov = this.computeCovariances(covarianceSingularityThreshold);
|
||||||
final int nC = cov.length;
|
final int nC = cov.getColumnDimension();
|
||||||
final double[] sig = new double[nC];
|
final RealVector sig = new ArrayRealVector(nC);
|
||||||
for (int i = 0; i < nC; ++i) {
|
for (int i = 0; i < nC; ++i) {
|
||||||
sig[i] = FastMath.sqrt(cov[i][i]);
|
sig.setEntry(i, FastMath.sqrt(cov.getEntry(i,i)));
|
||||||
}
|
}
|
||||||
return sig;
|
return sig;
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Applies a dense weight matrix to an evaluation.
|
* Applies a dense weight matrix to an evaluation.
|
||||||
|
@ -37,19 +38,19 @@ class DenseWeightedEvaluation extends AbstractEvaluation {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
public double[] computeResiduals() {
|
public RealVector computeResiduals() {
|
||||||
return this.weightSqrt.operate(this.unweighted.computeResiduals());
|
return this.weightSqrt.operate(this.unweighted.computeResiduals());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* delegate */
|
/* delegate */
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
public double[] getPoint() {
|
public RealVector getPoint() {
|
||||||
return unweighted.getPoint();
|
return unweighted.getPoint();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
public double[] computeValue() {
|
public RealVector computeValue() {
|
||||||
return unweighted.computeValue();
|
return unweighted.computeValue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.commons.math3.linear.DecompositionSolver;
|
||||||
import org.apache.commons.math3.linear.LUDecomposition;
|
import org.apache.commons.math3.linear.LUDecomposition;
|
||||||
import org.apache.commons.math3.linear.QRDecomposition;
|
import org.apache.commons.math3.linear.QRDecomposition;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
import org.apache.commons.math3.linear.SingularMatrixException;
|
import org.apache.commons.math3.linear.SingularMatrixException;
|
||||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||||
import org.apache.commons.math3.util.Incrementor;
|
import org.apache.commons.math3.util.Incrementor;
|
||||||
|
@ -133,7 +134,7 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
|
||||||
final int nR = lsp.getObservationSize(); // Number of observed data.
|
final int nR = lsp.getObservationSize(); // Number of observed data.
|
||||||
final int nC = lsp.getParameterSize();
|
final int nC = lsp.getParameterSize();
|
||||||
|
|
||||||
final double[] currentPoint = lsp.getStart();
|
final RealVector currentPoint = lsp.getStart();
|
||||||
|
|
||||||
// iterate until convergence is reached
|
// iterate until convergence is reached
|
||||||
Evaluation current = null;
|
Evaluation current = null;
|
||||||
|
@ -145,7 +146,7 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
|
||||||
// Value of the objective function at "currentPoint".
|
// Value of the objective function at "currentPoint".
|
||||||
evaluationCounter.incrementCount();
|
evaluationCounter.incrementCount();
|
||||||
current = lsp.evaluate(currentPoint);
|
current = lsp.evaluate(currentPoint);
|
||||||
final double[] currentResiduals = current.computeResiduals();
|
final RealVector currentResiduals = current.computeResiduals();
|
||||||
final RealMatrix weightedJacobian = current.computeJacobian();
|
final RealMatrix weightedJacobian = current.computeJacobian();
|
||||||
|
|
||||||
// Check convergence.
|
// Check convergence.
|
||||||
|
@ -164,7 +165,7 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
|
||||||
for (int i = 0; i < nR; ++i) {
|
for (int i = 0; i < nR; ++i) {
|
||||||
|
|
||||||
final double[] grad = weightedJacobian.getRow(i);
|
final double[] grad = weightedJacobian.getRow(i);
|
||||||
final double residual = currentResiduals[i];
|
final double residual = currentResiduals.getEntry(i);
|
||||||
|
|
||||||
// compute the normal equation
|
// compute the normal equation
|
||||||
//residual is already weighted
|
//residual is already weighted
|
||||||
|
@ -186,10 +187,10 @@ public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
|
||||||
// solve the linearized least squares problem
|
// solve the linearized least squares problem
|
||||||
RealMatrix mA = new BlockRealMatrix(a);
|
RealMatrix mA = new BlockRealMatrix(a);
|
||||||
DecompositionSolver solver = this.decomposition.getSolver(mA);
|
DecompositionSolver solver = this.decomposition.getSolver(mA);
|
||||||
final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
|
final RealVector dX = solver.solve(new ArrayRealVector(b, false));
|
||||||
// update the estimated parameters
|
// update the estimated parameters
|
||||||
for (int i = 0; i < nC; ++i) {
|
for (int i = 0; i < nC; ++i) {
|
||||||
currentPoint[i] += dX[i];
|
currentPoint.setEntry(i, currentPoint.getEntry(i) + dX.getEntry(i));
|
||||||
}
|
}
|
||||||
} catch (SingularMatrixException e) {
|
} catch (SingularMatrixException e) {
|
||||||
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
|
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package org.apache.commons.math3.fitting.leastsquares;
|
package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||||
import org.apache.commons.math3.util.Incrementor;
|
import org.apache.commons.math3.util.Incrementor;
|
||||||
|
|
||||||
|
@ -23,7 +24,7 @@ public class LeastSquaresAdapter implements LeastSquaresProblem {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
public double[] getStart() {
|
public RealVector getStart() {
|
||||||
return problem.getStart();
|
return problem.getStart();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,8 +38,9 @@ public class LeastSquaresAdapter implements LeastSquaresProblem {
|
||||||
return problem.getParameterSize();
|
return problem.getParameterSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc}
|
||||||
public Evaluation evaluate(final double[] point) {
|
* @param point*/
|
||||||
|
public Evaluation evaluate(final RealVector point) {
|
||||||
return problem.evaluate(point);
|
return problem.evaluate(point);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,8 +39,8 @@ public class LeastSquaresFactory {
|
||||||
* @return the specified General Least Squares problem.
|
* @return the specified General Least Squares problem.
|
||||||
*/
|
*/
|
||||||
public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
|
public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
|
||||||
final double[] observed,
|
final RealVector observed,
|
||||||
final double[] start,
|
final RealVector start,
|
||||||
final ConvergenceChecker<Evaluation> checker,
|
final ConvergenceChecker<Evaluation> checker,
|
||||||
final int maxEvaluations,
|
final int maxEvaluations,
|
||||||
final int maxIterations) {
|
final int maxIterations) {
|
||||||
|
@ -54,6 +54,34 @@ public class LeastSquaresFactory {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem}
|
||||||
|
* from the given elements. There will be no weights applied (Identity weights).
|
||||||
|
*
|
||||||
|
* @param model the model function. Produces the computed values.
|
||||||
|
* @param observed the observed (target) values
|
||||||
|
* @param start the initial guess.
|
||||||
|
* @param checker convergence checker
|
||||||
|
* @param maxEvaluations the maximum number of times to evaluate the model
|
||||||
|
* @param maxIterations the maximum number to times to iterate in the algorithm
|
||||||
|
* @return the specified General Least Squares problem.
|
||||||
|
*/
|
||||||
|
public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
|
||||||
|
final double[] observed,
|
||||||
|
final double[] start,
|
||||||
|
final ConvergenceChecker<Evaluation> checker,
|
||||||
|
final int maxEvaluations,
|
||||||
|
final int maxIterations) {
|
||||||
|
return create(
|
||||||
|
model,
|
||||||
|
new ArrayRealVector(observed, false),
|
||||||
|
new ArrayRealVector(start, false),
|
||||||
|
checker,
|
||||||
|
maxEvaluations,
|
||||||
|
maxIterations
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem}
|
* Create a {@link org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem}
|
||||||
* from the given elements. There will be no weights applied (Identity weights).
|
* from the given elements. There will be no weights applied (Identity weights).
|
||||||
|
@ -132,7 +160,7 @@ public class LeastSquaresFactory {
|
||||||
final RealMatrix weightSquareRoot = squareRoot(weights);
|
final RealMatrix weightSquareRoot = squareRoot(weights);
|
||||||
return new LeastSquaresAdapter(problem) {
|
return new LeastSquaresAdapter(problem) {
|
||||||
@Override
|
@Override
|
||||||
public Evaluation evaluate(final double[] point) {
|
public Evaluation evaluate(final RealVector point) {
|
||||||
return new DenseWeightedEvaluation(super.evaluate(point), weightSquareRoot);
|
return new DenseWeightedEvaluation(super.evaluate(point), weightSquareRoot);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -154,7 +182,7 @@ public class LeastSquaresFactory {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Count the evaluations of a particular problem. The {@code counter} will be
|
* Count the evaluations of a particular problem. The {@code counter} will be
|
||||||
* incremented every time {@link LeastSquaresProblem#evaluate(double[])} is called on
|
* incremented every time {@link LeastSquaresProblem#evaluate(RealVector)} is called on
|
||||||
* the <em>returned</em> problem.
|
* the <em>returned</em> problem.
|
||||||
*
|
*
|
||||||
* @param problem the problem to track.
|
* @param problem the problem to track.
|
||||||
|
@ -165,7 +193,7 @@ public class LeastSquaresFactory {
|
||||||
final Incrementor counter) {
|
final Incrementor counter) {
|
||||||
return new LeastSquaresAdapter(problem) {
|
return new LeastSquaresAdapter(problem) {
|
||||||
|
|
||||||
public Evaluation evaluate(final double[] point) {
|
public Evaluation evaluate(final RealVector point) {
|
||||||
counter.incrementCount();
|
counter.incrementCount();
|
||||||
return super.evaluate(point);
|
return super.evaluate(point);
|
||||||
}
|
}
|
||||||
|
@ -192,12 +220,12 @@ public class LeastSquaresFactory {
|
||||||
return checker.converged(
|
return checker.converged(
|
||||||
iteration,
|
iteration,
|
||||||
new PointVectorValuePair(
|
new PointVectorValuePair(
|
||||||
previous.getPoint(),
|
previous.getPoint().toArray(),
|
||||||
previous.computeValue(),
|
previous.computeValue().toArray(),
|
||||||
false),
|
false),
|
||||||
new PointVectorValuePair(
|
new PointVectorValuePair(
|
||||||
current.getPoint(),
|
current.getPoint().toArray(),
|
||||||
current.computeValue(),
|
current.computeValue().toArray(),
|
||||||
false)
|
false)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -237,11 +265,13 @@ public class LeastSquaresFactory {
|
||||||
final MultivariateMatrixFunction jacobian
|
final MultivariateMatrixFunction jacobian
|
||||||
) {
|
) {
|
||||||
return new MultivariateJacobianFunction() {
|
return new MultivariateJacobianFunction() {
|
||||||
public Pair<RealVector, RealMatrix> value(final double[] point) {
|
public Pair<RealVector, RealMatrix> value(final RealVector point) {
|
||||||
//evaluate and use Real* interfaces without copying
|
//TODO get array from RealVector without copying?
|
||||||
|
final double[] pointArray = point.toArray();
|
||||||
|
//evaluate and return data without copying
|
||||||
return new Pair<RealVector, RealMatrix>(
|
return new Pair<RealVector, RealMatrix>(
|
||||||
new ArrayRealVector(value.value(point), false),
|
new ArrayRealVector(value.value(pointArray), false),
|
||||||
new Array2DRowRealMatrix(jacobian.value(point), false));
|
new Array2DRowRealMatrix(jacobian.value(pointArray), false));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package org.apache.commons.math3.fitting.leastsquares;
|
||||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||||
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The data necessary to define a non-linear least squares problem. Includes the observed
|
* The data necessary to define a non-linear least squares problem. Includes the observed
|
||||||
|
@ -19,7 +20,7 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
|
||||||
*
|
*
|
||||||
* @return the initial guess values.
|
* @return the initial guess values.
|
||||||
*/
|
*/
|
||||||
double[] getStart();
|
RealVector getStart();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the number of observations (rows in the Jacobian) in this problem.
|
* Get the number of observations (rows in the Jacobian) in this problem.
|
||||||
|
@ -38,13 +39,14 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
|
||||||
/**
|
/**
|
||||||
* Evaluate the model at the specified point.
|
* Evaluate the model at the specified point.
|
||||||
*
|
*
|
||||||
|
*
|
||||||
* @param point the parameter values.
|
* @param point the parameter values.
|
||||||
* @return the model's value and derivative at the given point.
|
* @return the model's value and derivative at the given point.
|
||||||
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
|
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
|
||||||
* if the maximal number of evaluations (of the model vector function) is
|
* if the maximal number of evaluations (of the model vector function) is
|
||||||
* exceeded.
|
* exceeded.
|
||||||
*/
|
*/
|
||||||
Evaluation evaluate(double[] point);
|
Evaluation evaluate(RealVector point);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An evaluation of a {@link LeastSquaresProblem} at a particular point. This class
|
* An evaluation of a {@link LeastSquaresProblem} at a particular point. This class
|
||||||
|
@ -59,12 +61,13 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
|
||||||
* way for the caller to specify that the result of this computation should be
|
* way for the caller to specify that the result of this computation should be
|
||||||
* considered meaningless, and thus trigger an exception.
|
* considered meaningless, and thus trigger an exception.
|
||||||
*
|
*
|
||||||
|
*
|
||||||
* @param threshold Singularity threshold.
|
* @param threshold Singularity threshold.
|
||||||
* @return the covariance matrix.
|
* @return the covariance matrix.
|
||||||
* @throws org.apache.commons.math3.linear.SingularMatrixException
|
* @throws org.apache.commons.math3.linear.SingularMatrixException
|
||||||
* if the covariance matrix cannot be computed (singular problem).
|
* if the covariance matrix cannot be computed (singular problem).
|
||||||
*/
|
*/
|
||||||
double[][] computeCovariances(double threshold);
|
RealMatrix computeCovariances(double threshold);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Computes an estimate of the standard deviation of the parameters. The returned
|
* Computes an estimate of the standard deviation of the parameters. The returned
|
||||||
|
@ -72,13 +75,14 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
|
||||||
* matrix, {@code sd(a[i]) ~= sqrt(C[i][i])}, where {@code a[i]} is the optimized
|
* matrix, {@code sd(a[i]) ~= sqrt(C[i][i])}, where {@code a[i]} is the optimized
|
||||||
* value of the {@code i}-th parameter, and {@code C} is the covariance matrix.
|
* value of the {@code i}-th parameter, and {@code C} is the covariance matrix.
|
||||||
*
|
*
|
||||||
|
*
|
||||||
* @param covarianceSingularityThreshold Singularity threshold (see {@link
|
* @param covarianceSingularityThreshold Singularity threshold (see {@link
|
||||||
* #computeCovariances(double) computeCovariances}).
|
* #computeCovariances(double) computeCovariances}).
|
||||||
* @return an estimate of the standard deviation of the optimized parameters
|
* @return an estimate of the standard deviation of the optimized parameters
|
||||||
* @throws org.apache.commons.math3.linear.SingularMatrixException
|
* @throws org.apache.commons.math3.linear.SingularMatrixException
|
||||||
* if the covariance matrix cannot be computed.
|
* if the covariance matrix cannot be computed.
|
||||||
*/
|
*/
|
||||||
double[] computeSigma(double covarianceSingularityThreshold);
|
RealVector computeSigma(double covarianceSingularityThreshold);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Computes the normalized cost. It is the square-root of the sum of squared of
|
* Computes the normalized cost. It is the square-root of the sum of squared of
|
||||||
|
@ -93,7 +97,7 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
|
||||||
*
|
*
|
||||||
* @return the objective function value at the specified point.
|
* @return the objective function value at the specified point.
|
||||||
*/
|
*/
|
||||||
double[] computeValue();
|
RealVector computeValue();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Computes the weighted Jacobian matrix.
|
* Computes the weighted Jacobian matrix.
|
||||||
|
@ -121,13 +125,13 @@ public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
|
||||||
* @return the weighted residuals: W<sup>1/2</sup> K.
|
* @return the weighted residuals: W<sup>1/2</sup> K.
|
||||||
* @throws DimensionMismatchException if the residuals have the wrong length.
|
* @throws DimensionMismatchException if the residuals have the wrong length.
|
||||||
*/
|
*/
|
||||||
double[] computeResiduals();
|
RealVector computeResiduals();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the abscissa (independent variables) of this evaluation.
|
* Get the abscissa (independent variables) of this evaluation.
|
||||||
*
|
*
|
||||||
* @return the point provided to {@link #evaluate(double[])}.
|
* @return the point provided to {@link #evaluate(RealVector)}.
|
||||||
*/
|
*/
|
||||||
double[] getPoint();
|
RealVector getPoint();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.commons.math3.fitting.leastsquares;
|
package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
|
||||||
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
import org.apache.commons.math3.linear.RealVector;
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
|
@ -36,15 +35,15 @@ class LeastSquaresProblemImpl
|
||||||
implements LeastSquaresProblem {
|
implements LeastSquaresProblem {
|
||||||
|
|
||||||
/** Target values for the model function at optimum. */
|
/** Target values for the model function at optimum. */
|
||||||
private double[] target;
|
private RealVector target;
|
||||||
/** Model function. */
|
/** Model function. */
|
||||||
private MultivariateJacobianFunction model;
|
private MultivariateJacobianFunction model;
|
||||||
/** Initial guess. */
|
/** Initial guess. */
|
||||||
private double[] start;
|
private RealVector start;
|
||||||
|
|
||||||
LeastSquaresProblemImpl(final MultivariateJacobianFunction model,
|
LeastSquaresProblemImpl(final MultivariateJacobianFunction model,
|
||||||
final double[] target,
|
final RealVector target,
|
||||||
final double[] start,
|
final RealVector start,
|
||||||
final ConvergenceChecker<Evaluation> checker,
|
final ConvergenceChecker<Evaluation> checker,
|
||||||
final int maxEvaluations,
|
final int maxEvaluations,
|
||||||
final int maxIterations) {
|
final int maxIterations) {
|
||||||
|
@ -55,18 +54,18 @@ class LeastSquaresProblemImpl
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getObservationSize() {
|
public int getObservationSize() {
|
||||||
return target.length;
|
return target.getDimension();
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getParameterSize() {
|
public int getParameterSize() {
|
||||||
return start.length;
|
return start.getDimension();
|
||||||
}
|
}
|
||||||
|
|
||||||
public double[] getStart() {
|
public RealVector getStart() {
|
||||||
return start == null ? null : start.clone();
|
return start == null ? null : start.copy();
|
||||||
}
|
}
|
||||||
|
|
||||||
public Evaluation evaluate(final double[] point) {
|
public Evaluation evaluate(final RealVector point) {
|
||||||
//evaluate value and jacobian in one function call
|
//evaluate value and jacobian in one function call
|
||||||
final Pair<RealVector, RealMatrix> value = this.model.value(point);
|
final Pair<RealVector, RealMatrix> value = this.model.value(point);
|
||||||
return new UnweightedEvaluation(
|
return new UnweightedEvaluation(
|
||||||
|
@ -84,19 +83,19 @@ class LeastSquaresProblemImpl
|
||||||
private static class UnweightedEvaluation extends AbstractEvaluation {
|
private static class UnweightedEvaluation extends AbstractEvaluation {
|
||||||
|
|
||||||
/** the point of evaluation */
|
/** the point of evaluation */
|
||||||
private final double[] point;
|
private final RealVector point;
|
||||||
/** value at point */
|
/** value at point */
|
||||||
private final RealVector values;
|
private final RealVector values;
|
||||||
/** deriviative at point */
|
/** deriviative at point */
|
||||||
private final RealMatrix jacobian;
|
private final RealMatrix jacobian;
|
||||||
/** reference to the observed values */
|
/** reference to the observed values */
|
||||||
private final double[] target;
|
private final RealVector target;
|
||||||
|
|
||||||
private UnweightedEvaluation(final RealVector values,
|
private UnweightedEvaluation(final RealVector values,
|
||||||
final RealMatrix jacobian,
|
final RealMatrix jacobian,
|
||||||
final double[] target,
|
final RealVector target,
|
||||||
final double[] point) {
|
final RealVector point) {
|
||||||
super(target.length);
|
super(target.getDimension());
|
||||||
this.values = values;
|
this.values = values;
|
||||||
this.jacobian = jacobian;
|
this.jacobian = jacobian;
|
||||||
this.target = target;
|
this.target = target;
|
||||||
|
@ -104,31 +103,20 @@ class LeastSquaresProblemImpl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public double[] computeValue() {
|
public RealVector computeValue() {
|
||||||
return this.values.toArray();
|
return this.values;
|
||||||
}
|
}
|
||||||
|
|
||||||
public RealMatrix computeJacobian() {
|
public RealMatrix computeJacobian() {
|
||||||
return this.jacobian;
|
return this.jacobian;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double[] getPoint() {
|
public RealVector getPoint() {
|
||||||
return this.point;
|
return this.point;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double[] computeResiduals() {
|
public RealVector computeResiduals() {
|
||||||
final double[] objectiveValue = this.computeValue();
|
return target.subtract(this.computeValue());
|
||||||
if (objectiveValue.length != target.length) {
|
|
||||||
throw new DimensionMismatchException(target.length,
|
|
||||||
objectiveValue.length);
|
|
||||||
}
|
|
||||||
|
|
||||||
final double[] residuals = new double[target.length];
|
|
||||||
for (int i = 0; i < target.length; i++) {
|
|
||||||
residuals[i] = target[i] - objectiveValue[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
return residuals;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.commons.math3.fitting.leastsquares;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
||||||
|
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
import org.apache.commons.math3.exception.ConvergenceException;
|
import org.apache.commons.math3.exception.ConvergenceException;
|
||||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||||
|
@ -297,7 +298,7 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
|
||||||
//pull in relevant data from the problem as locals
|
//pull in relevant data from the problem as locals
|
||||||
final int nR = problem.getObservationSize(); // Number of observed data.
|
final int nR = problem.getObservationSize(); // Number of observed data.
|
||||||
final int nC = problem.getParameterSize(); // Number of parameters.
|
final int nC = problem.getParameterSize(); // Number of parameters.
|
||||||
final double[] currentPoint = problem.getStart();
|
final double[] currentPoint = problem.getStart().toArray();
|
||||||
//counters
|
//counters
|
||||||
final Incrementor iterationCounter = problem.getIterationCounter();
|
final Incrementor iterationCounter = problem.getIterationCounter();
|
||||||
final Incrementor evaluationCounter = problem.getEvaluationCounter();
|
final Incrementor evaluationCounter = problem.getEvaluationCounter();
|
||||||
|
@ -327,8 +328,8 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
|
||||||
// Evaluate the function at the starting point and calculate its norm.
|
// Evaluate the function at the starting point and calculate its norm.
|
||||||
evaluationCounter.incrementCount();
|
evaluationCounter.incrementCount();
|
||||||
//value will be reassigned in the loop
|
//value will be reassigned in the loop
|
||||||
Evaluation current = problem.evaluate(currentPoint);
|
Evaluation current = problem.evaluate(new ArrayRealVector(currentPoint, false));
|
||||||
double[] currentResiduals = current.computeResiduals();
|
double[] currentResiduals = current.computeResiduals().toArray();
|
||||||
double currentCost = current.computeCost();
|
double currentCost = current.computeCost();
|
||||||
|
|
||||||
// Outer loop.
|
// Outer loop.
|
||||||
|
@ -444,8 +445,8 @@ public class LevenbergMarquardtOptimizer implements LeastSquaresOptimizer {
|
||||||
|
|
||||||
// Evaluate the function at x + p and calculate its norm.
|
// Evaluate the function at x + p and calculate its norm.
|
||||||
evaluationCounter.incrementCount();
|
evaluationCounter.incrementCount();
|
||||||
current = problem.evaluate(currentPoint);
|
current = problem.evaluate(new ArrayRealVector(currentPoint,false));
|
||||||
currentResiduals = current.computeResiduals();
|
currentResiduals = current.computeResiduals().toArray();
|
||||||
currentCost = current.computeCost();
|
currentCost = current.computeCost();
|
||||||
|
|
||||||
// compute the scaled actual reduction
|
// compute the scaled actual reduction
|
||||||
|
|
|
@ -18,6 +18,6 @@ public interface MultivariateJacobianFunction {
|
||||||
* @param point the abscissae
|
* @param point the abscissae
|
||||||
* @return the values and their Jacobian of this vector valued function.
|
* @return the values and their Jacobian of this vector valued function.
|
||||||
*/
|
*/
|
||||||
Pair<RealVector, RealMatrix> value(double[] point);
|
Pair<RealVector, RealMatrix> value(RealVector point);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package org.apache.commons.math3.fitting.leastsquares;
|
||||||
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
|
||||||
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A pedantic implementation of {@link Optimum}.
|
* A pedantic implementation of {@link Optimum}.
|
||||||
|
@ -44,12 +45,12 @@ class OptimumImpl implements Optimum {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
public double[][] computeCovariances(double threshold) {
|
public RealMatrix computeCovariances(double threshold) {
|
||||||
return value.computeCovariances(threshold);
|
return value.computeCovariances(threshold);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
public double[] computeSigma(double covarianceSingularityThreshold) {
|
public RealVector computeSigma(double covarianceSingularityThreshold) {
|
||||||
return value.computeSigma(covarianceSingularityThreshold);
|
return value.computeSigma(covarianceSingularityThreshold);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,7 +60,7 @@ class OptimumImpl implements Optimum {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
public double[] computeValue() {
|
public RealVector computeValue() {
|
||||||
return value.computeValue();
|
return value.computeValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,12 +75,12 @@ class OptimumImpl implements Optimum {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
public double[] computeResiduals() {
|
public RealVector computeResiduals() {
|
||||||
return value.computeResiduals();
|
return value.computeResiduals();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
public double[] getPoint() {
|
public RealVector getPoint() {
|
||||||
return value.getPoint();
|
return value.getPoint();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
|
||||||
import org.apache.commons.math3.linear.BlockRealMatrix;
|
import org.apache.commons.math3.linear.BlockRealMatrix;
|
||||||
import org.apache.commons.math3.linear.DiagonalMatrix;
|
import org.apache.commons.math3.linear.DiagonalMatrix;
|
||||||
import org.apache.commons.math3.linear.RealMatrix;
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
import org.apache.commons.math3.optim.SimpleVectorValueChecker;
|
import org.apache.commons.math3.optim.SimpleVectorValueChecker;
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
@ -44,6 +45,9 @@ import java.util.Arrays;
|
||||||
*/
|
*/
|
||||||
public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
|
|
||||||
|
/** default absolute tolerance of comparisons */
|
||||||
|
public static final double TOl = 1e-10;
|
||||||
|
|
||||||
public LeastSquaresBuilder base() {
|
public LeastSquaresBuilder base() {
|
||||||
return new LeastSquaresBuilder()
|
return new LeastSquaresBuilder()
|
||||||
.checkerPair(new SimpleVectorValueChecker(1e-6, 1e-6))
|
.checkerPair(new SimpleVectorValueChecker(1e-6, 1e-6))
|
||||||
|
@ -78,6 +82,19 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
Assert.fail("Expected Exception from: " + optimizer.toString());
|
Assert.fail("Expected Exception from: " + optimizer.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check the value of a vector.
|
||||||
|
* @param tolerance the absolute tolerance of comparisons
|
||||||
|
* @param actual the vector to test
|
||||||
|
* @param expected the expected values
|
||||||
|
*/
|
||||||
|
public void assertEquals(double tolerance, RealVector actual, double... expected){
|
||||||
|
for (int i = 0; i < expected.length; i++) {
|
||||||
|
Assert.assertEquals(expected[i], actual.getEntry(i), tolerance);
|
||||||
|
}
|
||||||
|
Assert.assertEquals(expected.length, actual.getDimension());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return the default number of allowed iterations (which will be used when not
|
* @return the default number of allowed iterations (which will be used when not
|
||||||
* specified otherwise).
|
* specified otherwise).
|
||||||
|
@ -150,9 +167,9 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
|
|
||||||
Optimum optimum = optimizer.optimize(ls);
|
Optimum optimum = optimizer.optimize(ls);
|
||||||
|
|
||||||
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
|
Assert.assertEquals(0, optimum.computeRMS(), TOl);
|
||||||
Assert.assertEquals(1.5, optimum.getPoint()[0], 1e-10);
|
assertEquals(TOl, optimum.getPoint(), 1.5);
|
||||||
Assert.assertEquals(3.0, optimum.computeValue()[0], 1e-10);
|
Assert.assertEquals(3.0, optimum.computeValue().getEntry(0), TOl);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testQRColumnsPermutation(LeastSquaresOptimizer optimizer) {
|
public void testQRColumnsPermutation(LeastSquaresOptimizer optimizer) {
|
||||||
|
@ -162,12 +179,9 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
|
|
||||||
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
|
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
|
||||||
|
|
||||||
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
|
Assert.assertEquals(0, optimum.computeRMS(), TOl);
|
||||||
Assert.assertEquals(7, optimum.getPoint()[0], 1e-10);
|
assertEquals(TOl, optimum.getPoint(), 7, 3);
|
||||||
Assert.assertEquals(3, optimum.getPoint()[1], 1e-10);
|
assertEquals(TOl, optimum.computeValue(), 4, 6, 1);
|
||||||
Assert.assertEquals(4, optimum.computeValue()[0], 1e-10);
|
|
||||||
Assert.assertEquals(6, optimum.computeValue()[1], 1e-10);
|
|
||||||
Assert.assertEquals(1, optimum.computeValue()[2], 1e-10);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testNoDependency(LeastSquaresOptimizer optimizer) {
|
public void testNoDependency(LeastSquaresOptimizer optimizer) {
|
||||||
|
@ -182,9 +196,9 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
|
|
||||||
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
|
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
|
||||||
|
|
||||||
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
|
Assert.assertEquals(0, optimum.computeRMS(), TOl);
|
||||||
for (int i = 0; i < problem.target.length; ++i) {
|
for (int i = 0; i < problem.target.length; ++i) {
|
||||||
Assert.assertEquals(0.55 * i, optimum.getPoint()[i], 1e-10);
|
Assert.assertEquals(0.55 * i, optimum.getPoint().getEntry(i), TOl);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -197,10 +211,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
|
|
||||||
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
|
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
|
||||||
|
|
||||||
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
|
Assert.assertEquals(0, optimum.computeRMS(), TOl);
|
||||||
Assert.assertEquals(1, optimum.getPoint()[0], 1e-10);
|
assertEquals(TOl, optimum.getPoint(), 1, 2, 3);
|
||||||
Assert.assertEquals(2, optimum.getPoint()[1], 1e-10);
|
|
||||||
Assert.assertEquals(3, optimum.getPoint()[2], 1e-10);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testTwoSets(LeastSquaresOptimizer optimizer) {
|
public void testTwoSets(LeastSquaresOptimizer optimizer) {
|
||||||
|
@ -216,13 +228,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
|
|
||||||
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
|
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
|
||||||
|
|
||||||
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
|
Assert.assertEquals(0, optimum.computeRMS(), TOl);
|
||||||
Assert.assertEquals(3, optimum.getPoint()[0], 1e-10);
|
assertEquals(TOl, optimum.getPoint(), 3, 4, -1, -2, 1 + epsilon, 1 - epsilon);
|
||||||
Assert.assertEquals(4, optimum.getPoint()[1], 1e-10);
|
|
||||||
Assert.assertEquals(-1, optimum.getPoint()[2], 1e-10);
|
|
||||||
Assert.assertEquals(-2, optimum.getPoint()[3], 1e-10);
|
|
||||||
Assert.assertEquals(1 + epsilon, optimum.getPoint()[4], 1e-10);
|
|
||||||
Assert.assertEquals(1 - epsilon, optimum.getPoint()[5], 1e-10);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testNonInvertible(LeastSquaresOptimizer optimizer) throws Exception {
|
public void testNonInvertible(LeastSquaresOptimizer optimizer) throws Exception {
|
||||||
|
@ -253,11 +260,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
Optimum optimum = optimizer
|
Optimum optimum = optimizer
|
||||||
.optimize(problem1.getBuilder().start(start).build());
|
.optimize(problem1.getBuilder().start(start).build());
|
||||||
|
|
||||||
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
|
Assert.assertEquals(0, optimum.computeRMS(), TOl);
|
||||||
Assert.assertEquals(1, optimum.getPoint()[0], 1e-10);
|
assertEquals(TOl, optimum.getPoint(), 1, 1, 1, 1);
|
||||||
Assert.assertEquals(1, optimum.getPoint()[1], 1e-10);
|
|
||||||
Assert.assertEquals(1, optimum.getPoint()[2], 1e-10);
|
|
||||||
Assert.assertEquals(1, optimum.getPoint()[3], 1e-10);
|
|
||||||
|
|
||||||
LinearProblem problem2 = new LinearProblem(new double[][]{
|
LinearProblem problem2 = new LinearProblem(new double[][]{
|
||||||
{10.00, 7.00, 8.10, 7.20},
|
{10.00, 7.00, 8.10, 7.20},
|
||||||
|
@ -268,11 +272,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
|
|
||||||
optimum = optimizer.optimize(problem2.getBuilder().start(start).build());
|
optimum = optimizer.optimize(problem2.getBuilder().start(start).build());
|
||||||
|
|
||||||
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
|
Assert.assertEquals(0, optimum.computeRMS(), TOl);
|
||||||
Assert.assertEquals(-81, optimum.getPoint()[0], 1e-8);
|
assertEquals(1e-8, optimum.getPoint(), -81, 137, -34, 22);
|
||||||
Assert.assertEquals(137, optimum.getPoint()[1], 1e-8);
|
|
||||||
Assert.assertEquals(-34, optimum.getPoint()[2], 1e-8);
|
|
||||||
Assert.assertEquals(22, optimum.getPoint()[3], 1e-8);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testMoreEstimatedParametersSimple(LeastSquaresOptimizer optimizer) {
|
public void testMoreEstimatedParametersSimple(LeastSquaresOptimizer optimizer) {
|
||||||
|
@ -285,7 +286,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
Optimum optimum = optimizer
|
Optimum optimum = optimizer
|
||||||
.optimize(problem.getBuilder().start(new double[]{7, 6, 5, 4}).build());
|
.optimize(problem.getBuilder().start(new double[]{7, 6, 5, 4}).build());
|
||||||
|
|
||||||
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
|
Assert.assertEquals(0, optimum.computeRMS(), TOl);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testMoreEstimatedParametersUnsorted(LeastSquaresOptimizer optimizer) {
|
public void testMoreEstimatedParametersUnsorted(LeastSquaresOptimizer optimizer) {
|
||||||
|
@ -300,11 +301,9 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
Optimum optimum = optimizer.optimize(
|
Optimum optimum = optimizer.optimize(
|
||||||
problem.getBuilder().start(new double[]{2, 2, 2, 2, 2, 2}).build());
|
problem.getBuilder().start(new double[]{2, 2, 2, 2, 2, 2}).build());
|
||||||
|
|
||||||
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
|
Assert.assertEquals(0, optimum.computeRMS(), TOl);
|
||||||
Assert.assertEquals(3, optimum.getPoint()[2], 1e-10);
|
//TODO the first two elements of point were not previously checked
|
||||||
Assert.assertEquals(4, optimum.getPoint()[3], 1e-10);
|
assertEquals(TOl, optimum.getPoint(), 2, 1, 3, 4, 5, 6);
|
||||||
Assert.assertEquals(5, optimum.getPoint()[4], 1e-10);
|
|
||||||
Assert.assertEquals(6, optimum.getPoint()[5], 1e-10);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testRedundantEquations(LeastSquaresOptimizer optimizer) {
|
public void testRedundantEquations(LeastSquaresOptimizer optimizer) {
|
||||||
|
@ -317,9 +316,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
Optimum optimum = optimizer
|
Optimum optimum = optimizer
|
||||||
.optimize(problem.getBuilder().start(new double[]{1, 1}).build());
|
.optimize(problem.getBuilder().start(new double[]{1, 1}).build());
|
||||||
|
|
||||||
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
|
Assert.assertEquals(0, optimum.computeRMS(), TOl);
|
||||||
Assert.assertEquals(2, optimum.getPoint()[0], 1e-10);
|
assertEquals(TOl, optimum.getPoint(), 2, 1);
|
||||||
Assert.assertEquals(1, optimum.getPoint()[1], 1e-10);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testInconsistentEquations(LeastSquaresOptimizer optimizer) {
|
public void testInconsistentEquations(LeastSquaresOptimizer optimizer) {
|
||||||
|
@ -346,9 +344,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
//TODO why is this part here? hasn't it been tested already?
|
//TODO why is this part here? hasn't it been tested already?
|
||||||
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
|
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
|
||||||
|
|
||||||
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
|
Assert.assertEquals(0, optimum.computeRMS(), TOl);
|
||||||
Assert.assertEquals(-1, optimum.getPoint()[0], 1e-10);
|
assertEquals(TOl, optimum.getPoint(), -1, 1);
|
||||||
Assert.assertEquals(1, optimum.getPoint()[1], 1e-10);
|
|
||||||
|
|
||||||
//TODO move to builder test
|
//TODO move to builder test
|
||||||
optimizer.optimize(
|
optimizer.optimize(
|
||||||
|
@ -368,9 +365,8 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
|
|
||||||
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
|
Optimum optimum = optimizer.optimize(problem.getBuilder().build());
|
||||||
|
|
||||||
Assert.assertEquals(0, optimum.computeRMS(), 1e-10);
|
Assert.assertEquals(0, optimum.computeRMS(), TOl);
|
||||||
Assert.assertEquals(-1, optimum.getPoint()[0], 1e-10);
|
assertEquals(TOl, optimum.getPoint(), -1, 1);
|
||||||
Assert.assertEquals(1, optimum.getPoint()[1], 1e-10);
|
|
||||||
|
|
||||||
//TODO move to builder test
|
//TODO move to builder test
|
||||||
optimizer.optimize(
|
optimizer.optimize(
|
||||||
|
@ -400,14 +396,14 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
Assert.assertTrue(optimum.getEvaluations() < 10);
|
Assert.assertTrue(optimum.getEvaluations() < 10);
|
||||||
|
|
||||||
double rms = optimum.computeRMS();
|
double rms = optimum.computeRMS();
|
||||||
Assert.assertEquals(1.768262623567235, FastMath.sqrt(circle.getN()) * rms, 1e-10);
|
Assert.assertEquals(1.768262623567235, FastMath.sqrt(circle.getN()) * rms, TOl);
|
||||||
|
|
||||||
Vector2D center = new Vector2D(optimum.getPoint()[0], optimum.getPoint()[1]);
|
Vector2D center = new Vector2D(optimum.getPoint().getEntry(0), optimum.getPoint().getEntry(1));
|
||||||
Assert.assertEquals(69.96016176931406, circle.getRadius(center), 1e-6);
|
Assert.assertEquals(69.96016176931406, circle.getRadius(center), 1e-6);
|
||||||
Assert.assertEquals(96.07590211815305, center.getX(), 1e-6);
|
Assert.assertEquals(96.07590211815305, center.getX(), 1e-6);
|
||||||
Assert.assertEquals(48.13516790438953, center.getY(), 1e-6);
|
Assert.assertEquals(48.13516790438953, center.getY(), 1e-6);
|
||||||
|
|
||||||
double[][] cov = optimum.computeCovariances(1e-14);
|
double[][] cov = optimum.computeCovariances(1e-14).getData();
|
||||||
Assert.assertEquals(1.839, cov[0][0], 0.001);
|
Assert.assertEquals(1.839, cov[0][0], 0.001);
|
||||||
Assert.assertEquals(0.731, cov[0][1], 0.001);
|
Assert.assertEquals(0.731, cov[0][1], 0.001);
|
||||||
Assert.assertEquals(cov[0][1], cov[1][0], 1e-14);
|
Assert.assertEquals(cov[0][1], cov[1][0], 1e-14);
|
||||||
|
@ -425,7 +421,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
optimum = optimizer.optimize(
|
optimum = optimizer.optimize(
|
||||||
builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
|
builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
|
||||||
|
|
||||||
cov = optimum.computeCovariances(1e-14);
|
cov = optimum.computeCovariances(1e-14).getData();
|
||||||
Assert.assertEquals(0.0016, cov[0][0], 0.001);
|
Assert.assertEquals(0.0016, cov[0][0], 0.001);
|
||||||
Assert.assertEquals(3.2e-7, cov[0][1], 1e-9);
|
Assert.assertEquals(3.2e-7, cov[0][1], 1e-9);
|
||||||
Assert.assertEquals(cov[0][1], cov[1][0], 1e-14);
|
Assert.assertEquals(cov[0][1], cov[1][0], 1e-14);
|
||||||
|
@ -444,7 +440,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
|
|
||||||
Optimum optimum = optimizer.optimize(builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
|
Optimum optimum = optimizer.optimize(builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
|
||||||
|
|
||||||
Vector2D center = new Vector2D(optimum.getPoint()[0], optimum.getPoint()[1]);
|
Vector2D center = new Vector2D(optimum.getPoint().getEntry(0), optimum.getPoint().getEntry(1));
|
||||||
Assert.assertTrue(optimum.getEvaluations() < 25);
|
Assert.assertTrue(optimum.getEvaluations() < 25);
|
||||||
Assert.assertEquals(0.043, optimum.computeRMS(), 1e-3);
|
Assert.assertEquals(0.043, optimum.computeRMS(), 1e-3);
|
||||||
Assert.assertEquals(0.292235, circle.getRadius(center), 1e-6);
|
Assert.assertEquals(0.292235, circle.getRadius(center), 1e-6);
|
||||||
|
@ -465,8 +461,7 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
Optimum optimum = optimizer.optimize(
|
Optimum optimum = optimizer.optimize(
|
||||||
builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
|
builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
|
||||||
|
|
||||||
Assert.assertEquals(-0.1517383071957963, optimum.getPoint()[0], 1e-6);
|
assertEquals(1e-6, optimum.getPoint(), -0.1517383071957963, 0.2074999736353867);
|
||||||
Assert.assertEquals(0.2074999736353867, optimum.getPoint()[1], 1e-6);
|
|
||||||
Assert.assertEquals(0.04268731682389561, optimum.computeRMS(), 1e-8);
|
Assert.assertEquals(0.04268731682389561, optimum.computeRMS(), 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -509,12 +504,12 @@ public abstract class AbstractLeastSquaresOptimizerAbstractTest {
|
||||||
|
|
||||||
final Optimum optimum = optimizer.optimize(builder(dataset).build());
|
final Optimum optimum = optimizer.optimize(builder(dataset).build());
|
||||||
|
|
||||||
final double[] actual = optimum.getPoint();
|
final RealVector actual = optimum.getPoint();
|
||||||
for (int i = 0; i < actual.length; i++) {
|
for (int i = 0; i < actual.getDimension(); i++) {
|
||||||
double expected = dataset.getParameter(i);
|
double expected = dataset.getParameter(i);
|
||||||
double delta = FastMath.abs(errParams * expected);
|
double delta = FastMath.abs(errParams * expected);
|
||||||
Assert.assertEquals(dataset.getName() + ", param #" + i,
|
Assert.assertEquals(dataset.getName() + ", param #" + i,
|
||||||
expected, actual[i], delta);
|
expected, actual.getEntry(i), delta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
|
||||||
import org.apache.commons.math3.linear.DiagonalMatrix;
|
import org.apache.commons.math3.linear.DiagonalMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -87,10 +88,10 @@ public class EvaluationTest {
|
||||||
|
|
||||||
final Evaluation evaluation = lsp.evaluate(lsp.getStart());
|
final Evaluation evaluation = lsp.evaluate(lsp.getStart());
|
||||||
final double cost = evaluation.computeCost();
|
final double cost = evaluation.computeCost();
|
||||||
final double[] sig = evaluation.computeSigma(1e-14);
|
final RealVector sig = evaluation.computeSigma(1e-14);
|
||||||
final int dof = lsp.getObservationSize() - lsp.getParameterSize();
|
final int dof = lsp.getObservationSize() - lsp.getParameterSize();
|
||||||
for (int i = 0; i < sig.length; i++) {
|
for (int i = 0; i < sig.getDimension(); i++) {
|
||||||
final double actual = FastMath.sqrt(cost * cost / dof) * sig[i];
|
final double actual = FastMath.sqrt(cost * cost / dof) * sig.getEntry(i);
|
||||||
Assert.assertEquals(dataset.getName() + ", parameter #" + i,
|
Assert.assertEquals(dataset.getName() + ", parameter #" + i,
|
||||||
expected[i], actual, 1e-6 * expected[i]);
|
expected[i], actual, 1e-6 * expected[i]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,17 +13,18 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.commons.math3.fitting.leastsquares;
|
package org.apache.commons.math3.fitting.leastsquares;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||||
import java.util.List;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.awt.geom.Point2D;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
|
|
||||||
import org.apache.commons.math3.stat.descriptive.StatisticalSummary;
|
|
||||||
import org.apache.commons.math3.linear.DiagonalMatrix;
|
import org.apache.commons.math3.linear.DiagonalMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
|
import org.apache.commons.math3.stat.descriptive.StatisticalSummary;
|
||||||
|
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.awt.geom.Point2D;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class demonstrates the main functionality of the
|
* This class demonstrates the main functionality of the
|
||||||
|
@ -95,7 +96,7 @@ public class EvaluationTestValidation {
|
||||||
sigmaEstimate[i] = new SummaryStatistics();
|
sigmaEstimate[i] = new SummaryStatistics();
|
||||||
}
|
}
|
||||||
|
|
||||||
final double[] init = { slope, offset };
|
final RealVector init = new ArrayRealVector(new double[]{ slope, offset }, false);
|
||||||
|
|
||||||
// Monte-Carlo (generates many sets of observations).
|
// Monte-Carlo (generates many sets of observations).
|
||||||
final int mcRepeat = MONTE_CARLO_RUNS;
|
final int mcRepeat = MONTE_CARLO_RUNS;
|
||||||
|
@ -117,12 +118,12 @@ public class EvaluationTestValidation {
|
||||||
// covariance matrix).
|
// covariance matrix).
|
||||||
final LeastSquaresProblem lsp = builder(problem).build();
|
final LeastSquaresProblem lsp = builder(problem).build();
|
||||||
|
|
||||||
final double[] sigma = lsp.evaluate(init).computeSigma(1e-14);
|
final RealVector sigma = lsp.evaluate(init).computeSigma(1e-14);
|
||||||
|
|
||||||
// Accumulate statistics.
|
// Accumulate statistics.
|
||||||
for (int i = 0; i < numParams; i++) {
|
for (int i = 0; i < numParams; i++) {
|
||||||
paramsFoundByDirectSolution[i].addValue(regress[i]);
|
paramsFoundByDirectSolution[i].addValue(regress[i]);
|
||||||
sigmaEstimate[i].addValue(sigma[i]);
|
sigmaEstimate[i].addValue(sigma.getEntry(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next Monte-Carlo.
|
// Next Monte-Carlo.
|
||||||
|
@ -138,7 +139,7 @@ public class EvaluationTestValidation {
|
||||||
|
|
||||||
StatisticalSummary s = paramsFoundByDirectSolution[i].getSummary();
|
StatisticalSummary s = paramsFoundByDirectSolution[i].getSummary();
|
||||||
System.out.printf(" %+.6e %+.6e %+.6e\n",
|
System.out.printf(" %+.6e %+.6e %+.6e\n",
|
||||||
init[i],
|
init.getEntry(i),
|
||||||
s.getMean(),
|
s.getMean(),
|
||||||
s.getStandardDeviation());
|
s.getStandardDeviation());
|
||||||
|
|
||||||
|
@ -212,7 +213,7 @@ public class EvaluationTestValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Direct solution (using simple regression).
|
// Direct solution (using simple regression).
|
||||||
final double[] regress = problem.solve();
|
final RealVector regress = new ArrayRealVector(problem.solve(), false);
|
||||||
|
|
||||||
// Dummy optimizer (to compute the chi-square).
|
// Dummy optimizer (to compute the chi-square).
|
||||||
final LeastSquaresProblem lsp = builder(problem).build();
|
final LeastSquaresProblem lsp = builder(problem).build();
|
||||||
|
@ -221,7 +222,7 @@ public class EvaluationTestValidation {
|
||||||
// Get chi-square of the best parameters set for the given set of
|
// Get chi-square of the best parameters set for the given set of
|
||||||
// observations.
|
// observations.
|
||||||
final double bestChi2N = getChi2N(lsp, regress);
|
final double bestChi2N = getChi2N(lsp, regress);
|
||||||
final double[] sigma = lsp.evaluate(regress).computeSigma(1e-14);
|
final RealVector sigma = lsp.evaluate(regress).computeSigma(1e-14);
|
||||||
|
|
||||||
// Monte-Carlo (generates a grid of parameters).
|
// Monte-Carlo (generates a grid of parameters).
|
||||||
final int mcRepeat = MONTE_CARLO_RUNS;
|
final int mcRepeat = MONTE_CARLO_RUNS;
|
||||||
|
@ -233,8 +234,8 @@ public class EvaluationTestValidation {
|
||||||
// Index 2 = normalized chi2
|
// Index 2 = normalized chi2
|
||||||
final List<double[]> paramsAndChi2 = new ArrayList<double[]>(gridSize * gridSize);
|
final List<double[]> paramsAndChi2 = new ArrayList<double[]>(gridSize * gridSize);
|
||||||
|
|
||||||
final double slopeRange = 10 * sigma[0];
|
final double slopeRange = 10 * sigma.getEntry(0);
|
||||||
final double offsetRange = 10 * sigma[1];
|
final double offsetRange = 10 * sigma.getEntry(1);
|
||||||
final double minSlope = slope - 0.5 * slopeRange;
|
final double minSlope = slope - 0.5 * slopeRange;
|
||||||
final double minOffset = offset - 0.5 * offsetRange;
|
final double minOffset = offset - 0.5 * offsetRange;
|
||||||
final double deltaSlope = slopeRange/ gridSize;
|
final double deltaSlope = slopeRange/ gridSize;
|
||||||
|
@ -243,7 +244,8 @@ public class EvaluationTestValidation {
|
||||||
final double s = minSlope + i * deltaSlope;
|
final double s = minSlope + i * deltaSlope;
|
||||||
for (int j = 0; j < gridSize; j++) {
|
for (int j = 0; j < gridSize; j++) {
|
||||||
final double o = minOffset + j * deltaOffset;
|
final double o = minOffset + j * deltaOffset;
|
||||||
final double chi2N = getChi2N(lsp, new double[] {s, o});
|
final double chi2N = getChi2N(lsp,
|
||||||
|
new ArrayRealVector(new double[] {s, o}, false));
|
||||||
|
|
||||||
paramsAndChi2.add(new double[] {s, o, chi2N});
|
paramsAndChi2.add(new double[] {s, o, chi2N});
|
||||||
}
|
}
|
||||||
|
@ -260,7 +262,7 @@ public class EvaluationTestValidation {
|
||||||
final String lineFmt = "%+.10e %+.10e %.8e\n";
|
final String lineFmt = "%+.10e %+.10e %.8e\n";
|
||||||
|
|
||||||
// Point with smallest chi-square.
|
// Point with smallest chi-square.
|
||||||
System.out.printf(lineFmt, regress[0], regress[1], bestChi2N);
|
System.out.printf(lineFmt, regress.getEntry(0), regress.getEntry(1), bestChi2N);
|
||||||
System.out.println(); // Empty line.
|
System.out.println(); // Empty line.
|
||||||
|
|
||||||
// Points within the confidence interval.
|
// Points within the confidence interval.
|
||||||
|
@ -280,7 +282,7 @@ public class EvaluationTestValidation {
|
||||||
}
|
}
|
||||||
System.out.println(); // Empty line.
|
System.out.println(); // Empty line.
|
||||||
|
|
||||||
System.out.println("# sigma=" + Arrays.toString(sigma));
|
System.out.println("# sigma=" + sigma.toString());
|
||||||
System.out.println("# " + numLarger + " sets filtered out");
|
System.out.println("# " + numLarger + " sets filtered out");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -289,15 +291,17 @@ public class EvaluationTestValidation {
|
||||||
.model(problem.getModelFunction())
|
.model(problem.getModelFunction())
|
||||||
.jacobian(problem.getModelFunctionJacobian())
|
.jacobian(problem.getModelFunctionJacobian())
|
||||||
.target(problem.target())
|
.target(problem.target())
|
||||||
.weight(new DiagonalMatrix(problem.weight()));
|
.weight(new DiagonalMatrix(problem.weight()))
|
||||||
|
//unused start point to avoid NPE
|
||||||
|
.start(new double[2]);
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* @return the normalized chi-square.
|
* @return the normalized chi-square.
|
||||||
*/
|
*/
|
||||||
private double getChi2N(LeastSquaresProblem lsp,
|
private double getChi2N(LeastSquaresProblem lsp,
|
||||||
double[] params) {
|
RealVector params) {
|
||||||
final double cost = lsp.evaluate(params).computeCost();
|
final double cost = lsp.evaluate(params).computeCost();
|
||||||
return cost * cost / (lsp.getObservationSize() - params.length);
|
return cost * cost / (lsp.getObservationSize() - params.getDimension());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,8 @@ import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||||
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
|
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
|
||||||
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
|
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
|
||||||
import org.apache.commons.math3.linear.DiagonalMatrix;
|
import org.apache.commons.math3.linear.DiagonalMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealMatrix;
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
import org.apache.commons.math3.linear.SingularMatrixException;
|
import org.apache.commons.math3.linear.SingularMatrixException;
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
import org.apache.commons.math3.util.Precision;
|
import org.apache.commons.math3.util.Precision;
|
||||||
|
@ -201,10 +203,10 @@ public class LevenbergMarquardtOptimizerTest
|
||||||
.build()
|
.build()
|
||||||
);
|
);
|
||||||
|
|
||||||
final double[] solution = optimum.getPoint();
|
final RealVector solution = optimum.getPoint();
|
||||||
final double[] expectedSolution = { 10.4, 958.3, 131.4, 33.9, 205.0 };
|
final double[] expectedSolution = { 10.4, 958.3, 131.4, 33.9, 205.0 };
|
||||||
|
|
||||||
final double[][] covarMatrix = optimum.computeCovariances(1e-14);
|
final RealMatrix covarMatrix = optimum.computeCovariances(1e-14);
|
||||||
final double[][] expectedCovarMatrix = {
|
final double[][] expectedCovarMatrix = {
|
||||||
{ 3.38, -3.69, 27.98, -2.34, -49.24 },
|
{ 3.38, -3.69, 27.98, -2.34, -49.24 },
|
||||||
{ -3.69, 2492.26, 81.89, -69.21, -8.9 },
|
{ -3.69, 2492.26, 81.89, -69.21, -8.9 },
|
||||||
|
@ -218,7 +220,7 @@ public class LevenbergMarquardtOptimizerTest
|
||||||
// Check that the computed solution is within the reference error range.
|
// Check that the computed solution is within the reference error range.
|
||||||
for (int i = 0; i < numParams; i++) {
|
for (int i = 0; i < numParams; i++) {
|
||||||
final double error = FastMath.sqrt(expectedCovarMatrix[i][i]);
|
final double error = FastMath.sqrt(expectedCovarMatrix[i][i]);
|
||||||
Assert.assertEquals("Parameter " + i, expectedSolution[i], solution[i], error);
|
Assert.assertEquals("Parameter " + i, expectedSolution[i], solution.getEntry(i), error);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that each entry of the computed covariance matrix is within 10%
|
// Check that each entry of the computed covariance matrix is within 10%
|
||||||
|
@ -227,7 +229,7 @@ public class LevenbergMarquardtOptimizerTest
|
||||||
for (int j = 0; j < numParams; j++) {
|
for (int j = 0; j < numParams; j++) {
|
||||||
Assert.assertEquals("Covariance matrix [" + i + "][" + j + "]",
|
Assert.assertEquals("Covariance matrix [" + i + "][" + j + "]",
|
||||||
expectedCovarMatrix[i][j],
|
expectedCovarMatrix[i][j],
|
||||||
covarMatrix[i][j],
|
covarMatrix.getEntry(i, j),
|
||||||
FastMath.abs(0.1 * expectedCovarMatrix[i][j]));
|
FastMath.abs(0.1 * expectedCovarMatrix[i][j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -258,10 +260,10 @@ public class LevenbergMarquardtOptimizerTest
|
||||||
final Optimum optimum = optimizer.optimize(
|
final Optimum optimum = optimizer.optimize(
|
||||||
builder(circle).maxIterations(50).start(init).build());
|
builder(circle).maxIterations(50).start(init).build());
|
||||||
|
|
||||||
final double[] paramFound = optimum.getPoint();
|
final double[] paramFound = optimum.getPoint().toArray();
|
||||||
|
|
||||||
// Retrieve errors estimation.
|
// Retrieve errors estimation.
|
||||||
final double[] asymptoticStandardErrorFound = optimum.computeSigma(1e-14);
|
final double[] asymptoticStandardErrorFound = optimum.computeSigma(1e-14).toArray();
|
||||||
|
|
||||||
// Check that the parameters are found within the assumed error bars.
|
// Check that the parameters are found within the assumed error bars.
|
||||||
Assert.assertEquals(xCenter, paramFound[0], asymptoticStandardErrorFound[0]);
|
Assert.assertEquals(xCenter, paramFound[0], asymptoticStandardErrorFound[0]);
|
||||||
|
|
|
@ -518,7 +518,7 @@ public class MinpackTest {
|
||||||
final Optimum optimum = optimizer.optimize(problem);
|
final Optimum optimum = optimizer.optimize(problem);
|
||||||
Assert.assertFalse(exceptionExpected);
|
Assert.assertFalse(exceptionExpected);
|
||||||
function.checkTheoreticalMinCost(optimum.computeRMS());
|
function.checkTheoreticalMinCost(optimum.computeRMS());
|
||||||
function.checkTheoreticalMinParams(optimum.getPoint());
|
function.checkTheoreticalMinParams(optimum.getPoint().toArray());
|
||||||
} catch (TooManyEvaluationsException e) {
|
} catch (TooManyEvaluationsException e) {
|
||||||
Assert.assertTrue(exceptionExpected);
|
Assert.assertTrue(exceptionExpected);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue