Add unit tests.

Failing tests are disabled.
This commit is contained in:
Gilles Sadowski 2020-07-28 23:23:24 +02:00
parent 2470c3ff28
commit 3d2b2107b5
2 changed files with 313 additions and 139 deletions

View File

@ -30,14 +30,18 @@ import org.apache.commons.math4.optim.nonlinear.scalar.noderiv.MultiDirectionalS
import org.apache.commons.math4.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;
import org.apache.commons.math4.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
import org.apache.commons.math4.util.FastMath;
import org.apache.commons.math4.util.MathArrays;
import org.junit.Assert;
import org.junit.Test;
import org.junit.Ignore;
public class SimplexOptimizerMultiDirectionalTest {
private static final int DIM = 13;
@Test(expected=MathUnsupportedOperationException.class)
public void testBoundsUnsupported() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final OptimTestUtils.FourExtrema fourExtrema = new OptimTestUtils.FourExtrema();
optimizer.optimize(new MaxEval(100),
new ObjectiveFunction(fourExtrema),
@ -51,7 +55,7 @@ public class SimplexOptimizerMultiDirectionalTest {
@Test
public void testMinimize1() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-11, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final OptimTestUtils.FourExtrema fourExtrema = new OptimTestUtils.FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(200),
@ -72,7 +76,7 @@ public class SimplexOptimizerMultiDirectionalTest {
@Test
public void testMinimize2() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-11, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final OptimTestUtils.FourExtrema fourExtrema = new OptimTestUtils.FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(200),
@ -93,7 +97,7 @@ public class SimplexOptimizerMultiDirectionalTest {
@Test
public void testMaximize1() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-11, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final OptimTestUtils.FourExtrema fourExtrema = new OptimTestUtils.FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(200),
@ -114,7 +118,7 @@ public class SimplexOptimizerMultiDirectionalTest {
@Test
public void testMaximize2() {
SimplexOptimizer optimizer = new SimplexOptimizer(new SimpleValueChecker(1e-15, 1e-30));
final FourExtrema fourExtrema = new FourExtrema();
final OptimTestUtils.FourExtrema fourExtrema = new OptimTestUtils.FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(200),
@ -134,18 +138,7 @@ public class SimplexOptimizerMultiDirectionalTest {
@Test
public void testRosenbrock() {
MultivariateFunction rosenbrock
= new MultivariateFunction() {
@Override
public double value(double[] x) {
++count;
double a = x[1] - x[0] * x[0];
double b = 1.0 - x[0];
return 100 * a * a + b * b;
}
};
count = 0;
final OptimTestUtils.Rosenbrock rosenbrock = new OptimTestUtils.Rosenbrock();
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
@ -156,8 +149,6 @@ public class SimplexOptimizerMultiDirectionalTest {
{ -1.2, 1.0 },
{ 0.9, 1.2 },
{ 3.5, -2.3 } }));
Assert.assertEquals(count, optimizer.getEvaluations());
Assert.assertTrue(optimizer.getEvaluations() > 50);
Assert.assertTrue(optimizer.getEvaluations() < 100);
Assert.assertTrue(optimum.getValue() > 1e-2);
@ -165,20 +156,7 @@ public class SimplexOptimizerMultiDirectionalTest {
@Test
public void testPowell() {
MultivariateFunction powell
= new MultivariateFunction() {
@Override
public double value(double[] x) {
++count;
double a = x[0] + 10 * x[1];
double b = x[2] - x[3];
double c = x[1] - 2 * x[2];
double d = x[0] - x[3];
return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
}
};
count = 0;
final OptimTestUtils.Powell powell = new OptimTestUtils.Powell();
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
PointValuePair optimum
= optimizer.optimize(new MaxEval(1000),
@ -186,7 +164,6 @@ public class SimplexOptimizerMultiDirectionalTest {
GoalType.MINIMIZE,
new InitialGuess(new double[] { 3, -1, 0, 1 }),
new MultiDirectionalSimplex(4));
Assert.assertEquals(count, optimizer.getEvaluations());
Assert.assertTrue(optimizer.getEvaluations() > 800);
Assert.assertTrue(optimizer.getEvaluations() < 900);
Assert.assertTrue(optimum.getValue() > 1e-2);
@ -197,7 +174,7 @@ public class SimplexOptimizerMultiDirectionalTest {
// fails because MultiDirectional.iterateSimplex is looping forever
// the while(true) should be replaced with a convergence check
SimplexOptimizer optimizer = new SimplexOptimizer(1e-14, 1e-14);
final Gaussian2D function = new Gaussian2D(0, 0, 1);
final OptimTestUtils.Gaussian2D function = new OptimTestUtils.Gaussian2D(0, 0, 1);
PointValuePair estimate = optimizer.optimize(new MaxEval(1000),
new ObjectiveFunction(function),
GoalType.MAXIMIZE,
@ -214,50 +191,152 @@ public class SimplexOptimizerMultiDirectionalTest {
Assert.assertEquals(expectedPosition[1], actualPosition[1], EPSILON );
}
private static class FourExtrema implements MultivariateFunction {
// The following function has 4 local extrema.
final double xM = -3.841947088256863675365;
final double yM = -1.391745200270734924416;
final double xP = 0.2286682237349059125691;
final double yP = -yM;
final double valueXmYm = 0.2373295333134216789769; // Local maximum.
final double valueXmYp = -valueXmYm; // Local minimum.
final double valueXpYm = -0.7290400707055187115322; // Global minimum.
final double valueXpYp = -valueXpYm; // Global maximum.
@Override
public double value(double[] variables) {
final double x = variables[0];
final double y = variables[1];
return (x == 0 || y == 0) ? 0 :
FastMath.atan(x) * FastMath.atan(x + 2) * FastMath.atan(y) * FastMath.atan(y) / (x * y);
}
@Test
public void testRosen() {
doTest(new OptimTestUtils.Rosen(),
OptimTestUtils.point(DIM, 0.1),
GoalType.MINIMIZE,
183861,
new PointValuePair(OptimTestUtils.point(DIM, 1.0), 0.0),
1e-4);
}
private static class Gaussian2D implements MultivariateFunction {
private final double[] maximumPosition;
private final double std;
@Test
public void testEllipse() {
doTest(new OptimTestUtils.Elli(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
873,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
public Gaussian2D(double xOpt, double yOpt, double std) {
maximumPosition = new double[] { xOpt, yOpt };
this.std = std;
}
public double getMaximum() {
return value(maximumPosition);
}
public double[] getMaximumPosition() {
return maximumPosition.clone();
}
@Override
public double value(double[] point) {
final double x = point[0], y = point[1];
final double twoS2 = 2.0 * std * std;
return 1.0 / (twoS2 * FastMath.PI) * FastMath.exp(-(x * x + y * y) / twoS2);
}
//@Ignore
@Test
public void testElliRotated() {
doTest(new OptimTestUtils.ElliRotated(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
873,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
private int count;
@Test
public void testCigar() {
doTest(new OptimTestUtils.Cigar(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
925,
new PointValuePair(OptimTestUtils.point(DIM,0.0), 0.0),
1e-14);
}
@Test
public void testTwoAxes() {
doTest(new OptimTestUtils.TwoAxes(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
1159,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
@Test
public void testCigTab() {
doTest(new OptimTestUtils.CigTab(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
795,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
@Test
public void testSphere() {
doTest(new OptimTestUtils.Sphere(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
665,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
@Test
public void testTablet() {
doTest(new OptimTestUtils.Tablet(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
873,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
@Test
public void testDiffPow() {
doTest(new OptimTestUtils.DiffPow(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
614,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
@Test
public void testSsDiffPow() {
doTest(new OptimTestUtils.SsDiffPow(),
OptimTestUtils.point(DIM / 2, 1.0),
GoalType.MINIMIZE,
656,
new PointValuePair(OptimTestUtils.point(DIM / 2, 0.0), 0.0),
1e-15);
}
@Ignore
@Test
public void testAckley() {
doTest(new OptimTestUtils.Ackley(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
587,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
0);
}
@Ignore
@Test
public void testRastrigin() {
doTest(new OptimTestUtils.Rastrigin(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
535,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
0);
}
/**
* @param func Function to optimize.
* @param startPoint Starting point.
* @param goal Minimization or maximization.
* @param maxEvaluations Maximum number of evaluations.
* @param expected Expected optimum.
* @param tol Tolerance for checking that the optimum is correct.
*/
private void doTest(MultivariateFunction func,
double[] startPoint,
GoalType goal,
int maxEvaluations,
PointValuePair expected,
double tol) {
final int dim = startPoint.length;
final SimplexOptimizer optim = new SimplexOptimizer(1e-10, 1e-12);
final PointValuePair result = optim.optimize(new MaxEval(maxEvaluations),
new ObjectiveFunction(func),
goal,
new InitialGuess(startPoint),
new MultiDirectionalSimplex(dim, 0.1));
final double dist = MathArrays.distance(expected.getPoint(),
result.getPoint());
Assert.assertEquals(0d, dist, tol);
}
}

View File

@ -34,14 +34,18 @@ import org.apache.commons.math4.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math4.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;
import org.apache.commons.math4.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
import org.apache.commons.math4.util.FastMath;
import org.apache.commons.math4.util.MathArrays;
import org.junit.Assert;
import org.junit.Test;
import org.junit.Ignore;
public class SimplexOptimizerNelderMeadTest {
private static final int DIM = 13;
@Test(expected=MathUnsupportedOperationException.class)
public void testBoundsUnsupported() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final OptimTestUtils.FourExtrema fourExtrema = new OptimTestUtils.FourExtrema();
optimizer.optimize(new MaxEval(100),
new ObjectiveFunction(fourExtrema),
@ -55,7 +59,7 @@ public class SimplexOptimizerNelderMeadTest {
@Test
public void testMinimize1() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final OptimTestUtils.FourExtrema fourExtrema = new OptimTestUtils.FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
@ -76,7 +80,7 @@ public class SimplexOptimizerNelderMeadTest {
@Test
public void testMinimize2() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final OptimTestUtils.FourExtrema fourExtrema = new OptimTestUtils.FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
@ -97,7 +101,7 @@ public class SimplexOptimizerNelderMeadTest {
@Test
public void testMaximize1() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final OptimTestUtils.FourExtrema fourExtrema = new OptimTestUtils.FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
@ -118,7 +122,7 @@ public class SimplexOptimizerNelderMeadTest {
@Test
public void testMaximize2() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final OptimTestUtils.FourExtrema fourExtrema = new OptimTestUtils.FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
@ -139,7 +143,7 @@ public class SimplexOptimizerNelderMeadTest {
@Test
public void testRosenbrock() {
Rosenbrock rosenbrock = new Rosenbrock();
OptimTestUtils.Rosenbrock rosenbrock = new OptimTestUtils.Rosenbrock();
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
@ -150,8 +154,6 @@ public class SimplexOptimizerNelderMeadTest {
{ -1.2, 1 },
{ 0.9, 1.2 },
{ 3.5, -2.3 } }));
Assert.assertEquals(rosenbrock.getCount(), optimizer.getEvaluations());
Assert.assertTrue(optimizer.getEvaluations() > 40);
Assert.assertTrue(optimizer.getEvaluations() < 50);
Assert.assertTrue(optimum.getValue() < 8e-4);
@ -159,7 +161,7 @@ public class SimplexOptimizerNelderMeadTest {
@Test
public void testPowell() {
Powell powell = new Powell();
OptimTestUtils.Powell powell = new OptimTestUtils.Powell();
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
PointValuePair optimum =
optimizer.optimize(new MaxEval(200),
@ -167,7 +169,6 @@ public class SimplexOptimizerNelderMeadTest {
GoalType.MINIMIZE,
new InitialGuess(new double[] { 3, -1, 0, 1 }),
new NelderMeadSimplex(4));
Assert.assertEquals(powell.getCount(), optimizer.getEvaluations());
Assert.assertTrue(optimizer.getEvaluations() > 110);
Assert.assertTrue(optimizer.getEvaluations() < 130);
Assert.assertTrue(optimum.getValue() < 2e-3);
@ -258,7 +259,7 @@ public class SimplexOptimizerNelderMeadTest {
@Test(expected=TooManyEvaluationsException.class)
public void testMaxIterations() {
Powell powell = new Powell();
OptimTestUtils.Powell powell = new OptimTestUtils.Powell();
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
optimizer.optimize(new MaxEval(20),
new ObjectiveFunction(powell),
@ -267,65 +268,159 @@ public class SimplexOptimizerNelderMeadTest {
new NelderMeadSimplex(4));
}
private static class FourExtrema implements MultivariateFunction {
// The following function has 4 local extrema.
final double xM = -3.841947088256863675365;
final double yM = -1.391745200270734924416;
final double xP = 0.2286682237349059125691;
final double yP = -yM;
final double valueXmYm = 0.2373295333134216789769; // Local maximum.
final double valueXmYp = -valueXmYm; // Local minimum.
final double valueXpYm = -0.7290400707055187115322; // Global minimum.
final double valueXpYp = -valueXpYm; // Global maximum.
@Override
public double value(double[] variables) {
final double x = variables[0];
final double y = variables[1];
return (x == 0 || y == 0) ? 0 :
FastMath.atan(x) * FastMath.atan(x + 2) * FastMath.atan(y) * FastMath.atan(y) / (x * y);
}
@Test
public void testRosen() {
doTest(new OptimTestUtils.Rosen(),
OptimTestUtils.point(DIM, 0.1),
GoalType.MINIMIZE,
11975,
new PointValuePair(OptimTestUtils.point(DIM, 1.0), 0.0),
1e-6);
}
private static class Rosenbrock implements MultivariateFunction {
private int count;
@Ignore
@Test
public void testEllipse() {
doTest(new OptimTestUtils.Elli(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
7184,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
public Rosenbrock() {
count = 0;
}
@Override
public double value(double[] x) {
++count;
double a = x[1] - x[0] * x[0];
double b = 1.0 - x[0];
return 100 * a * a + b * b;
}
public int getCount() {
return count;
}
@Ignore
@Test
public void testElliRotated() {
doTest(new OptimTestUtils.ElliRotated(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
7467,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
private static class Powell implements MultivariateFunction {
private int count;
@Test
public void testCigar() {
doTest(new OptimTestUtils.Cigar(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
9160,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-6);
}
public Powell() {
count = 0;
}
@Ignore
@Test
public void testTwoAxes() {
doTest(new OptimTestUtils.TwoAxes(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
3451,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
@Override
public double value(double[] x) {
++count;
double a = x[0] + 10 * x[1];
double b = x[2] - x[3];
double c = x[1] - 2 * x[2];
double d = x[0] - x[3];
return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
}
@Ignore
@Test
public void testCigTab() {
doTest(new OptimTestUtils.CigTab(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
7454,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
public int getCount() {
return count;
}
@Ignore
@Test
public void testSphere() {
doTest(new OptimTestUtils.Sphere(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
3881,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-6);
}
@Ignore
@Test
public void testTablet() {
doTest(new OptimTestUtils.Tablet(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
6639,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
@Ignore
@Test
public void testDiffPow() {
doTest(new OptimTestUtils.DiffPow(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
4105,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
1e-14);
}
@Ignore
@Test
public void testSsDiffPow() {
doTest(new OptimTestUtils.SsDiffPow(),
OptimTestUtils.point(DIM / 2, 1.0),
GoalType.MINIMIZE,
3990,
new PointValuePair(OptimTestUtils.point(DIM / 2, 0.0), 0.0),
1e-15);
}
@Ignore
@Test
public void testAckley() {
doTest(new OptimTestUtils.Ackley(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
2849,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
0);
}
@Ignore
@Test
public void testRastrigin() {
doTest(new OptimTestUtils.Rastrigin(),
OptimTestUtils.point(DIM, 1.0),
GoalType.MINIMIZE,
2166,
new PointValuePair(OptimTestUtils.point(DIM, 0.0), 0.0),
0);
}
/**
* @param func Function to optimize.
* @param startPoint Starting point.
* @param goal Minimization or maximization.
* @param maxEvaluations Maximum number of evaluations.
* @param expected Expected optimum.
* @param tol Tolerance for checking that the optimum is correct.
*/
private void doTest(MultivariateFunction func,
double[] startPoint,
GoalType goal,
int maxEvaluations,
PointValuePair expected,
double tol) {
final int dim = startPoint.length;
final SimplexOptimizer optim = new SimplexOptimizer(1e-13, 1e-14);
final PointValuePair result = optim.optimize(new MaxEval(maxEvaluations),
new ObjectiveFunction(func),
goal,
new InitialGuess(startPoint),
new NelderMeadSimplex(dim, 0.1));
final double dist = MathArrays.distance(expected.getPoint(),
result.getPoint());
Assert.assertEquals(0d, dist, tol);
}
}