From 44f6b8b20f0ed5bf1965ef575b9ad30bd4f8b256 Mon Sep 17 00:00:00 2001 From: Gilles Sadowski Date: Mon, 6 Sep 2010 09:06:28 +0000 Subject: [PATCH] MATH-413 (points 1, 2 and 10) Reverted to the original version of the convergence checker (using only the previous and current best points). "LevenberMarquardtOptimizer": Removed setters (control parameters must be set at construction). Added a contructor. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@992976 13f79535-47bb-0310-9956-ffa450edef68 --- .../AbstractConvergenceChecker.java | 14 +- .../math/optimization/ConvergenceChecker.java | 19 +- .../optimization/SimpleRealPointChecker.java | 22 +-- .../SimpleScalarValueChecker.java | 22 +-- .../SimpleVectorialPointChecker.java | 22 +-- .../SimpleVectorialValueChecker.java | 22 +-- .../general/LevenbergMarquardtOptimizer.java | 176 ++++++++---------- .../optimization/general/PowellOptimizer.java | 75 ++++++-- .../univariate/BrentOptimizer.java | 135 +++++--------- .../LevenbergMarquardtOptimizerTest.java | 34 +--- .../optimization/general/MinpackTest.java | 8 +- .../general/PowellOptimizerTest.java | 3 +- .../univariate/BrentOptimizerTest.java | 30 ++- ...MultiStartUnivariateRealOptimizerTest.java | 6 +- 14 files changed, 252 insertions(+), 336 deletions(-) diff --git a/src/main/java/org/apache/commons/math/optimization/AbstractConvergenceChecker.java b/src/main/java/org/apache/commons/math/optimization/AbstractConvergenceChecker.java index 7bca96a31..f05774a2e 100644 --- a/src/main/java/org/apache/commons/math/optimization/AbstractConvergenceChecker.java +++ b/src/main/java/org/apache/commons/math/optimization/AbstractConvergenceChecker.java @@ -22,11 +22,13 @@ import org.apache.commons.math.util.MathUtils; /** * Base class for all convergence checker implementations. * + * Type of (point, value) pair. + * * @version $Revision$ $Date$ * @since 3.0 */ -public abstract class AbstractConvergenceChecker - implements ConvergenceChecker { +public abstract class AbstractConvergenceChecker + implements ConvergenceChecker { /** * Default relative threshold. */ @@ -65,14 +67,14 @@ public abstract class AbstractConvergenceChecker } /** - * {@inheritDoc} + * @return the relative threshold. */ public double getRelativeThreshold() { return relativeThreshold; } /** - * {@inheritDoc} + * @return the absolute threshold. */ public double getAbsoluteThreshold() { return absoluteThreshold; @@ -81,5 +83,7 @@ public abstract class AbstractConvergenceChecker /** * {@inheritDoc} */ - public abstract boolean converged(int iteration, T ... points); + public abstract boolean converged(int iteration, + PAIR previous, + PAIR current); } diff --git a/src/main/java/org/apache/commons/math/optimization/ConvergenceChecker.java b/src/main/java/org/apache/commons/math/optimization/ConvergenceChecker.java index d7e60a6e7..caede6242 100644 --- a/src/main/java/org/apache/commons/math/optimization/ConvergenceChecker.java +++ b/src/main/java/org/apache/commons/math/optimization/ConvergenceChecker.java @@ -37,22 +37,9 @@ public interface ConvergenceChecker { * Check if the optimization algorithm has converged. * * @param iteration Current iteration. - * @param points Data used for checking the convergence. + * @param previous Best point in the previous iteration. + * @param current Best point in the current iteration. * @return {@code true} if the algorithm is considered to have converged. */ - boolean converged(int iteration, PAIR ... points); - - /** - * Get the relative tolerance. - * - * @return the relative threshold. - */ - double getRelativeThreshold(); - - /** - * Get the absolute tolerance. - * - * @return the absolute threshold. - */ - double getAbsoluteThreshold(); + boolean converged(int iteration, PAIR previous, PAIR current); } diff --git a/src/main/java/org/apache/commons/math/optimization/SimpleRealPointChecker.java b/src/main/java/org/apache/commons/math/optimization/SimpleRealPointChecker.java index 6b872e907..3e1f5fdf0 100644 --- a/src/main/java/org/apache/commons/math/optimization/SimpleRealPointChecker.java +++ b/src/main/java/org/apache/commons/math/optimization/SimpleRealPointChecker.java @@ -19,7 +19,6 @@ package org.apache.commons.math.optimization; import org.apache.commons.math.util.MathUtils; import org.apache.commons.math.util.FastMath; -import org.apache.commons.math.exception.DimensionMismatchException; /** * Simple implementation of the {@link ConvergenceChecker} interface using @@ -66,24 +65,15 @@ public class SimpleRealPointChecker * not only for the best or worst ones. * * @param iteration Index of current iteration - * @param points Points used for checking convergence. The list must - * contain two elements: - *
    - *
  • the previous best point,
  • - *
  • the current best point.
  • - *
+ * @param previous Best point in the previous iteration. + * @param current Best point in the current iteration. * @return {@code true} if the algorithm has converged. - * @throws DimensionMismatchException if the length of the {@code points} - * list is not equal to 2. */ public boolean converged(final int iteration, - final RealPointValuePair ... points) { - if (points.length != 2) { - throw new DimensionMismatchException(points.length, 2); - } - - final double[] p = points[0].getPoint(); - final double[] c = points[1].getPoint(); + final RealPointValuePair previous, + final RealPointValuePair current) { + final double[] p = previous.getPoint(); + final double[] c = current.getPoint(); for (int i = 0; i < p.length; ++i) { final double difference = FastMath.abs(p[i] - c[i]); final double size = FastMath.max(FastMath.abs(p[i]), FastMath.abs(c[i])); diff --git a/src/main/java/org/apache/commons/math/optimization/SimpleScalarValueChecker.java b/src/main/java/org/apache/commons/math/optimization/SimpleScalarValueChecker.java index 2e71805cb..95664b6f3 100644 --- a/src/main/java/org/apache/commons/math/optimization/SimpleScalarValueChecker.java +++ b/src/main/java/org/apache/commons/math/optimization/SimpleScalarValueChecker.java @@ -19,7 +19,6 @@ package org.apache.commons.math.optimization; import org.apache.commons.math.util.MathUtils; import org.apache.commons.math.util.FastMath; -import org.apache.commons.math.exception.DimensionMismatchException; /** * Simple implementation of the {@link ConvergenceChecker} interface using @@ -66,24 +65,15 @@ public class SimpleScalarValueChecker * not only for the best or worst ones. * * @param iteration Index of current iteration - * @param points Points used for checking convergence. The list must - * contain two elements: - *
    - *
  • the previous best point,
  • - *
  • the current best point.
  • - *
+ * @param previous Best point in the previous iteration. + * @param current Best point in the current iteration. * @return {@code true} if the algorithm has converged. - * @throws DimensionMismatchException if the length of the {@code points} - * list is not equal to 2. */ public boolean converged(final int iteration, - final RealPointValuePair ... points) { - if (points.length != 2) { - throw new DimensionMismatchException(points.length, 2); - } - - final double p = points[0].getValue(); - final double c = points[1].getValue(); + final RealPointValuePair previous, + final RealPointValuePair current) { + final double p = previous.getValue(); + final double c = current.getValue(); final double difference = FastMath.abs(p - c); final double size = FastMath.max(FastMath.abs(p), FastMath.abs(c)); return (difference <= size * getRelativeThreshold() || diff --git a/src/main/java/org/apache/commons/math/optimization/SimpleVectorialPointChecker.java b/src/main/java/org/apache/commons/math/optimization/SimpleVectorialPointChecker.java index 9ebf33a33..d0c4de7c3 100644 --- a/src/main/java/org/apache/commons/math/optimization/SimpleVectorialPointChecker.java +++ b/src/main/java/org/apache/commons/math/optimization/SimpleVectorialPointChecker.java @@ -18,7 +18,6 @@ package org.apache.commons.math.optimization; import org.apache.commons.math.util.MathUtils; -import org.apache.commons.math.exception.DimensionMismatchException; /** * Simple implementation of the {@link ConvergenceChecker} interface using @@ -66,24 +65,15 @@ public class SimpleVectorialPointChecker * not only for the best or worst ones. * * @param iteration Index of current iteration - * @param points Points used for checking convergence. The list must - * contain two elements: - *
    - *
  • the previous best point,
  • - *
  • the current best point.
  • - *
+ * @param previous Best point in the previous iteration. + * @param current Best point in the current iteration. * @return {@code true} if the algorithm has converged. - * @throws DimensionMismatchException if the length of the {@code points} - * list is not equal to 2. */ public boolean converged(final int iteration, - final VectorialPointValuePair ... points) { - if (points.length != 2) { - throw new DimensionMismatchException(points.length, 2); - } - - final double[] p = points[0].getPointRef(); - final double[] c = points[1].getPointRef(); + final VectorialPointValuePair previous, + final VectorialPointValuePair current) { + final double[] p = previous.getPointRef(); + final double[] c = current.getPointRef(); for (int i = 0; i < p.length; ++i) { final double pi = p[i]; final double ci = c[i]; diff --git a/src/main/java/org/apache/commons/math/optimization/SimpleVectorialValueChecker.java b/src/main/java/org/apache/commons/math/optimization/SimpleVectorialValueChecker.java index 576462b76..31ae71364 100644 --- a/src/main/java/org/apache/commons/math/optimization/SimpleVectorialValueChecker.java +++ b/src/main/java/org/apache/commons/math/optimization/SimpleVectorialValueChecker.java @@ -19,7 +19,6 @@ package org.apache.commons.math.optimization; import org.apache.commons.math.util.FastMath; import org.apache.commons.math.util.MathUtils; -import org.apache.commons.math.exception.DimensionMismatchException; /** * Simple implementation of the {@link ConvergenceChecker} interface using @@ -67,24 +66,15 @@ public class SimpleVectorialValueChecker * not only for the best or worst ones. * * @param iteration Index of current iteration - * @param points Points used for checking convergence. The list must - * contain two elements: - *
    - *
  • the previous best point,
  • - *
  • the current best point.
  • - *
+ * @param previous Best point in the previous iteration. + * @param current Best point in the current iteration. * @return {@code true} if the algorithm has converged. - * @throws DimensionMismatchException if the length of the {@code points} - * list is not equal to 2. */ public boolean converged(final int iteration, - final VectorialPointValuePair ... points) { - if (points.length != 2) { - throw new DimensionMismatchException(points.length, 2); - } - - final double[] p = points[0].getValueRef(); - final double[] c = points[1].getValueRef(); + final VectorialPointValuePair previous, + final VectorialPointValuePair current) { + final double[] p = previous.getValueRef(); + final double[] c = current.getValueRef(); for (int i = 0; i < p.length; ++i) { final double pi = p[i]; final double ci = c[i]; diff --git a/src/main/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizer.java b/src/main/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizer.java index 2d166f068..161229980 100644 --- a/src/main/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizer.java +++ b/src/main/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizer.java @@ -106,136 +106,117 @@ import org.apache.commons.math.util.FastMath; * */ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { - /** Number of solved point. */ private int solvedCols; - /** Diagonal elements of the R matrix in the Q.R. decomposition. */ private double[] diagR; - /** Norms of the columns of the jacobian matrix. */ private double[] jacNorm; - /** Coefficients of the Householder transforms vectors. */ private double[] beta; - /** Columns permutation array. */ private int[] permutation; - /** Rank of the jacobian matrix. */ private int rank; - /** Levenberg-Marquardt parameter. */ private double lmPar; - /** Parameters evolution direction associated with lmPar. */ private double[] lmDir; - /** Positive input variable used in determining the initial step bound. */ - private double initialStepBoundFactor; - + private final double initialStepBoundFactor; /** Desired relative error in the sum of squares. */ - private double costRelativeTolerance; - + private final double costRelativeTolerance; /** Desired relative error in the approximate solution parameters. */ - private double parRelativeTolerance; - + private final double parRelativeTolerance; /** Desired max cosine on the orthogonality between the function vector * and the columns of the jacobian. */ - private double orthoTolerance; - + private final double orthoTolerance; /** Threshold for QR ranking. */ - private double qrRankingThreshold; + private final double qrRankingThreshold; /** - * Build an optimizer for least squares problems. - *

The default values for the algorithm settings are: - *

    - *
  • {@link #setConvergenceChecker(ConvergenceChecker) vectorial convergence checker}: null
  • - *
  • {@link #setInitialStepBoundFactor(double) initial step bound factor}: 100.0
  • - *
  • {@link #setCostRelativeTolerance(double) cost relative tolerance}: 1.0e-10
  • - *
  • {@link #setParRelativeTolerance(double) parameters relative tolerance}: 1.0e-10
  • - *
  • {@link #setOrthoTolerance(double) orthogonality tolerance}: 1.0e-10
  • - *
  • {@link #setQRRankingThreshold(double) QR ranking threshold}: {@link MathUtils#SAFE_MIN}
  • - *
- *

- *

These default values may be overridden after construction. If the {@link - * #setConvergenceChecker vectorial convergence checker} is set to a non-null value, it - * will be used instead of the {@link #setCostRelativeTolerance cost relative tolerance} - * and {@link #setParRelativeTolerance parameters relative tolerance} settings. + * Build an optimizer for least squares problems with default values + * for all the tuning parameters (see the {@link + * #LevenbergMarquardtOptimizer(double,double,double,double,double) + * other contructor}. + * The default values for the algorithm settings are: + *

    + *
  • Initial step bound factor}: 100
  • + *
  • Cost relative tolerance}: 1e-10
  • + *
  • Parameters relative tolerance}: 1e-10
  • + *
  • Orthogonality tolerance}: 1e-10
  • + *
  • QR ranking threshold}: {@link MathUtils#SAFE_MIN}
  • + *
*/ public LevenbergMarquardtOptimizer() { - // default values for the tuning parameters - setConvergenceChecker(null); - setInitialStepBoundFactor(100.0); - setCostRelativeTolerance(1.0e-10); - setParRelativeTolerance(1.0e-10); - setOrthoTolerance(1.0e-10); - setQRRankingThreshold(MathUtils.SAFE_MIN); + this(100, 1e-10, 1e-10, 1e-10, MathUtils.SAFE_MIN); } /** - * Set the positive input variable used in determining the initial step bound. - * This bound is set to the product of initialStepBoundFactor and the euclidean - * norm of diag*x if nonzero, or else to initialStepBoundFactor itself. In most - * cases factor should lie in the interval (0.1, 100.0). 100.0 is a generally - * recommended value. + * Build an optimizer for least squares problems with default values + * for some of the tuning parameters (see the {@link + * #LevenbergMarquardtOptimizer(double,double,double,double,double) + * other contructor}. + * The default values for the algorithm settings are: + *
    + *
  • Initial step bound factor}: 100
  • + *
  • QR ranking threshold}: {@link MathUtils#SAFE_MIN}
  • + *
* - * @param initialStepBoundFactor initial step bound factor + * @param costRelativeTolerance Desired relative error in the sum of + * squares. + * @param parRelativeTolerance Desired relative error in the approximate + * solution parameters. + * @param orthoTolerance Desired max cosine on the orthogonality between + * the function vector and the columns of the Jacobian. */ - public void setInitialStepBoundFactor(double initialStepBoundFactor) { + public LevenbergMarquardtOptimizer(double costRelativeTolerance, + double parRelativeTolerance, + double orthoTolerance) { + this(100, + costRelativeTolerance, parRelativeTolerance, orthoTolerance, + MathUtils.SAFE_MIN); + } + + /** + * The arguments control the behaviour of the default convergence checking + * procedure. + * Additional criteria can defined through the setting of a {@link + * ConvergenceChecker}. + * + * @param initialStepBoundFactor Positive input variable used in + * determining the initial step bound. This bound is set to the + * product of initialStepBoundFactor and the euclidean norm of + * {@code diag * x} if non-zero, or else to {@code initialStepBoundFactor} + * itself. In most cases factor should lie in the interval + * {@code (0.1, 100.0)}. {@code 100} is a generally recommended value. + * @param costRelativeTolerance Desired relative error in the sum of + * squares. + * @param parRelativeTolerance Desired relative error in the approximate + * solution parameters. + * @param orthoTolerance Desired max cosine on the orthogonality between + * the function vector and the columns of the Jacobian. + * @param threshold Desired threshold for QR ranking. If the squared norm + * of a column vector is smaller or equal to this threshold during QR + * decomposition, it is considered to be a zero vector and hence the rank + * of the matrix is reduced. + */ + public LevenbergMarquardtOptimizer(double initialStepBoundFactor, + double costRelativeTolerance, + double parRelativeTolerance, + double orthoTolerance, + double threshold) { this.initialStepBoundFactor = initialStepBoundFactor; - } - - /** - * Set the desired relative error in the sum of squares. - *

This setting is used only if the {@link #setConvergenceChecker vectorial - * convergence checker} is set to null.

- * @param costRelativeTolerance desired relative error in the sum of squares - */ - public void setCostRelativeTolerance(double costRelativeTolerance) { this.costRelativeTolerance = costRelativeTolerance; - } - - /** - * Set the desired relative error in the approximate solution parameters. - *

This setting is used only if the {@link #setConvergenceChecker vectorial - * convergence checker} is set to null.

- * @param parRelativeTolerance desired relative error - * in the approximate solution parameters - */ - public void setParRelativeTolerance(double parRelativeTolerance) { this.parRelativeTolerance = parRelativeTolerance; - } - - /** - * Set the desired max cosine on the orthogonality. - *

This setting is always used, regardless of the {@link #setConvergenceChecker - * vectorial convergence checker} being null or non-null.

- * @param orthoTolerance desired max cosine on the orthogonality - * between the function vector and the columns of the jacobian - */ - public void setOrthoTolerance(double orthoTolerance) { this.orthoTolerance = orthoTolerance; - } - - /** - * Set the desired threshold for QR ranking. - *

- * If the squared norm of a column vector is smaller or equal to this threshold - * during QR decomposition, it is considered to be a zero vector and hence the - * rank of the matrix is reduced. - *

- * @param threshold threshold for QR ranking - */ - public void setQRRankingThreshold(final double threshold) { this.qrRankingThreshold = threshold; } /** {@inheritDoc} */ @Override - protected VectorialPointValuePair doOptimize() throws FunctionEvaluationException { - + protected VectorialPointValuePair doOptimize() + throws FunctionEvaluationException { // arrays shared with the other private methods solvedCols = FastMath.min(rows, cols); diagR = new double[cols]; @@ -446,14 +427,15 @@ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { objective = oldObj; oldObj = tmpVec; } - if (checker==null) { - if (((FastMath.abs(actRed) <= costRelativeTolerance) && - (preRed <= costRelativeTolerance) && - (ratio <= 2.0)) || - (delta <= parRelativeTolerance * xNorm)) { - return current; - } + + // Default convergence criteria. + if ((FastMath.abs(actRed) <= costRelativeTolerance && + preRed <= costRelativeTolerance && + ratio <= 2.0) || + delta <= parRelativeTolerance * xNorm) { + return current; } + // tests for termination and stringent tolerances // (2.2204e-16 is the machine epsilon for IEEE754) if ((FastMath.abs(actRed) <= 2.2204e-16) && (preRed <= 2.2204e-16) && (ratio <= 2.0)) { diff --git a/src/main/java/org/apache/commons/math/optimization/general/PowellOptimizer.java b/src/main/java/org/apache/commons/math/optimization/general/PowellOptimizer.java index 45eef67c4..6597e52b8 100644 --- a/src/main/java/org/apache/commons/math/optimization/general/PowellOptimizer.java +++ b/src/main/java/org/apache/commons/math/optimization/general/PowellOptimizer.java @@ -20,7 +20,10 @@ package org.apache.commons.math.optimization.general; import java.util.Arrays; import org.apache.commons.math.FunctionEvaluationException; +import org.apache.commons.math.util.FastMath; import org.apache.commons.math.analysis.UnivariateRealFunction; +import org.apache.commons.math.exception.NumberIsTooSmallException; +import org.apache.commons.math.exception.NotStrictlyPositiveException; import org.apache.commons.math.optimization.GoalType; import org.apache.commons.math.optimization.RealPointValuePair; import org.apache.commons.math.optimization.ConvergenceChecker; @@ -35,8 +38,10 @@ import org.apache.commons.math.optimization.univariate.UnivariateRealPointValueP * algorithm (as implemented in module {@code optimize.py} v0.5 of * SciPy). *
- * The user is responsible for calling {@link - * #setConvergenceChecker(ConvergenceChecker) ConvergenceChecker} + * The default stopping criterion is based on the differences of the + * function value between two successive iterations. It is however possible + * to define custom convergence criteria by calling a {@link + * #setConvergenceChecker(ConvergenceChecker) setConvergenceChecker} * prior to using the optimizer. * * @version $Revision$ $Date$ @@ -44,27 +49,49 @@ import org.apache.commons.math.optimization.univariate.UnivariateRealPointValueP */ public class PowellOptimizer extends AbstractScalarOptimizer { + /** + * Minimum relative tolerance. + */ + private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d); + /** + * Relative threshold. + */ + private double relativeThreshold; + /** + * Absolute threshold. + */ + private double absoluteThreshold; /** * Line search. */ - private LineSearch line = new LineSearch(); + private LineSearch line; /** - * Set the convergence checker. - * It also indirectly sets the line search tolerances to the square-root - * of the correponding tolerances in the checker. + * The arguments control the behaviour of the default convergence + * checking procedure. * - * @param checker Convergence checker. + * @param rel Relative threshold. + * @param abs Absolute threshold. + * @throws NotStrictlyPositiveException if {@code abs <= 0}. + * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}. */ - public void setConvergenceChecker(ConvergenceChecker checker) { - super.setConvergenceChecker(checker); + public PowellOptimizer(double rel, + double abs) { + if (rel < MIN_RELATIVE_TOLERANCE) { + throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true); + } + if (abs <= 0) { + throw new NotStrictlyPositiveException(abs); + } + relativeThreshold = rel; + absoluteThreshold = abs; // Line search tolerances can be much lower than the tolerances // required for the optimizer itself. final double minTol = 1e-4; - final double rel = Math.min(Math.sqrt(checker.getRelativeThreshold()), minTol); - final double abs = Math.min(Math.sqrt(checker.getAbsoluteThreshold()), minTol); - line.setConvergenceChecker(new BrentOptimizer.BrentConvergenceChecker(rel, abs)); + final double lsRel = Math.min(FastMath.sqrt(relativeThreshold), minTol); + final double lsAbs = Math.min(FastMath.sqrt(absoluteThreshold), minTol); + line = new LineSearch(lsRel, lsAbs); } /** {@inheritDoc} */ @@ -92,6 +119,9 @@ public class PowellOptimizer direc[i][i] = 1; } + final ConvergenceChecker checker + = getConvergenceChecker(); + double[] x = guess; double fVal = computeObjectiveValue(x); double[] x1 = x.clone(); @@ -122,9 +152,19 @@ public class PowellOptimizer } } + // Default convergence check. + boolean stop = 2 * (fX - fVal) <= (relativeThreshold * (FastMath.abs(fX) + + FastMath.abs(fVal)) + + absoluteThreshold); + final RealPointValuePair previous = new RealPointValuePair(x1, fX); final RealPointValuePair current = new RealPointValuePair(x, fVal); - if (getConvergenceChecker().converged(iter, previous, current)) { + if (!stop) { // User-defined stopping criteria. + if (checker != null) { + stop = checker.converged(iter, previous, current); + } + } + if (stop) { if (goal == GoalType.MINIMIZE) { return (fVal < fX) ? current : previous; } else { @@ -203,6 +243,15 @@ public class PowellOptimizer */ private UnivariateRealPointValuePair optimum; + /** + * @param rel Relative threshold. + * @param rel Absolute threshold. + */ + LineSearch(double rel, + double abs) { + super(rel, abs); + } + /** * Find the minimum of the function {@code f(p + alpha * d)}. * diff --git a/src/main/java/org/apache/commons/math/optimization/univariate/BrentOptimizer.java b/src/main/java/org/apache/commons/math/optimization/univariate/BrentOptimizer.java index c10581f23..7936a28cd 100644 --- a/src/main/java/org/apache/commons/math/optimization/univariate/BrentOptimizer.java +++ b/src/main/java/org/apache/commons/math/optimization/univariate/BrentOptimizer.java @@ -19,7 +19,6 @@ package org.apache.commons.math.optimization.univariate; import org.apache.commons.math.FunctionEvaluationException; import org.apache.commons.math.util.MathUtils; import org.apache.commons.math.util.FastMath; -import org.apache.commons.math.exception.DimensionMismatchException; import org.apache.commons.math.exception.NumberIsTooSmallException; import org.apache.commons.math.exception.NotStrictlyPositiveException; import org.apache.commons.math.exception.MathUnsupportedOperationException; @@ -48,9 +47,21 @@ public class BrentOptimizer extends AbstractUnivariateRealOptimizer { * Golden section. */ private static final double GOLDEN_SECTION = 0.5 * (3 - FastMath.sqrt(5)); + /** + * Minimum relative tolerance. + */ + private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d); + /** + * Relative threshold. + */ + private final double relativeThreshold; + /** + * Absolute threshold. + */ + private final double absoluteThreshold; /** - * Convergence checker that implements the original stopping criterion + * The arguments are used implement the original stopping criterion * of Brent's algorithm. * {@code abs} and {@code rel} define a tolerance * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than @@ -58,82 +69,21 @@ public class BrentOptimizer extends AbstractUnivariateRealOptimizer { * where macheps is the relative machine precision. {@code abs} must * be positive. * - * @since 3.0 + * @param rel Relative threshold. + * @param abs Absolute threshold. + * @throws NotStrictlyPositiveException if {@code abs <= 0}. + * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}. */ - public static class BrentConvergenceChecker - extends AbstractConvergenceChecker { - /** - * Minimum relative tolerance. - */ - private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d); - - /** - * Build an instance with specified thresholds. - * - * @param rel Relative tolerance threshold - * @param abs Absolute tolerance threshold - */ - public BrentConvergenceChecker(final double rel, - final double abs) { - super(rel, abs); - - if (rel < MIN_RELATIVE_TOLERANCE) { - throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true); - } - if (abs <= 0) { - throw new NotStrictlyPositiveException(abs); - } + public BrentOptimizer(double rel, + double abs) { + if (rel < MIN_RELATIVE_TOLERANCE) { + throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true); } - - /** - * Convergence criterion. - * - * @param iteration Current iteration. - * @param points Points used for checking the stopping criterion. The list - * must contain 3 points (in the following order): - *
    - *
  • the lower end of the current interval
  • - *
  • the current best point
  • - *
  • the higher end of the current interval
  • - *
- * @return {@code true} if the stopping criterion is satisfied. - * @throws DimensionMismatchException if the length of the {@code points} - * list is not equal to 3. - */ - public boolean converged(final int iteration, - final UnivariateRealPointValuePair ... points) { - if (points.length != 3) { - throw new DimensionMismatchException(points.length, 3); - } - - final double a = points[0].getPoint(); - final double x = points[1].getPoint(); - final double b = points[2].getPoint(); - - final double tol1 = getRelativeThreshold() * FastMath.abs(x) + getAbsoluteThreshold(); - final double tol2 = 2 * tol1; - - final double m = 0.5 * (a + b); - return FastMath.abs(x - m) <= tol2 - 0.5 * (b - a); - } - } - - /** - * Set the convergence checker. - * Since this algorithm requires a specific checker, this method will throw - * an {@code UnsupportedOperationexception} if the argument type is not - * {@link BrentConvergenceChecker}. - * - * @throws MathUnsupportedOperationexception if the checker is not an - * instance of {@link BrentConvergenceChecker}. - */ - @Override - public void setConvergenceChecker(ConvergenceChecker checker) { - if (checker instanceof BrentConvergenceChecker) { - super.setConvergenceChecker(checker); - } else { - throw new MathUnsupportedOperationException(); + if (abs <= 0) { + throw new NotStrictlyPositiveException(abs); } + relativeThreshold = rel; + absoluteThreshold = abs; } /** {@inheritDoc} */ @@ -144,10 +94,9 @@ public class BrentOptimizer extends AbstractUnivariateRealOptimizer { final double mid = getStartValue(); final double hi = getMax(); + // Optional additional convergence criteria. final ConvergenceChecker checker = getConvergenceChecker(); - final double eps = checker.getRelativeThreshold(); - final double t = checker.getAbsoluteThreshold(); double a; double b; @@ -171,19 +120,19 @@ public class BrentOptimizer extends AbstractUnivariateRealOptimizer { double fv = fx; double fw = fx; + UnivariateRealPointValuePair previous = null; + UnivariateRealPointValuePair current + = new UnivariateRealPointValuePair(x, (isMinim ? fx : -fx)); + int iter = 0; while (true) { - double m = 0.5 * (a + b); - final double tol1 = eps * FastMath.abs(x) + t; + final double m = 0.5 * (a + b); + final double tol1 = relativeThreshold * FastMath.abs(x) + absoluteThreshold; final double tol2 = 2 * tol1; - // Check stopping criterion. - // This test will work only if the "checker" is an instance of - // "BrentOptimizer.BrentConvergenceChecker". - if (!getConvergenceChecker().converged(iter, - new UnivariateRealPointValuePair(a, Double.NaN), - new UnivariateRealPointValuePair(x, Double.NaN), - new UnivariateRealPointValuePair(b, Double.NaN))) { + // Default stopping criterion. + final boolean stop = FastMath.abs(x - m) <= tol2 - 0.5 * (b - a); + if (!stop) { double p = 0; double q = 0; double r = 0; @@ -283,8 +232,18 @@ public class BrentOptimizer extends AbstractUnivariateRealOptimizer { fv = fu; } } - } else { // Termination. - return new UnivariateRealPointValuePair(x, (isMinim ? fx : -fx)); + + previous = current; + current = new UnivariateRealPointValuePair(x, (isMinim ? fx : -fx)); + + // User-defined convergence checker. + if (checker != null) { + if (checker.converged(iter, previous, current)) { + return current; + } + } + } else { // Default termination (Brent's criterion). + return current; } ++iter; } diff --git a/src/test/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizerTest.java b/src/test/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizerTest.java index f80dc460b..0acfa5921 100644 --- a/src/test/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizerTest.java +++ b/src/test/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizerTest.java @@ -35,6 +35,7 @@ import org.apache.commons.math.linear.BlockRealMatrix; import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.optimization.SimpleVectorialValueChecker; import org.apache.commons.math.optimization.VectorialPointValuePair; +import org.apache.commons.math.util.MathUtils; import org.apache.commons.math.util.FastMath; /** @@ -140,7 +141,6 @@ public class LevenbergMarquardtOptimizerTest assertEquals(4.0, optimum.getValue()[0], 1.0e-10); assertEquals(6.0, optimum.getValue()[1], 1.0e-10); assertEquals(1.0, optimum.getValue()[2], 1.0e-10); - } public void testNoDependency() throws FunctionEvaluationException { @@ -176,7 +176,6 @@ public class LevenbergMarquardtOptimizerTest assertEquals(1.0, optimum.getPoint()[0], 1.0e-10); assertEquals(2.0, optimum.getPoint()[1], 1.0e-10); assertEquals(3.0, optimum.getPoint()[2], 1.0e-10); - } public void testTwoSets() throws FunctionEvaluationException { @@ -201,7 +200,6 @@ public class LevenbergMarquardtOptimizerTest assertEquals(-2.0, optimum.getPoint()[3], 1.0e-10); assertEquals( 1.0 + epsilon, optimum.getPoint()[4], 1.0e-10); assertEquals( 1.0 - epsilon, optimum.getPoint()[5], 1.0e-10); - } public void testNonInversible() throws FunctionEvaluationException { @@ -223,7 +221,6 @@ public class LevenbergMarquardtOptimizerTest } catch (Exception e) { fail("wrong exception caught"); } - } public void testIllConditioned() throws FunctionEvaluationException { @@ -257,7 +254,6 @@ public class LevenbergMarquardtOptimizerTest assertEquals(137.0, optimum2.getPoint()[1], 1.0e-8); assertEquals(-34.0, optimum2.getPoint()[2], 1.0e-8); assertEquals( 22.0, optimum2.getPoint()[3], 1.0e-8); - } public void testMoreEstimatedParametersSimple() throws FunctionEvaluationException { @@ -272,7 +268,6 @@ public class LevenbergMarquardtOptimizerTest optimizer.optimize(problem, problem.target, new double[] { 1, 1, 1 }, new double[] { 7, 6, 5, 4 }); assertEquals(0, optimizer.getRMS(), 1.0e-10); - } public void testMoreEstimatedParametersUnsorted() throws FunctionEvaluationException { @@ -293,7 +288,6 @@ public class LevenbergMarquardtOptimizerTest assertEquals(4.0, optimum.getPointRef()[3], 1.0e-10); assertEquals(5.0, optimum.getPointRef()[4], 1.0e-10); assertEquals(6.0, optimum.getPointRef()[5], 1.0e-10); - } public void testRedundantEquations() throws FunctionEvaluationException { @@ -310,7 +304,6 @@ public class LevenbergMarquardtOptimizerTest assertEquals(0, optimizer.getRMS(), 1.0e-10); assertEquals(2.0, optimum.getPointRef()[0], 1.0e-10); assertEquals(1.0, optimum.getPointRef()[1], 1.0e-10); - } public void testInconsistentEquations() throws FunctionEvaluationException { @@ -323,7 +316,6 @@ public class LevenbergMarquardtOptimizerTest LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(); optimizer.optimize(problem, problem.target, new double[] { 1, 1, 1 }, new double[] { 1, 1 }); assertTrue(optimizer.getRMS() > 0.1); - } public void testInconsistentSizes() throws FunctionEvaluationException { @@ -358,7 +350,6 @@ public class LevenbergMarquardtOptimizerTest } catch (Exception e) { fail("wrong exception caught"); } - } public void testControlParameters() { @@ -380,12 +371,13 @@ public class LevenbergMarquardtOptimizerTest double costRelativeTolerance, double parRelativeTolerance, double orthoTolerance, boolean shouldFail) { try { - LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(); - optimizer.setInitialStepBoundFactor(initialStepBoundFactor); + LevenbergMarquardtOptimizer optimizer + = new LevenbergMarquardtOptimizer(initialStepBoundFactor, + costRelativeTolerance, + parRelativeTolerance, + orthoTolerance, + MathUtils.SAFE_MIN); optimizer.setMaxEvaluations(maxCostEval); - optimizer.setCostRelativeTolerance(costRelativeTolerance); - optimizer.setParRelativeTolerance(parRelativeTolerance); - optimizer.setOrthoTolerance(orthoTolerance); optimizer.optimize(problem, new double[] { 0, 0, 0, 0, 0 }, new double[] { 1, 1, 1, 1, 1 }, new double[] { 98.680, 47.345 }); assertTrue(!shouldFail); @@ -444,7 +436,6 @@ public class LevenbergMarquardtOptimizerTest errors = optimizer.guessParametersErrors(); assertEquals(0.004, errors[0], 0.001); assertEquals(0.004, errors[1], 0.001); - } public void testCircleFittingBadInit() throws FunctionEvaluationException { @@ -508,8 +499,8 @@ public class LevenbergMarquardtOptimizerTest problem.addPoint (2, -2.1488478161387325); problem.addPoint (3, -1.9122489313410047); problem.addPoint (4, 1.7785661310051026); - LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(); - optimizer.setQRRankingThreshold(0); + LevenbergMarquardtOptimizer optimizer + = new LevenbergMarquardtOptimizer(100, 1e-10, 1e-10, 1e-10, 0); optimizer.optimize(problem, new double[] { 0, 0, 0, 0, 0 }, new double[] { 0.0, 4.4e-323, 1.0, 4.4e-323, 0.0 }, @@ -518,7 +509,6 @@ public class LevenbergMarquardtOptimizerTest } catch (ConvergenceException ee) { // expected behavior } - } private static class LinearProblem implements DifferentiableMultivariateVectorialFunction, Serializable { @@ -543,7 +533,6 @@ public class LevenbergMarquardtOptimizerTest } }; } - } private static class Circle implements DifferentiableMultivariateVectorialFunction, Serializable { @@ -598,7 +587,6 @@ public class LevenbergMarquardtOptimizerTest } return jacobian; - } public double[] value(double[] variables) @@ -613,7 +601,6 @@ public class LevenbergMarquardtOptimizerTest } return residuals; - } public MultivariateMatrixFunction jacobian() { @@ -624,7 +611,6 @@ public class LevenbergMarquardtOptimizerTest } }; } - } private static class QuadraticProblem implements DifferentiableMultivariateVectorialFunction, Serializable { @@ -669,7 +655,5 @@ public class LevenbergMarquardtOptimizerTest } }; } - } - } diff --git a/src/test/java/org/apache/commons/math/optimization/general/MinpackTest.java b/src/test/java/org/apache/commons/math/optimization/general/MinpackTest.java index f6404eb46..47055c644 100644 --- a/src/test/java/org/apache/commons/math/optimization/general/MinpackTest.java +++ b/src/test/java/org/apache/commons/math/optimization/general/MinpackTest.java @@ -489,11 +489,11 @@ public class MinpackTest extends TestCase { } private void minpackTest(MinpackFunction function, boolean exceptionExpected) { - LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(); + LevenbergMarquardtOptimizer optimizer + = new LevenbergMarquardtOptimizer(FastMath.sqrt(2.22044604926e-16), + FastMath.sqrt(2.22044604926e-16), + 2.22044604926e-16); optimizer.setMaxEvaluations(400 * (function.getN() + 1)); - optimizer.setCostRelativeTolerance(FastMath.sqrt(2.22044604926e-16)); - optimizer.setParRelativeTolerance(FastMath.sqrt(2.22044604926e-16)); - optimizer.setOrthoTolerance(2.22044604926e-16); // assertTrue(function.checkTheoreticalStartCost(optimizer.getRMS())); try { VectorialPointValuePair optimum = diff --git a/src/test/java/org/apache/commons/math/optimization/general/PowellOptimizerTest.java b/src/test/java/org/apache/commons/math/optimization/general/PowellOptimizerTest.java index 9a8ece534..e6af4f206 100644 --- a/src/test/java/org/apache/commons/math/optimization/general/PowellOptimizerTest.java +++ b/src/test/java/org/apache/commons/math/optimization/general/PowellOptimizerTest.java @@ -137,9 +137,8 @@ public class PowellOptimizerTest { double fTol, double pointTol) throws MathException { - final MultivariateRealOptimizer optim = new PowellOptimizer(); + final MultivariateRealOptimizer optim = new PowellOptimizer(fTol, Math.ulp(1d)); optim.setMaxEvaluations(1000); - optim.setConvergenceChecker(new SimpleScalarValueChecker(fTol, Math.ulp(1d))); final RealPointValuePair result = optim.optimize(func, goal, init); final double[] found = result.getPoint(); diff --git a/src/test/java/org/apache/commons/math/optimization/univariate/BrentOptimizerTest.java b/src/test/java/org/apache/commons/math/optimization/univariate/BrentOptimizerTest.java index 8b747d0a1..9e1fdb069 100644 --- a/src/test/java/org/apache/commons/math/optimization/univariate/BrentOptimizerTest.java +++ b/src/test/java/org/apache/commons/math/optimization/univariate/BrentOptimizerTest.java @@ -38,15 +38,12 @@ public final class BrentOptimizerTest { @Test public void testSinMin() throws MathException { UnivariateRealFunction f = new SinFunction(); - UnivariateRealOptimizer optimizer = new BrentOptimizer(); - optimizer.setConvergenceChecker(new BrentOptimizer.BrentConvergenceChecker(1e-10, 1e-14)); + UnivariateRealOptimizer optimizer = new BrentOptimizer(1e-10, 1e-14); optimizer.setMaxEvaluations(200); assertEquals(200, optimizer.getMaxEvaluations()); - assertEquals(3 * Math.PI / 2, optimizer.optimize(f, GoalType.MINIMIZE, 4, 5).getPoint(), - 100 * optimizer.getConvergenceChecker().getRelativeThreshold()); + assertEquals(3 * Math.PI / 2, optimizer.optimize(f, GoalType.MINIMIZE, 4, 5).getPoint(),1e-8); assertTrue(optimizer.getEvaluations() <= 50); - assertEquals(3 * Math.PI / 2, optimizer.optimize(f, GoalType.MINIMIZE, 1, 5).getPoint(), - 100 * optimizer.getConvergenceChecker().getRelativeThreshold()); + assertEquals(3 * Math.PI / 2, optimizer.optimize(f, GoalType.MINIMIZE, 1, 5).getPoint(), 1e-8); assertTrue(optimizer.getEvaluations() <= 100); assertTrue(optimizer.getEvaluations() >= 15); optimizer.setMaxEvaluations(10); @@ -64,8 +61,7 @@ public final class BrentOptimizerTest { public void testQuinticMin() throws MathException { // The function has local minima at -0.27195613 and 0.82221643. UnivariateRealFunction f = new QuinticFunction(); - UnivariateRealOptimizer optimizer = new BrentOptimizer(); - optimizer.setConvergenceChecker(new BrentOptimizer.BrentConvergenceChecker(1e-10, 1e-14)); + UnivariateRealOptimizer optimizer = new BrentOptimizer(1e-10, 1e-14); optimizer.setMaxEvaluations(200); assertEquals(-0.27195613, optimizer.optimize(f, GoalType.MINIMIZE, -0.3, -0.2).getPoint(), 1.0e-8); assertEquals( 0.82221643, optimizer.optimize(f, GoalType.MINIMIZE, 0.3, 0.9).getPoint(), 1.0e-8); @@ -80,8 +76,7 @@ public final class BrentOptimizerTest { public void testQuinticMinStatistics() throws MathException { // The function has local minima at -0.27195613 and 0.82221643. UnivariateRealFunction f = new QuinticFunction(); - UnivariateRealOptimizer optimizer = new BrentOptimizer(); - optimizer.setConvergenceChecker(new BrentOptimizer.BrentConvergenceChecker(1e-12, 1e-14)); + UnivariateRealOptimizer optimizer = new BrentOptimizer(1e-11, 1e-14); optimizer.setMaxEvaluations(40); final DescriptiveStatistics[] stat = new DescriptiveStatistics[2]; @@ -101,8 +96,9 @@ public final class BrentOptimizerTest { final double meanOptValue = stat[0].getMean(); final double medianEval = stat[1].getPercentile(50); - assertTrue(meanOptValue > -0.2719561281 && meanOptValue < -0.2719561280); - assertEquals((int) medianEval, 27); + assertTrue(meanOptValue > -0.2719561281); + assertTrue(meanOptValue < -0.2719561280); + assertEquals(23, (int) medianEval); } @Test(expected = TooManyEvaluationsException.class) @@ -110,8 +106,7 @@ public final class BrentOptimizerTest { // The quintic function has zeros at 0, +-0.5 and +-1. // The function has a local maximum at 0.27195613. UnivariateRealFunction f = new QuinticFunction(); - UnivariateRealOptimizer optimizer = new BrentOptimizer(); - optimizer.setConvergenceChecker(new BrentOptimizer.BrentConvergenceChecker(1e-12, 1e-14)); + UnivariateRealOptimizer optimizer = new BrentOptimizer(1e-12, 1e-14); assertEquals(0.27195613, optimizer.optimize(f, GoalType.MAXIMIZE, 0.2, 0.3).getPoint(), 1e-8); optimizer.setMaxEvaluations(5); try { @@ -127,15 +122,14 @@ public final class BrentOptimizerTest { @Test public void testMinEndpoints() throws Exception { UnivariateRealFunction f = new SinFunction(); - UnivariateRealOptimizer optimizer = new BrentOptimizer(); - optimizer.setConvergenceChecker(new BrentOptimizer.BrentConvergenceChecker(1e-8, 1e-14)); + UnivariateRealOptimizer optimizer = new BrentOptimizer(1e-8, 1e-14); optimizer.setMaxEvaluations(50); // endpoint is minimum double result = optimizer.optimize(f, GoalType.MINIMIZE, 3 * Math.PI / 2, 5).getPoint(); - assertEquals(3 * Math.PI / 2, result, 100 * optimizer.getConvergenceChecker().getRelativeThreshold()); + assertEquals(3 * Math.PI / 2, result, 1e-6); result = optimizer.optimize(f, GoalType.MINIMIZE, 4, 3 * Math.PI / 2).getPoint(); - assertEquals(3 * Math.PI / 2, result, 100 * optimizer.getConvergenceChecker().getRelativeThreshold()); + assertEquals(3 * Math.PI / 2, result, 1e-6); } } diff --git a/src/test/java/org/apache/commons/math/optimization/univariate/MultiStartUnivariateRealOptimizerTest.java b/src/test/java/org/apache/commons/math/optimization/univariate/MultiStartUnivariateRealOptimizerTest.java index 47864e2a6..a688df91b 100644 --- a/src/test/java/org/apache/commons/math/optimization/univariate/MultiStartUnivariateRealOptimizerTest.java +++ b/src/test/java/org/apache/commons/math/optimization/univariate/MultiStartUnivariateRealOptimizerTest.java @@ -36,8 +36,7 @@ public class MultiStartUnivariateRealOptimizerTest { @Test public void testSinMin() throws MathException { UnivariateRealFunction f = new SinFunction(); - UnivariateRealOptimizer underlying = new BrentOptimizer(); - underlying.setConvergenceChecker(new BrentOptimizer.BrentConvergenceChecker(1e-10, 1e-14)); + UnivariateRealOptimizer underlying = new BrentOptimizer(1e-10, 1e-14); underlying.setMaxEvaluations(300); JDKRandomGenerator g = new JDKRandomGenerator(); g.setSeed(44428400075l); @@ -60,8 +59,7 @@ public class MultiStartUnivariateRealOptimizerTest { // The quintic function has zeros at 0, +-0.5 and +-1. // The function has extrema (first derivative is zero) at 0.27195613 and 0.82221643, UnivariateRealFunction f = new QuinticFunction(); - UnivariateRealOptimizer underlying = new BrentOptimizer(); - underlying.setConvergenceChecker(new BrentOptimizer.BrentConvergenceChecker(1e-9, 1e-14)); + UnivariateRealOptimizer underlying = new BrentOptimizer(1e-9, 1e-14); underlying.setMaxEvaluations(300); JDKRandomGenerator g = new JDKRandomGenerator(); g.setSeed(4312000053L);