Improved brackting utility for univariate solvers.

Bracketing utility for univariate root solvers now returns a tighter
interval than before. It also allows choosing the search interval
expansion rate, supporting both linear and asymptotically exponential
rates.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1578428 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2014-03-17 15:14:07 +00:00
parent 5929846a20
commit c5ae09d77e
3 changed files with 191 additions and 61 deletions

View File

@ -51,6 +51,11 @@ If the output is not quite correct, check for invisible trailing spaces!
</properties> </properties>
<body> <body>
<release version="3.3" date="TBD" description="TBD"> <release version="3.3" date="TBD" description="TBD">
<action dev="luc" type="update" >
Bracketing utility for univariate root solvers returns a tighter interval than before.
It also allows choosing the search interval expansion rate, supporting both linear
and asymptotically exponential rates.
</action>
<action dev="luc" type="fix" issue="MATH-1107" due-to="Bruce A Johnson"> <action dev="luc" type="fix" issue="MATH-1107" due-to="Bruce A Johnson">
Prevent penalties to grow multiplicatively in CMAES for out of bounds points. Prevent penalties to grow multiplicatively in CMAES for out of bounds points.
</action> </action>

View File

@ -171,31 +171,16 @@ public class UnivariateSolverUtils {
} }
/** /**
* This method attempts to find two values a and b satisfying <ul> * This method simply calls {@link #bracket(UnivariateFunction, double, double, double,
* <li> <code> lowerBound <= a < initial < b <= upperBound</code> </li> * double, double, int) bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)}
* <li> <code> f(a) * f(b) < 0 </code></li> * with {@code q} and {@code r} set to 1.0 and {@code maximumIterations} set to {@code Integer.MAX_VALUE}.
* </ul>
* If f is continuous on <code>[a,b],</code> this means that <code>a</code>
* and <code>b</code> bracket a root of f.
* <p>
* The algorithm starts by setting
* <code>a := initial -1; b := initial +1,</code> examines the value of the
* function at <code>a</code> and <code>b</code> and keeps moving
* the endpoints out by one unit each time through a loop that terminates
* when one of the following happens: <ul>
* <li> <code> f(a) * f(b) < 0 </code> -- success!</li>
* <li> <code> a = lower </code> and <code> b = upper</code>
* -- NoBracketingException </li>
* <li> <code> Integer.MAX_VALUE</code> iterations elapse
* -- NoBracketingException </li>
* </ul></p>
* <p>
* <strong>Note: </strong> this method can take * <strong>Note: </strong> this method can take
* <code>Integer.MAX_VALUE</code> iterations to throw a * <code>Integer.MAX_VALUE</code> iterations to throw a
* <code>ConvergenceException.</code> Unless you are confident that there * <code>ConvergenceException.</code> Unless you are confident that there
* is a root between <code>lowerBound</code> and <code>upperBound</code> * is a root between <code>lowerBound</code> and <code>upperBound</code>
* near <code>initial,</code> it is better to use * near <code>initial,</code> it is better to use
* {@link #bracket(UnivariateFunction, double, double, double, int)}, * {@link #bracket(UnivariateFunction, double, double, double,
* double, int) bracket(function, initial, lowerBound, upperBound, delta, maximumIterations)},
* explicitly specifying the maximum number of iterations.</p> * explicitly specifying the maximum number of iterations.</p>
* *
* @param function Function. * @param function Function.
@ -215,28 +200,13 @@ public class UnivariateSolverUtils {
throws NullArgumentException, throws NullArgumentException,
NotStrictlyPositiveException, NotStrictlyPositiveException,
NoBracketingException { NoBracketingException {
return bracket(function, initial, lowerBound, upperBound, Integer.MAX_VALUE); return bracket(function, initial, lowerBound, upperBound, 1.0, 1.0, Integer.MAX_VALUE);
} }
/** /**
* This method attempts to find two values a and b satisfying <ul> * This method simply calls {@link #bracket(UnivariateFunction, double, double, double,
* <li> <code> lowerBound <= a < initial < b <= upperBound</code> </li> * double, double, int) bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)}
* <li> <code> f(a) * f(b) <= 0 </code> </li> * with {@code q} and {@code r} set to 1.0.
* </ul>
* If f is continuous on <code>[a,b],</code> this means that <code>a</code>
* and <code>b</code> bracket a root of f.
* <p>
* The algorithm starts by setting
* <code>a := initial -1; b := initial +1,</code> examines the value of the
* function at <code>a</code> and <code>b</code> and keeps moving
* the endpoints out by one unit each time through a loop that terminates
* when one of the following happens: <ul>
* <li> <code> f(a) * f(b) <= 0 </code> -- success!</li>
* <li> <code> a = lower </code> and <code> b = upper</code>
* -- NoBracketingException </li>
* <li> <code> maximumIterations</code> iterations elapse
* -- NoBracketingException </li></ul></p>
*
* @param function Function. * @param function Function.
* @param initial Initial midpoint of interval being expanded to * @param initial Initial midpoint of interval being expanded to
* bracket a root. * bracket a root.
@ -257,38 +227,132 @@ public class UnivariateSolverUtils {
throws NullArgumentException, throws NullArgumentException,
NotStrictlyPositiveException, NotStrictlyPositiveException,
NoBracketingException { NoBracketingException {
return bracket(function, initial, lowerBound, upperBound, 1.0, 1.0, maximumIterations);
}
/**
* This method attempts to find two values a and b satisfying <ul>
* <li> {@code lowerBound <= a < initial < b <= upperBound} </li>
* <li> {@code f(a) * f(b) <= 0} </li>
* </ul>
* If {@code f} is continuous on {@code [a,b]}, this means that {@code a}
* and {@code b} bracket a root of {@code f}.
* <p>
* The algorithm checks the sign of \( f(l_k) \) and \( f(u_k) \) for increasing
* values of k, where \( l_k = max(lower, initial - \delta_k) \),
* \( u_k = min(upper, initial + \delta_k) \), using recurrence
* \( \delta_{k+1} = r \delta_k + q, \delta_0 = 0\) and starting search with \( k=1 \).
* The algorithm stops when one of the following happens: <ul>
* <li> at least one positive and one negative value have been found -- success!</li>
* <li> both endpoints have reached their respective limites -- NoBracketingException </li>
* <li> {@code maximumIterations} iterations elapse -- NoBracketingException </li></ul></p>
* <p>
* If different signs are found at first iteration ({@code k=1}), then the returned
* interval will be \( [a, b] = [l_1, u_1] \). If different signs are found at a later
* iteration ({code k>1}, then the returned interval will be either
* \( [a, b] = [l_{k+1}, l_{k}] \) or ( [a, b] = [u_{k}, u_{k+1}] \). A root solver called
* with these parameters will therefore start with the smallest bracketing interval known
* at this step.
* </p>
* <p>
* Interval expansion rate is tuned by changing the recurrence parameters {@code r} and
* {@code q}. When the multiplicative factor {@code r} is set to 1, the sequence is a
* simple arithmetic sequence with linear increase. When the multiplicative factor {@code r}
* is larger than 1, the sequence has an asymtotically exponential rate. Note than the
* additive parameter {@code q} should never be set to zero, otherwise the interval would
* degenerate to the single initial point for all values of {@code k}.
* </p>
* <p>
* As a rule of thumb, when the location of the root is expected to be approximately known
* within some error margin, {@code r} should be set to 1 and {@code q} should be set to the
* order of magnitude of the error margin. When the location of the root is really a wild guess,
* then {@code r} should be set to a value larger than 1 (typically 2 to double the interval
* length at each iteration) and {@code q} should be set according to half the initial
* search interval length.
* </p>
* <p>
* As an example, if we consider the trivial function {@code f(x) = 1 - x} and use
* {@code initial = 4}, {@code r = 1}, {@code q = 2}, the algorithm will compute
* {@code f(4-2) = f(2) = -1} and {@code f(4+2) = f(6) = -5} for {@code k = 1}, then
* {@code f(4-4) = f(0) = +1} and {@code f(4+4) = f(8) = -7} for {@code k = 2}. Then it will
* return the interval {@code [0, 2]} as the smallest one known to be bracketing the root.
* As shown by this example, the initial value (here {@code 4}) may lie outside of the returned
* bracketing interval.
* </p>
* @param function function to check
* @param initial Initial midpoint of interval being expanded to
* bracket a root.
* @param lowerBound Lower bound (a is never lower than this value).
* @param upperBound Upper bound (b never is greater than this
* value).
* @param q additive offset used to compute bounds sequence (must be strictly positive)
* @param r multiplicative factor used to compute bounds sequence
* @param maximumIterations Maximum number of iterations to perform
* @return a two element array holding the bracketing values.
* @exception NoBracketingException if function cannot be bracketed in the search interval
*/
public static double[] bracket(final UnivariateFunction function, final double initial,
final double lowerBound, final double upperBound,
final double q, final double r, final int maximumIterations)
throws NoBracketingException {
if (function == null) { if (function == null) {
throw new NullArgumentException(LocalizedFormats.FUNCTION); throw new NullArgumentException(LocalizedFormats.FUNCTION);
} }
if (q <= 0) {
throw new NotStrictlyPositiveException(q);
}
if (maximumIterations <= 0) { if (maximumIterations <= 0) {
throw new NotStrictlyPositiveException(LocalizedFormats.INVALID_MAX_ITERATIONS, maximumIterations); throw new NotStrictlyPositiveException(LocalizedFormats.INVALID_MAX_ITERATIONS, maximumIterations);
} }
verifySequence(lowerBound, initial, upperBound); verifySequence(lowerBound, initial, upperBound);
double a = initial; // initialize the recurrence
double b = initial; double a = initial;
double fa; double b = initial;
double fb; double fa = Double.NaN;
int numIterations = 0; double fb = Double.NaN;
double delta = 0;
do { for (int numIterations = 0;
a = FastMath.max(a - 1.0, lowerBound); (numIterations < maximumIterations) && (a > lowerBound || b > upperBound);
b = FastMath.min(b + 1.0, upperBound); ++numIterations) {
fa = function.value(a);
fb = function.value(b); final double previousA = a;
++numIterations; final double previousFa = fa;
} while ((fa * fb > 0.0) && (numIterations < maximumIterations) && final double previousB = b;
((a > lowerBound) || (b < upperBound))); final double previousFb = fb;
delta = r * delta + q;
a = FastMath.max(initial - delta, lowerBound);
b = FastMath.min(initial + delta, upperBound);
fa = function.value(a);
fb = function.value(b);
if (numIterations == 0) {
// at first iteration, we don't have a previous interval
// we simply compare both sides of the initial interval
if (fa * fb <= 0) {
// the first interval already brackets a root
return new double[] { a, b };
}
} else {
// we have a previous interval with constant sign and expand it,
// we expect sign changes to occur at boundaries
if (fa * previousFa <= 0) {
// sign change detected at near lower bound
return new double[] { a, previousA };
} else if (fb * previousFb <= 0) {
// sign change detected at near upper bound
return new double[] { previousB, b };
}
}
if (fa * fb > 0.0) {
throw new NoBracketingException(LocalizedFormats.FAILED_BRACKETING,
a, b, fa, fb,
numIterations, maximumIterations, initial,
lowerBound, upperBound);
} }
return new double[] {a, b}; // no bracketing found
throw new NoBracketingException(a, b, fa, fb);
} }
/** /**

View File

@ -21,6 +21,7 @@ import org.apache.commons.math3.analysis.QuinticFunction;
import org.apache.commons.math3.analysis.UnivariateFunction; import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.function.Sin; import org.apache.commons.math3.analysis.function.Sin;
import org.apache.commons.math3.exception.MathIllegalArgumentException; import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.exception.NoBracketingException;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
@ -86,6 +87,56 @@ public class UnivariateSolverUtilsTest {
Assert.assertTrue(sin.value(result[1]) > 0); Assert.assertTrue(sin.value(result[1]) > 0);
} }
@Test
public void testBracketCentered() {
double initial = 0.1;
double[] result = UnivariateSolverUtils.bracket(sin, initial, -2.0, 2.0, 0.2, 1.0, 100);
Assert.assertTrue(result[0] < initial);
Assert.assertTrue(result[1] > initial);
Assert.assertTrue(sin.value(result[0]) < 0);
Assert.assertTrue(sin.value(result[1]) > 0);
}
@Test
public void testBracketLow() {
double initial = 0.5;
double[] result = UnivariateSolverUtils.bracket(sin, initial, -2.0, 2.0, 0.2, 1.0, 100);
Assert.assertTrue(result[0] < initial);
Assert.assertTrue(result[1] < initial);
Assert.assertTrue(sin.value(result[0]) < 0);
Assert.assertTrue(sin.value(result[1]) > 0);
}
@Test
public void testBracketHigh(){
double initial = -0.5;
double[] result = UnivariateSolverUtils.bracket(sin, initial, -2.0, 2.0, 0.2, 1.0, 100);
Assert.assertTrue(result[0] > initial);
Assert.assertTrue(result[1] > initial);
Assert.assertTrue(sin.value(result[0]) < 0);
Assert.assertTrue(sin.value(result[1]) > 0);
}
@Test(expected=NoBracketingException.class)
public void testBracketLinear(){
UnivariateSolverUtils.bracket(new UnivariateFunction() {
public double value(double x) {
return 1 - x;
}
}, 1000, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0, 1.0, 100);
}
@Test
public void testBracketExponential(){
double[] result = UnivariateSolverUtils.bracket(new UnivariateFunction() {
public double value(double x) {
return 1 - x;
}
}, 1000, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0, 2.0, 10);
Assert.assertTrue(result[0] <= 1);
Assert.assertTrue(result[1] >= 1);
}
@Test @Test
public void testBracketEndpointRoot() { public void testBracketEndpointRoot() {
double[] result = UnivariateSolverUtils.bracket(sin, 1.5, 0, 2.0); double[] result = UnivariateSolverUtils.bracket(sin, 1.5, 0, 2.0);
@ -97,18 +148,28 @@ public class UnivariateSolverUtilsTest {
public void testNullFunction() { public void testNullFunction() {
UnivariateSolverUtils.bracket(null, 1.5, 0, 2.0); UnivariateSolverUtils.bracket(null, 1.5, 0, 2.0);
} }
@Test(expected=MathIllegalArgumentException.class) @Test(expected=MathIllegalArgumentException.class)
public void testBadInitial() { public void testBadInitial() {
UnivariateSolverUtils.bracket(sin, 2.5, 0, 2.0); UnivariateSolverUtils.bracket(sin, 2.5, 0, 2.0);
} }
@Test(expected=MathIllegalArgumentException.class)
public void testBadAdditive() {
UnivariateSolverUtils.bracket(sin, 1.0, -2.0, 3.0, -1.0, 1.0, 100);
}
@Test(expected=NoBracketingException.class)
public void testIterationExceeded() {
UnivariateSolverUtils.bracket(sin, 1.0, -2.0, 3.0, 1.0e-5, 1.0, 100);
}
@Test(expected=MathIllegalArgumentException.class) @Test(expected=MathIllegalArgumentException.class)
public void testBadEndpoints() { public void testBadEndpoints() {
// endpoints not valid // endpoints not valid
UnivariateSolverUtils.bracket(sin, 1.5, 2.0, 1.0); UnivariateSolverUtils.bracket(sin, 1.5, 2.0, 1.0);
} }
@Test(expected=MathIllegalArgumentException.class) @Test(expected=MathIllegalArgumentException.class)
public void testBadMaximumIterations() { public void testBadMaximumIterations() {
// bad maximum iterations // bad maximum iterations