diff --git a/src/main/java/org/apache/commons/math/optimization/direct/MultiDirectional.java b/src/main/java/org/apache/commons/math/optimization/direct/MultiDirectional.java index 955d0d837..929560c36 100644 --- a/src/main/java/org/apache/commons/math/optimization/direct/MultiDirectional.java +++ b/src/main/java/org/apache/commons/math/optimization/direct/MultiDirectional.java @@ -21,6 +21,7 @@ import java.util.Comparator; import org.apache.commons.math.FunctionEvaluationException; import org.apache.commons.math.optimization.OptimizationException; +import org.apache.commons.math.optimization.RealConvergenceChecker; import org.apache.commons.math.optimization.RealPointValuePair; /** @@ -60,6 +61,7 @@ public class MultiDirectional extends DirectSearchOptimizer { protected void iterateSimplex(final Comparator comparator) throws FunctionEvaluationException, OptimizationException, IllegalArgumentException { + final RealConvergenceChecker checker = getConvergenceChecker(); while (true) { incrementIterationsCounter(); @@ -91,6 +93,16 @@ public class MultiDirectional extends DirectSearchOptimizer { return; } + // check convergence + final int iter = getIterations(); + boolean converged = true; + for (int i = 0; i < simplex.length; ++i) { + converged &= checker.converged(iter, original[i], simplex[i]); + } + if (converged) { + return; + } + } } diff --git a/src/site/xdoc/changes.xml b/src/site/xdoc/changes.xml index 5cdc07df4..504b8e79f 100644 --- a/src/site/xdoc/changes.xml +++ b/src/site/xdoc/changes.xml @@ -38,6 +38,12 @@ The type attribute can be add,update,fix,remove. Commons Math Release Notes + + + Prevent infinite loops in multi-directional direct optimization method when + the start point is exactly at the optimal point + + 120); - assertTrue(optimizer.getEvaluations() < 150); + Assert.assertEquals(xM, optimum.getPoint()[0], 4.0e-6); + Assert.assertEquals(yP, optimum.getPoint()[1], 3.0e-6); + Assert.assertEquals(valueXmYp, optimum.getValue(), 8.0e-13); + Assert.assertTrue(optimizer.getEvaluations() > 120); + Assert.assertTrue(optimizer.getEvaluations() < 150); optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { +1, 0 }); - assertEquals(xP, optimum.getPoint()[0], 2.0e-8); - assertEquals(yM, optimum.getPoint()[1], 3.0e-6); - assertEquals(valueXpYm, optimum.getValue(), 2.0e-12); - assertTrue(optimizer.getEvaluations() > 120); - assertTrue(optimizer.getEvaluations() < 150); + Assert.assertEquals(xP, optimum.getPoint()[0], 2.0e-8); + Assert.assertEquals(yM, optimum.getPoint()[1], 3.0e-6); + Assert.assertEquals(valueXpYm, optimum.getValue(), 2.0e-12); + Assert.assertTrue(optimizer.getEvaluations() > 120); + Assert.assertTrue(optimizer.getEvaluations() < 150); // maximization optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { -3.0, 0.0 }); - assertEquals(xM, optimum.getPoint()[0], 7.0e-7); - assertEquals(yM, optimum.getPoint()[1], 3.0e-7); - assertEquals(valueXmYm, optimum.getValue(), 2.0e-14); - assertTrue(optimizer.getEvaluations() > 120); - assertTrue(optimizer.getEvaluations() < 150); + Assert.assertEquals(xM, optimum.getPoint()[0], 7.0e-7); + Assert.assertEquals(yM, optimum.getPoint()[1], 3.0e-7); + Assert.assertEquals(valueXmYm, optimum.getValue(), 2.0e-14); + Assert.assertTrue(optimizer.getEvaluations() > 120); + Assert.assertTrue(optimizer.getEvaluations() < 150); + optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-15, 1.0e-30)); optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { +1, 0 }); - assertEquals(xP, optimum.getPoint()[0], 2.0e-8); - assertEquals(yP, optimum.getPoint()[1], 3.0e-6); - assertEquals(valueXpYp, optimum.getValue(), 2.0e-12); - assertTrue(optimizer.getEvaluations() > 120); - assertTrue(optimizer.getEvaluations() < 150); + Assert.assertEquals(xP, optimum.getPoint()[0], 2.0e-8); + Assert.assertEquals(yP, optimum.getPoint()[1], 3.0e-6); + Assert.assertEquals(valueXpYp, optimum.getValue(), 2.0e-12); + Assert.assertTrue(optimizer.getEvaluations() > 180); + Assert.assertTrue(optimizer.getEvaluations() < 220); } + @Test public void testRosenbrock() throws FunctionEvaluationException, ConvergenceException { @@ -154,13 +152,14 @@ public class MultiDirectionalTest RealPointValuePair optimum = optimizer.optimize(rosenbrock, GoalType.MINIMIZE, new double[] { -1.2, 1.0 }); - assertEquals(count, optimizer.getEvaluations()); - assertTrue(optimizer.getEvaluations() > 70); - assertTrue(optimizer.getEvaluations() < 100); - assertTrue(optimum.getValue() > 1.0e-2); + Assert.assertEquals(count, optimizer.getEvaluations()); + Assert.assertTrue(optimizer.getEvaluations() > 50); + Assert.assertTrue(optimizer.getEvaluations() < 100); + Assert.assertTrue(optimum.getValue() > 1.0e-2); } + @Test public void testPowell() throws FunctionEvaluationException, ConvergenceException { @@ -183,15 +182,64 @@ public class MultiDirectionalTest optimizer.setMaxIterations(1000); RealPointValuePair optimum = optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 }); - assertEquals(count, optimizer.getEvaluations()); - assertTrue(optimizer.getEvaluations() > 800); - assertTrue(optimizer.getEvaluations() < 900); - assertTrue(optimum.getValue() > 1.0e-2); + Assert.assertEquals(count, optimizer.getEvaluations()); + Assert.assertTrue(optimizer.getEvaluations() > 800); + Assert.assertTrue(optimizer.getEvaluations() < 900); + Assert.assertTrue(optimum.getValue() > 1.0e-2); } - public static Test suite() { - return new TestSuite(MultiDirectionalTest.class); + @Test + public void testMath283() + throws FunctionEvaluationException, OptimizationException { + // fails because MultiDirectional.iterateSimplex is looping forever + // the while(true) should be replaced with a convergence check + MultiDirectional multiDirectional = new MultiDirectional(); + multiDirectional.setMaxIterations(100); + multiDirectional.setMaxEvaluations(1000); + + final Gaussian2D function = new Gaussian2D(0.0, 0.0, 1.0); + + RealPointValuePair estimate = multiDirectional.optimize(function, + GoalType.MAXIMIZE, function.getMaximumPosition()); + + final double EPSILON = 1e-5; + + final double expectedMaximum = function.getMaximum(); + final double actualMaximum = estimate.getValue(); + Assert.assertEquals(expectedMaximum, actualMaximum, EPSILON); + + final double[] expectedPosition = function.getMaximumPosition(); + final double[] actualPosition = estimate.getPoint(); + Assert.assertEquals(expectedPosition[0], actualPosition[0], EPSILON ); + Assert.assertEquals(expectedPosition[1], actualPosition[1], EPSILON ); + + } + + private static class Gaussian2D implements MultivariateRealFunction { + + private final double[] maximumPosition; + + private final double std; + + public Gaussian2D(double xOpt, double yOpt, double std) { + maximumPosition = new double[] { xOpt, yOpt }; + this.std = std; + } + + public double getMaximum() { + return value(maximumPosition); + } + + public double[] getMaximumPosition() { + return maximumPosition.clone(); + } + + public double value(double[] point) { + final double x = point[0], y = point[1]; + final double twoS2 = 2.0 * std * std; + return 1.0 / (twoS2 * Math.PI) * Math.exp(-(x * x + y * y) / twoS2); + } } private int count;