From c5ae09d77e8756ab73e1c39f77f1b303aa4ba384 Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Mon, 17 Mar 2014 15:14:07 +0000 Subject: [PATCH] Improved brackting utility for univariate solvers. Bracketing utility for univariate root solvers now returns a tighter interval than before. It also allows choosing the search interval expansion rate, supporting both linear and asymptotically exponential rates. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1578428 13f79535-47bb-0310-9956-ffa450edef68 --- src/changes/changes.xml | 5 + .../solvers/UnivariateSolverUtils.java | 180 ++++++++++++------ .../solvers/UnivariateSolverUtilsTest.java | 67 ++++++- 3 files changed, 191 insertions(+), 61 deletions(-) diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 99692b549..aabcb2c0d 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -51,6 +51,11 @@ If the output is not quite correct, check for invisible trailing spaces! + + Bracketing utility for univariate root solvers returns a tighter interval than before. + It also allows choosing the search interval expansion rate, supporting both linear + and asymptotically exponential rates. + Prevent penalties to grow multiplicatively in CMAES for out of bounds points. diff --git a/src/main/java/org/apache/commons/math3/analysis/solvers/UnivariateSolverUtils.java b/src/main/java/org/apache/commons/math3/analysis/solvers/UnivariateSolverUtils.java index ec205784a..18e8ebd5c 100644 --- a/src/main/java/org/apache/commons/math3/analysis/solvers/UnivariateSolverUtils.java +++ b/src/main/java/org/apache/commons/math3/analysis/solvers/UnivariateSolverUtils.java @@ -171,31 +171,16 @@ public class UnivariateSolverUtils { } /** - * This method attempts to find two values a and b satisfying - * If f is continuous on [a,b], this means that a - * and b bracket a root of f. - *

- * The algorithm starts by setting - * a := initial -1; b := initial +1, examines the value of the - * function at a and b and keeps moving - * the endpoints out by one unit each time through a loop that terminates - * when one of the following happens:

- *

+ * This method simply calls {@link #bracket(UnivariateFunction, double, double, double, + * double, double, int) bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)} + * with {@code q} and {@code r} set to 1.0 and {@code maximumIterations} set to {@code Integer.MAX_VALUE}. * Note: this method can take * Integer.MAX_VALUE iterations to throw a * ConvergenceException. Unless you are confident that there * is a root between lowerBound and upperBound * near initial, it is better to use - * {@link #bracket(UnivariateFunction, double, double, double, int)}, + * {@link #bracket(UnivariateFunction, double, double, double, + * double, int) bracket(function, initial, lowerBound, upperBound, delta, maximumIterations)}, * explicitly specifying the maximum number of iterations.

* * @param function Function. @@ -215,28 +200,13 @@ public class UnivariateSolverUtils { throws NullArgumentException, NotStrictlyPositiveException, NoBracketingException { - return bracket(function, initial, lowerBound, upperBound, Integer.MAX_VALUE); + return bracket(function, initial, lowerBound, upperBound, 1.0, 1.0, Integer.MAX_VALUE); } /** - * This method attempts to find two values a and b satisfying - * If f is continuous on [a,b], this means that a - * and b bracket a root of f. - *

- * The algorithm starts by setting - * a := initial -1; b := initial +1, examines the value of the - * function at a and b and keeps moving - * the endpoints out by one unit each time through a loop that terminates - * when one of the following happens:

- * + * This method simply calls {@link #bracket(UnivariateFunction, double, double, double, + * double, double, int) bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)} + * with {@code q} and {@code r} set to 1.0. * @param function Function. * @param initial Initial midpoint of interval being expanded to * bracket a root. @@ -257,38 +227,132 @@ public class UnivariateSolverUtils { throws NullArgumentException, NotStrictlyPositiveException, NoBracketingException { + return bracket(function, initial, lowerBound, upperBound, 1.0, 1.0, maximumIterations); + } + + /** + * This method attempts to find two values a and b satisfying + * If {@code f} is continuous on {@code [a,b]}, this means that {@code a} + * and {@code b} bracket a root of {@code f}. + *

+ * The algorithm checks the sign of \( f(l_k) \) and \( f(u_k) \) for increasing + * values of k, where \( l_k = max(lower, initial - \delta_k) \), + * \( u_k = min(upper, initial + \delta_k) \), using recurrence + * \( \delta_{k+1} = r \delta_k + q, \delta_0 = 0\) and starting search with \( k=1 \). + * The algorithm stops when one of the following happens:

+ *

+ * If different signs are found at first iteration ({@code k=1}), then the returned + * interval will be \( [a, b] = [l_1, u_1] \). If different signs are found at a later + * iteration ({code k>1}, then the returned interval will be either + * \( [a, b] = [l_{k+1}, l_{k}] \) or ( [a, b] = [u_{k}, u_{k+1}] \). A root solver called + * with these parameters will therefore start with the smallest bracketing interval known + * at this step. + *

+ *

+ * Interval expansion rate is tuned by changing the recurrence parameters {@code r} and + * {@code q}. When the multiplicative factor {@code r} is set to 1, the sequence is a + * simple arithmetic sequence with linear increase. When the multiplicative factor {@code r} + * is larger than 1, the sequence has an asymtotically exponential rate. Note than the + * additive parameter {@code q} should never be set to zero, otherwise the interval would + * degenerate to the single initial point for all values of {@code k}. + *

+ *

+ * As a rule of thumb, when the location of the root is expected to be approximately known + * within some error margin, {@code r} should be set to 1 and {@code q} should be set to the + * order of magnitude of the error margin. When the location of the root is really a wild guess, + * then {@code r} should be set to a value larger than 1 (typically 2 to double the interval + * length at each iteration) and {@code q} should be set according to half the initial + * search interval length. + *

+ *

+ * As an example, if we consider the trivial function {@code f(x) = 1 - x} and use + * {@code initial = 4}, {@code r = 1}, {@code q = 2}, the algorithm will compute + * {@code f(4-2) = f(2) = -1} and {@code f(4+2) = f(6) = -5} for {@code k = 1}, then + * {@code f(4-4) = f(0) = +1} and {@code f(4+4) = f(8) = -7} for {@code k = 2}. Then it will + * return the interval {@code [0, 2]} as the smallest one known to be bracketing the root. + * As shown by this example, the initial value (here {@code 4}) may lie outside of the returned + * bracketing interval. + *

+ * @param function function to check + * @param initial Initial midpoint of interval being expanded to + * bracket a root. + * @param lowerBound Lower bound (a is never lower than this value). + * @param upperBound Upper bound (b never is greater than this + * value). + * @param q additive offset used to compute bounds sequence (must be strictly positive) + * @param r multiplicative factor used to compute bounds sequence + * @param maximumIterations Maximum number of iterations to perform + * @return a two element array holding the bracketing values. + * @exception NoBracketingException if function cannot be bracketed in the search interval + */ + public static double[] bracket(final UnivariateFunction function, final double initial, + final double lowerBound, final double upperBound, + final double q, final double r, final int maximumIterations) + throws NoBracketingException { + if (function == null) { throw new NullArgumentException(LocalizedFormats.FUNCTION); } + if (q <= 0) { + throw new NotStrictlyPositiveException(q); + } if (maximumIterations <= 0) { throw new NotStrictlyPositiveException(LocalizedFormats.INVALID_MAX_ITERATIONS, maximumIterations); } verifySequence(lowerBound, initial, upperBound); - double a = initial; - double b = initial; - double fa; - double fb; - int numIterations = 0; + // initialize the recurrence + double a = initial; + double b = initial; + double fa = Double.NaN; + double fb = Double.NaN; + double delta = 0; - do { - a = FastMath.max(a - 1.0, lowerBound); - b = FastMath.min(b + 1.0, upperBound); - fa = function.value(a); + for (int numIterations = 0; + (numIterations < maximumIterations) && (a > lowerBound || b > upperBound); + ++numIterations) { - fb = function.value(b); - ++numIterations; - } while ((fa * fb > 0.0) && (numIterations < maximumIterations) && - ((a > lowerBound) || (b < upperBound))); + final double previousA = a; + final double previousFa = fa; + final double previousB = b; + final double previousFb = fb; + + delta = r * delta + q; + a = FastMath.max(initial - delta, lowerBound); + b = FastMath.min(initial + delta, upperBound); + fa = function.value(a); + fb = function.value(b); + + if (numIterations == 0) { + // at first iteration, we don't have a previous interval + // we simply compare both sides of the initial interval + if (fa * fb <= 0) { + // the first interval already brackets a root + return new double[] { a, b }; + } + } else { + // we have a previous interval with constant sign and expand it, + // we expect sign changes to occur at boundaries + if (fa * previousFa <= 0) { + // sign change detected at near lower bound + return new double[] { a, previousA }; + } else if (fb * previousFb <= 0) { + // sign change detected at near upper bound + return new double[] { previousB, b }; + } + } - if (fa * fb > 0.0) { - throw new NoBracketingException(LocalizedFormats.FAILED_BRACKETING, - a, b, fa, fb, - numIterations, maximumIterations, initial, - lowerBound, upperBound); } - return new double[] {a, b}; + // no bracketing found + throw new NoBracketingException(a, b, fa, fb); + } /** diff --git a/src/test/java/org/apache/commons/math3/analysis/solvers/UnivariateSolverUtilsTest.java b/src/test/java/org/apache/commons/math3/analysis/solvers/UnivariateSolverUtilsTest.java index e37522f68..83b76e0dc 100644 --- a/src/test/java/org/apache/commons/math3/analysis/solvers/UnivariateSolverUtilsTest.java +++ b/src/test/java/org/apache/commons/math3/analysis/solvers/UnivariateSolverUtilsTest.java @@ -21,6 +21,7 @@ import org.apache.commons.math3.analysis.QuinticFunction; import org.apache.commons.math3.analysis.UnivariateFunction; import org.apache.commons.math3.analysis.function.Sin; import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NoBracketingException; import org.apache.commons.math3.util.FastMath; import org.junit.Assert; import org.junit.Test; @@ -86,6 +87,56 @@ public class UnivariateSolverUtilsTest { Assert.assertTrue(sin.value(result[1]) > 0); } + @Test + public void testBracketCentered() { + double initial = 0.1; + double[] result = UnivariateSolverUtils.bracket(sin, initial, -2.0, 2.0, 0.2, 1.0, 100); + Assert.assertTrue(result[0] < initial); + Assert.assertTrue(result[1] > initial); + Assert.assertTrue(sin.value(result[0]) < 0); + Assert.assertTrue(sin.value(result[1]) > 0); + } + + @Test + public void testBracketLow() { + double initial = 0.5; + double[] result = UnivariateSolverUtils.bracket(sin, initial, -2.0, 2.0, 0.2, 1.0, 100); + Assert.assertTrue(result[0] < initial); + Assert.assertTrue(result[1] < initial); + Assert.assertTrue(sin.value(result[0]) < 0); + Assert.assertTrue(sin.value(result[1]) > 0); + } + + @Test + public void testBracketHigh(){ + double initial = -0.5; + double[] result = UnivariateSolverUtils.bracket(sin, initial, -2.0, 2.0, 0.2, 1.0, 100); + Assert.assertTrue(result[0] > initial); + Assert.assertTrue(result[1] > initial); + Assert.assertTrue(sin.value(result[0]) < 0); + Assert.assertTrue(sin.value(result[1]) > 0); + } + + @Test(expected=NoBracketingException.class) + public void testBracketLinear(){ + UnivariateSolverUtils.bracket(new UnivariateFunction() { + public double value(double x) { + return 1 - x; + } + }, 1000, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0, 1.0, 100); + } + + @Test + public void testBracketExponential(){ + double[] result = UnivariateSolverUtils.bracket(new UnivariateFunction() { + public double value(double x) { + return 1 - x; + } + }, 1000, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0, 2.0, 10); + Assert.assertTrue(result[0] <= 1); + Assert.assertTrue(result[1] >= 1); + } + @Test public void testBracketEndpointRoot() { double[] result = UnivariateSolverUtils.bracket(sin, 1.5, 0, 2.0); @@ -97,18 +148,28 @@ public class UnivariateSolverUtilsTest { public void testNullFunction() { UnivariateSolverUtils.bracket(null, 1.5, 0, 2.0); } - + @Test(expected=MathIllegalArgumentException.class) public void testBadInitial() { UnivariateSolverUtils.bracket(sin, 2.5, 0, 2.0); } - + + @Test(expected=MathIllegalArgumentException.class) + public void testBadAdditive() { + UnivariateSolverUtils.bracket(sin, 1.0, -2.0, 3.0, -1.0, 1.0, 100); + } + + @Test(expected=NoBracketingException.class) + public void testIterationExceeded() { + UnivariateSolverUtils.bracket(sin, 1.0, -2.0, 3.0, 1.0e-5, 1.0, 100); + } + @Test(expected=MathIllegalArgumentException.class) public void testBadEndpoints() { // endpoints not valid UnivariateSolverUtils.bracket(sin, 1.5, 2.0, 1.0); } - + @Test(expected=MathIllegalArgumentException.class) public void testBadMaximumIterations() { // bad maximum iterations