From c5ae09d77e8756ab73e1c39f77f1b303aa4ba384 Mon Sep 17 00:00:00 2001
From: Luc Maisonobe
- * The algorithm starts by setting
- *
- *
- * If f is continuous on lowerBound <= a < initial < b <= upperBound
f(a) * f(b) < 0
[a,b],
this means that a
- * and b
bracket a root of f.
- * a := initial -1; b := initial +1,
examines the value of the
- * function at a
and b
and keeps moving
- * the endpoints out by one unit each time through a loop that terminates
- * when one of the following happens:
- *
f(a) * f(b) < 0
-- success! a = lower
and b = upper
- * -- NoBracketingException Integer.MAX_VALUE
iterations elapse
- * -- NoBracketingException
+ * This method simply calls {@link #bracket(UnivariateFunction, double, double, double,
+ * double, double, int) bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)}
+ * with {@code q} and {@code r} set to 1.0 and {@code maximumIterations} set to {@code Integer.MAX_VALUE}.
* Note: this method can take
* Integer.MAX_VALUE
iterations to throw a
* ConvergenceException.
Unless you are confident that there
* is a root between lowerBound
and upperBound
* near initial,
it is better to use
- * {@link #bracket(UnivariateFunction, double, double, double, int)},
+ * {@link #bracket(UnivariateFunction, double, double, double,
+ * double, int) bracket(function, initial, lowerBound, upperBound, delta, maximumIterations)},
* explicitly specifying the maximum number of iterations.
lowerBound <= a < initial < b <= upperBound
f(a) * f(b) <= 0
[a,b],
this means that a
- * and b
bracket a root of f.
- *
- * The algorithm starts by setting
- * a := initial -1; b := initial +1,
examines the value of the
- * function at a
and b
and keeps moving
- * the endpoints out by one unit each time through a loop that terminates
- * when one of the following happens:
f(a) * f(b) <= 0
-- success! a = lower
and b = upper
- * -- NoBracketingException maximumIterations
iterations elapse
- * -- NoBracketingException + * The algorithm checks the sign of \( f(l_k) \) and \( f(u_k) \) for increasing + * values of k, where \( l_k = max(lower, initial - \delta_k) \), + * \( u_k = min(upper, initial + \delta_k) \), using recurrence + * \( \delta_{k+1} = r \delta_k + q, \delta_0 = 0\) and starting search with \( k=1 \). + * The algorithm stops when one of the following happens:
+ * If different signs are found at first iteration ({@code k=1}), then the returned + * interval will be \( [a, b] = [l_1, u_1] \). If different signs are found at a later + * iteration ({code k>1}, then the returned interval will be either + * \( [a, b] = [l_{k+1}, l_{k}] \) or ( [a, b] = [u_{k}, u_{k+1}] \). A root solver called + * with these parameters will therefore start with the smallest bracketing interval known + * at this step. + *
+ *+ * Interval expansion rate is tuned by changing the recurrence parameters {@code r} and + * {@code q}. When the multiplicative factor {@code r} is set to 1, the sequence is a + * simple arithmetic sequence with linear increase. When the multiplicative factor {@code r} + * is larger than 1, the sequence has an asymtotically exponential rate. Note than the + * additive parameter {@code q} should never be set to zero, otherwise the interval would + * degenerate to the single initial point for all values of {@code k}. + *
+ *+ * As a rule of thumb, when the location of the root is expected to be approximately known + * within some error margin, {@code r} should be set to 1 and {@code q} should be set to the + * order of magnitude of the error margin. When the location of the root is really a wild guess, + * then {@code r} should be set to a value larger than 1 (typically 2 to double the interval + * length at each iteration) and {@code q} should be set according to half the initial + * search interval length. + *
+ *+ * As an example, if we consider the trivial function {@code f(x) = 1 - x} and use + * {@code initial = 4}, {@code r = 1}, {@code q = 2}, the algorithm will compute + * {@code f(4-2) = f(2) = -1} and {@code f(4+2) = f(6) = -5} for {@code k = 1}, then + * {@code f(4-4) = f(0) = +1} and {@code f(4+4) = f(8) = -7} for {@code k = 2}. Then it will + * return the interval {@code [0, 2]} as the smallest one known to be bracketing the root. + * As shown by this example, the initial value (here {@code 4}) may lie outside of the returned + * bracketing interval. + *
+ * @param function function to check + * @param initial Initial midpoint of interval being expanded to + * bracket a root. + * @param lowerBound Lower bound (a is never lower than this value). + * @param upperBound Upper bound (b never is greater than this + * value). + * @param q additive offset used to compute bounds sequence (must be strictly positive) + * @param r multiplicative factor used to compute bounds sequence + * @param maximumIterations Maximum number of iterations to perform + * @return a two element array holding the bracketing values. + * @exception NoBracketingException if function cannot be bracketed in the search interval + */ + public static double[] bracket(final UnivariateFunction function, final double initial, + final double lowerBound, final double upperBound, + final double q, final double r, final int maximumIterations) + throws NoBracketingException { + if (function == null) { throw new NullArgumentException(LocalizedFormats.FUNCTION); } + if (q <= 0) { + throw new NotStrictlyPositiveException(q); + } if (maximumIterations <= 0) { throw new NotStrictlyPositiveException(LocalizedFormats.INVALID_MAX_ITERATIONS, maximumIterations); } verifySequence(lowerBound, initial, upperBound); - double a = initial; - double b = initial; - double fa; - double fb; - int numIterations = 0; + // initialize the recurrence + double a = initial; + double b = initial; + double fa = Double.NaN; + double fb = Double.NaN; + double delta = 0; - do { - a = FastMath.max(a - 1.0, lowerBound); - b = FastMath.min(b + 1.0, upperBound); - fa = function.value(a); + for (int numIterations = 0; + (numIterations < maximumIterations) && (a > lowerBound || b > upperBound); + ++numIterations) { - fb = function.value(b); - ++numIterations; - } while ((fa * fb > 0.0) && (numIterations < maximumIterations) && - ((a > lowerBound) || (b < upperBound))); + final double previousA = a; + final double previousFa = fa; + final double previousB = b; + final double previousFb = fb; + + delta = r * delta + q; + a = FastMath.max(initial - delta, lowerBound); + b = FastMath.min(initial + delta, upperBound); + fa = function.value(a); + fb = function.value(b); + + if (numIterations == 0) { + // at first iteration, we don't have a previous interval + // we simply compare both sides of the initial interval + if (fa * fb <= 0) { + // the first interval already brackets a root + return new double[] { a, b }; + } + } else { + // we have a previous interval with constant sign and expand it, + // we expect sign changes to occur at boundaries + if (fa * previousFa <= 0) { + // sign change detected at near lower bound + return new double[] { a, previousA }; + } else if (fb * previousFb <= 0) { + // sign change detected at near upper bound + return new double[] { previousB, b }; + } + } - if (fa * fb > 0.0) { - throw new NoBracketingException(LocalizedFormats.FAILED_BRACKETING, - a, b, fa, fb, - numIterations, maximumIterations, initial, - lowerBound, upperBound); } - return new double[] {a, b}; + // no bracketing found + throw new NoBracketingException(a, b, fa, fb); + } /** diff --git a/src/test/java/org/apache/commons/math3/analysis/solvers/UnivariateSolverUtilsTest.java b/src/test/java/org/apache/commons/math3/analysis/solvers/UnivariateSolverUtilsTest.java index e37522f68..83b76e0dc 100644 --- a/src/test/java/org/apache/commons/math3/analysis/solvers/UnivariateSolverUtilsTest.java +++ b/src/test/java/org/apache/commons/math3/analysis/solvers/UnivariateSolverUtilsTest.java @@ -21,6 +21,7 @@ import org.apache.commons.math3.analysis.QuinticFunction; import org.apache.commons.math3.analysis.UnivariateFunction; import org.apache.commons.math3.analysis.function.Sin; import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NoBracketingException; import org.apache.commons.math3.util.FastMath; import org.junit.Assert; import org.junit.Test; @@ -86,6 +87,56 @@ public class UnivariateSolverUtilsTest { Assert.assertTrue(sin.value(result[1]) > 0); } + @Test + public void testBracketCentered() { + double initial = 0.1; + double[] result = UnivariateSolverUtils.bracket(sin, initial, -2.0, 2.0, 0.2, 1.0, 100); + Assert.assertTrue(result[0] < initial); + Assert.assertTrue(result[1] > initial); + Assert.assertTrue(sin.value(result[0]) < 0); + Assert.assertTrue(sin.value(result[1]) > 0); + } + + @Test + public void testBracketLow() { + double initial = 0.5; + double[] result = UnivariateSolverUtils.bracket(sin, initial, -2.0, 2.0, 0.2, 1.0, 100); + Assert.assertTrue(result[0] < initial); + Assert.assertTrue(result[1] < initial); + Assert.assertTrue(sin.value(result[0]) < 0); + Assert.assertTrue(sin.value(result[1]) > 0); + } + + @Test + public void testBracketHigh(){ + double initial = -0.5; + double[] result = UnivariateSolverUtils.bracket(sin, initial, -2.0, 2.0, 0.2, 1.0, 100); + Assert.assertTrue(result[0] > initial); + Assert.assertTrue(result[1] > initial); + Assert.assertTrue(sin.value(result[0]) < 0); + Assert.assertTrue(sin.value(result[1]) > 0); + } + + @Test(expected=NoBracketingException.class) + public void testBracketLinear(){ + UnivariateSolverUtils.bracket(new UnivariateFunction() { + public double value(double x) { + return 1 - x; + } + }, 1000, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0, 1.0, 100); + } + + @Test + public void testBracketExponential(){ + double[] result = UnivariateSolverUtils.bracket(new UnivariateFunction() { + public double value(double x) { + return 1 - x; + } + }, 1000, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0, 2.0, 10); + Assert.assertTrue(result[0] <= 1); + Assert.assertTrue(result[1] >= 1); + } + @Test public void testBracketEndpointRoot() { double[] result = UnivariateSolverUtils.bracket(sin, 1.5, 0, 2.0); @@ -97,18 +148,28 @@ public class UnivariateSolverUtilsTest { public void testNullFunction() { UnivariateSolverUtils.bracket(null, 1.5, 0, 2.0); } - + @Test(expected=MathIllegalArgumentException.class) public void testBadInitial() { UnivariateSolverUtils.bracket(sin, 2.5, 0, 2.0); } - + + @Test(expected=MathIllegalArgumentException.class) + public void testBadAdditive() { + UnivariateSolverUtils.bracket(sin, 1.0, -2.0, 3.0, -1.0, 1.0, 100); + } + + @Test(expected=NoBracketingException.class) + public void testIterationExceeded() { + UnivariateSolverUtils.bracket(sin, 1.0, -2.0, 3.0, 1.0e-5, 1.0, 100); + } + @Test(expected=MathIllegalArgumentException.class) public void testBadEndpoints() { // endpoints not valid UnivariateSolverUtils.bracket(sin, 1.5, 2.0, 1.0); } - + @Test(expected=MathIllegalArgumentException.class) public void testBadMaximumIterations() { // bad maximum iterations