Prevent infinite loops in multi-directional direct optimization method when the start point is exactly at the optimal point
JIRA: MATH-283 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@804328 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
9cb0ca6b0f
commit
8fe6a83eb6
|
@ -21,6 +21,7 @@ import java.util.Comparator;
|
||||||
|
|
||||||
import org.apache.commons.math.FunctionEvaluationException;
|
import org.apache.commons.math.FunctionEvaluationException;
|
||||||
import org.apache.commons.math.optimization.OptimizationException;
|
import org.apache.commons.math.optimization.OptimizationException;
|
||||||
|
import org.apache.commons.math.optimization.RealConvergenceChecker;
|
||||||
import org.apache.commons.math.optimization.RealPointValuePair;
|
import org.apache.commons.math.optimization.RealPointValuePair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -60,6 +61,7 @@ public class MultiDirectional extends DirectSearchOptimizer {
|
||||||
protected void iterateSimplex(final Comparator<RealPointValuePair> comparator)
|
protected void iterateSimplex(final Comparator<RealPointValuePair> comparator)
|
||||||
throws FunctionEvaluationException, OptimizationException, IllegalArgumentException {
|
throws FunctionEvaluationException, OptimizationException, IllegalArgumentException {
|
||||||
|
|
||||||
|
final RealConvergenceChecker checker = getConvergenceChecker();
|
||||||
while (true) {
|
while (true) {
|
||||||
|
|
||||||
incrementIterationsCounter();
|
incrementIterationsCounter();
|
||||||
|
@ -91,6 +93,16 @@ public class MultiDirectional extends DirectSearchOptimizer {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check convergence
|
||||||
|
final int iter = getIterations();
|
||||||
|
boolean converged = true;
|
||||||
|
for (int i = 0; i < simplex.length; ++i) {
|
||||||
|
converged &= checker.converged(iter, original[i], simplex[i]);
|
||||||
|
}
|
||||||
|
if (converged) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,12 @@ The <action> type attribute can be add,update,fix,remove.
|
||||||
<title>Commons Math Release Notes</title>
|
<title>Commons Math Release Notes</title>
|
||||||
</properties>
|
</properties>
|
||||||
<body>
|
<body>
|
||||||
|
<release version="2.1" date="TBD" description="TBD">
|
||||||
|
<action dev="luc" type="fix" issue="MATH-283" due-to="Michael Nischt">
|
||||||
|
Prevent infinite loops in multi-directional direct optimization method when
|
||||||
|
the start point is exactly at the optimal point
|
||||||
|
</action>
|
||||||
|
</release>
|
||||||
<release version="2.0" date="2009-08-03" description="
|
<release version="2.0" date="2009-08-03" description="
|
||||||
This is a major release. It combines bug fixes, new features and
|
This is a major release. It combines bug fixes, new features and
|
||||||
changes to existing features. Most notable among the new features are:
|
changes to existing features. Most notable among the new features are:
|
||||||
|
|
|
@ -17,24 +17,19 @@
|
||||||
|
|
||||||
package org.apache.commons.math.optimization.direct;
|
package org.apache.commons.math.optimization.direct;
|
||||||
|
|
||||||
import junit.framework.Test;
|
|
||||||
import junit.framework.TestCase;
|
|
||||||
import junit.framework.TestSuite;
|
|
||||||
|
|
||||||
import org.apache.commons.math.ConvergenceException;
|
import org.apache.commons.math.ConvergenceException;
|
||||||
import org.apache.commons.math.FunctionEvaluationException;
|
import org.apache.commons.math.FunctionEvaluationException;
|
||||||
import org.apache.commons.math.analysis.MultivariateRealFunction;
|
import org.apache.commons.math.analysis.MultivariateRealFunction;
|
||||||
import org.apache.commons.math.optimization.GoalType;
|
import org.apache.commons.math.optimization.GoalType;
|
||||||
|
import org.apache.commons.math.optimization.OptimizationException;
|
||||||
import org.apache.commons.math.optimization.RealPointValuePair;
|
import org.apache.commons.math.optimization.RealPointValuePair;
|
||||||
import org.apache.commons.math.optimization.SimpleScalarValueChecker;
|
import org.apache.commons.math.optimization.SimpleScalarValueChecker;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
public class MultiDirectionalTest
|
public class MultiDirectionalTest {
|
||||||
extends TestCase {
|
|
||||||
|
|
||||||
public MultiDirectionalTest(String name) {
|
|
||||||
super(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@Test
|
||||||
public void testFunctionEvaluationExceptions() {
|
public void testFunctionEvaluationExceptions() {
|
||||||
MultivariateRealFunction wrong =
|
MultivariateRealFunction wrong =
|
||||||
new MultivariateRealFunction() {
|
new MultivariateRealFunction() {
|
||||||
|
@ -52,25 +47,26 @@ public class MultiDirectionalTest
|
||||||
try {
|
try {
|
||||||
MultiDirectional optimizer = new MultiDirectional(0.9, 1.9);
|
MultiDirectional optimizer = new MultiDirectional(0.9, 1.9);
|
||||||
optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { -1.0 });
|
optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { -1.0 });
|
||||||
fail("an exception should have been thrown");
|
Assert.fail("an exception should have been thrown");
|
||||||
} catch (FunctionEvaluationException ce) {
|
} catch (FunctionEvaluationException ce) {
|
||||||
// expected behavior
|
// expected behavior
|
||||||
assertNull(ce.getCause());
|
Assert.assertNull(ce.getCause());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
fail("wrong exception caught: " + e.getMessage());
|
Assert.fail("wrong exception caught: " + e.getMessage());
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
MultiDirectional optimizer = new MultiDirectional(0.9, 1.9);
|
MultiDirectional optimizer = new MultiDirectional(0.9, 1.9);
|
||||||
optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { +2.0 });
|
optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { +2.0 });
|
||||||
fail("an exception should have been thrown");
|
Assert.fail("an exception should have been thrown");
|
||||||
} catch (FunctionEvaluationException ce) {
|
} catch (FunctionEvaluationException ce) {
|
||||||
// expected behavior
|
// expected behavior
|
||||||
assertNotNull(ce.getCause());
|
Assert.assertNotNull(ce.getCause());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
fail("wrong exception caught: " + e.getMessage());
|
Assert.fail("wrong exception caught: " + e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
public void testMinimizeMaximize()
|
public void testMinimizeMaximize()
|
||||||
throws FunctionEvaluationException, ConvergenceException {
|
throws FunctionEvaluationException, ConvergenceException {
|
||||||
|
|
||||||
|
@ -93,43 +89,45 @@ public class MultiDirectionalTest
|
||||||
};
|
};
|
||||||
|
|
||||||
MultiDirectional optimizer = new MultiDirectional();
|
MultiDirectional optimizer = new MultiDirectional();
|
||||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-10, 1.0e-30));
|
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-11, 1.0e-30));
|
||||||
optimizer.setMaxIterations(200);
|
optimizer.setMaxIterations(200);
|
||||||
optimizer.setStartConfiguration(new double[] { 0.2, 0.2 });
|
optimizer.setStartConfiguration(new double[] { 0.2, 0.2 });
|
||||||
RealPointValuePair optimum;
|
RealPointValuePair optimum;
|
||||||
|
|
||||||
// minimization
|
// minimization
|
||||||
optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { -3.0, 0 });
|
optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { -3.0, 0 });
|
||||||
assertEquals(xM, optimum.getPoint()[0], 4.0e-6);
|
Assert.assertEquals(xM, optimum.getPoint()[0], 4.0e-6);
|
||||||
assertEquals(yP, optimum.getPoint()[1], 3.0e-6);
|
Assert.assertEquals(yP, optimum.getPoint()[1], 3.0e-6);
|
||||||
assertEquals(valueXmYp, optimum.getValue(), 8.0e-13);
|
Assert.assertEquals(valueXmYp, optimum.getValue(), 8.0e-13);
|
||||||
assertTrue(optimizer.getEvaluations() > 120);
|
Assert.assertTrue(optimizer.getEvaluations() > 120);
|
||||||
assertTrue(optimizer.getEvaluations() < 150);
|
Assert.assertTrue(optimizer.getEvaluations() < 150);
|
||||||
|
|
||||||
optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { +1, 0 });
|
optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { +1, 0 });
|
||||||
assertEquals(xP, optimum.getPoint()[0], 2.0e-8);
|
Assert.assertEquals(xP, optimum.getPoint()[0], 2.0e-8);
|
||||||
assertEquals(yM, optimum.getPoint()[1], 3.0e-6);
|
Assert.assertEquals(yM, optimum.getPoint()[1], 3.0e-6);
|
||||||
assertEquals(valueXpYm, optimum.getValue(), 2.0e-12);
|
Assert.assertEquals(valueXpYm, optimum.getValue(), 2.0e-12);
|
||||||
assertTrue(optimizer.getEvaluations() > 120);
|
Assert.assertTrue(optimizer.getEvaluations() > 120);
|
||||||
assertTrue(optimizer.getEvaluations() < 150);
|
Assert.assertTrue(optimizer.getEvaluations() < 150);
|
||||||
|
|
||||||
// maximization
|
// maximization
|
||||||
optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { -3.0, 0.0 });
|
optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { -3.0, 0.0 });
|
||||||
assertEquals(xM, optimum.getPoint()[0], 7.0e-7);
|
Assert.assertEquals(xM, optimum.getPoint()[0], 7.0e-7);
|
||||||
assertEquals(yM, optimum.getPoint()[1], 3.0e-7);
|
Assert.assertEquals(yM, optimum.getPoint()[1], 3.0e-7);
|
||||||
assertEquals(valueXmYm, optimum.getValue(), 2.0e-14);
|
Assert.assertEquals(valueXmYm, optimum.getValue(), 2.0e-14);
|
||||||
assertTrue(optimizer.getEvaluations() > 120);
|
Assert.assertTrue(optimizer.getEvaluations() > 120);
|
||||||
assertTrue(optimizer.getEvaluations() < 150);
|
Assert.assertTrue(optimizer.getEvaluations() < 150);
|
||||||
|
|
||||||
|
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-15, 1.0e-30));
|
||||||
optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { +1, 0 });
|
optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { +1, 0 });
|
||||||
assertEquals(xP, optimum.getPoint()[0], 2.0e-8);
|
Assert.assertEquals(xP, optimum.getPoint()[0], 2.0e-8);
|
||||||
assertEquals(yP, optimum.getPoint()[1], 3.0e-6);
|
Assert.assertEquals(yP, optimum.getPoint()[1], 3.0e-6);
|
||||||
assertEquals(valueXpYp, optimum.getValue(), 2.0e-12);
|
Assert.assertEquals(valueXpYp, optimum.getValue(), 2.0e-12);
|
||||||
assertTrue(optimizer.getEvaluations() > 120);
|
Assert.assertTrue(optimizer.getEvaluations() > 180);
|
||||||
assertTrue(optimizer.getEvaluations() < 150);
|
Assert.assertTrue(optimizer.getEvaluations() < 220);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
public void testRosenbrock()
|
public void testRosenbrock()
|
||||||
throws FunctionEvaluationException, ConvergenceException {
|
throws FunctionEvaluationException, ConvergenceException {
|
||||||
|
|
||||||
|
@ -154,13 +152,14 @@ public class MultiDirectionalTest
|
||||||
RealPointValuePair optimum =
|
RealPointValuePair optimum =
|
||||||
optimizer.optimize(rosenbrock, GoalType.MINIMIZE, new double[] { -1.2, 1.0 });
|
optimizer.optimize(rosenbrock, GoalType.MINIMIZE, new double[] { -1.2, 1.0 });
|
||||||
|
|
||||||
assertEquals(count, optimizer.getEvaluations());
|
Assert.assertEquals(count, optimizer.getEvaluations());
|
||||||
assertTrue(optimizer.getEvaluations() > 70);
|
Assert.assertTrue(optimizer.getEvaluations() > 50);
|
||||||
assertTrue(optimizer.getEvaluations() < 100);
|
Assert.assertTrue(optimizer.getEvaluations() < 100);
|
||||||
assertTrue(optimum.getValue() > 1.0e-2);
|
Assert.assertTrue(optimum.getValue() > 1.0e-2);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
public void testPowell()
|
public void testPowell()
|
||||||
throws FunctionEvaluationException, ConvergenceException {
|
throws FunctionEvaluationException, ConvergenceException {
|
||||||
|
|
||||||
|
@ -183,15 +182,64 @@ public class MultiDirectionalTest
|
||||||
optimizer.setMaxIterations(1000);
|
optimizer.setMaxIterations(1000);
|
||||||
RealPointValuePair optimum =
|
RealPointValuePair optimum =
|
||||||
optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
|
optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
|
||||||
assertEquals(count, optimizer.getEvaluations());
|
Assert.assertEquals(count, optimizer.getEvaluations());
|
||||||
assertTrue(optimizer.getEvaluations() > 800);
|
Assert.assertTrue(optimizer.getEvaluations() > 800);
|
||||||
assertTrue(optimizer.getEvaluations() < 900);
|
Assert.assertTrue(optimizer.getEvaluations() < 900);
|
||||||
assertTrue(optimum.getValue() > 1.0e-2);
|
Assert.assertTrue(optimum.getValue() > 1.0e-2);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Test suite() {
|
@Test
|
||||||
return new TestSuite(MultiDirectionalTest.class);
|
public void testMath283()
|
||||||
|
throws FunctionEvaluationException, OptimizationException {
|
||||||
|
// fails because MultiDirectional.iterateSimplex is looping forever
|
||||||
|
// the while(true) should be replaced with a convergence check
|
||||||
|
MultiDirectional multiDirectional = new MultiDirectional();
|
||||||
|
multiDirectional.setMaxIterations(100);
|
||||||
|
multiDirectional.setMaxEvaluations(1000);
|
||||||
|
|
||||||
|
final Gaussian2D function = new Gaussian2D(0.0, 0.0, 1.0);
|
||||||
|
|
||||||
|
RealPointValuePair estimate = multiDirectional.optimize(function,
|
||||||
|
GoalType.MAXIMIZE, function.getMaximumPosition());
|
||||||
|
|
||||||
|
final double EPSILON = 1e-5;
|
||||||
|
|
||||||
|
final double expectedMaximum = function.getMaximum();
|
||||||
|
final double actualMaximum = estimate.getValue();
|
||||||
|
Assert.assertEquals(expectedMaximum, actualMaximum, EPSILON);
|
||||||
|
|
||||||
|
final double[] expectedPosition = function.getMaximumPosition();
|
||||||
|
final double[] actualPosition = estimate.getPoint();
|
||||||
|
Assert.assertEquals(expectedPosition[0], actualPosition[0], EPSILON );
|
||||||
|
Assert.assertEquals(expectedPosition[1], actualPosition[1], EPSILON );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class Gaussian2D implements MultivariateRealFunction {
|
||||||
|
|
||||||
|
private final double[] maximumPosition;
|
||||||
|
|
||||||
|
private final double std;
|
||||||
|
|
||||||
|
public Gaussian2D(double xOpt, double yOpt, double std) {
|
||||||
|
maximumPosition = new double[] { xOpt, yOpt };
|
||||||
|
this.std = std;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getMaximum() {
|
||||||
|
return value(maximumPosition);
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] getMaximumPosition() {
|
||||||
|
return maximumPosition.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
public double value(double[] point) {
|
||||||
|
final double x = point[0], y = point[1];
|
||||||
|
final double twoS2 = 2.0 * std * std;
|
||||||
|
return 1.0 / (twoS2 * Math.PI) * Math.exp(-(x * x + y * y) / twoS2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private int count;
|
private int count;
|
||||||
|
|
Loading…
Reference in New Issue