improved test coverage
git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@574082 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
5ae04469dc
commit
c5cb64a7e3
|
@ -33,6 +33,41 @@ public class MultiDirectionalTest
|
|||
super(name);
|
||||
}
|
||||
|
||||
public void testCostExceptions() throws ConvergenceException {
|
||||
CostFunction wrong =
|
||||
new CostFunction() {
|
||||
public double cost(double[] x) throws CostException {
|
||||
if (x[0] < 0) {
|
||||
throw new CostException("{0}", new Object[] { "oops"});
|
||||
} else if (x[0] > 1) {
|
||||
throw new CostException(new RuntimeException("oops"));
|
||||
} else {
|
||||
return x[0] * (1 - x[0]);
|
||||
}
|
||||
}
|
||||
};
|
||||
try {
|
||||
new MultiDirectional(1.9, 0.4).minimizes(wrong, 10, new ValueChecker(1.0e-3),
|
||||
new double[] { -0.5 }, new double[] { 0.5 });
|
||||
fail("an exception should have been thrown");
|
||||
} catch (CostException ce) {
|
||||
// expected behavior
|
||||
assertNull(ce.getCause());
|
||||
} catch (Exception e) {
|
||||
fail("wrong exception caught: " + e.getMessage());
|
||||
}
|
||||
try {
|
||||
new MultiDirectional(1.9, 0.4).minimizes(wrong, 10, new ValueChecker(1.0e-3),
|
||||
new double[] { 0.5 }, new double[] { 1.5 });
|
||||
fail("an exception should have been thrown");
|
||||
} catch (CostException ce) {
|
||||
// expected behavior
|
||||
assertNotNull(ce.getCause());
|
||||
} catch (Exception e) {
|
||||
fail("wrong exception caught: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public void testRosenbrock()
|
||||
throws CostException, ConvergenceException {
|
||||
|
||||
|
@ -49,11 +84,12 @@ public class MultiDirectionalTest
|
|||
count = 0;
|
||||
PointCostPair optimum =
|
||||
new MultiDirectional().minimizes(rosenbrock, 100, new ValueChecker(1.0e-3),
|
||||
new double[] { -1.2, 1.0 },
|
||||
new double[] { 3.5, -2.3 });
|
||||
new double[][] {
|
||||
{ -1.2, 1.0 }, { 0.9, 1.2 } , { 3.5, -2.3 }
|
||||
});
|
||||
|
||||
assertTrue(count > 60);
|
||||
assertTrue(optimum.cost > 0.02);
|
||||
assertTrue(optimum.cost > 0.01);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,12 @@ import org.apache.commons.math.optimization.CostFunction;
|
|||
import org.apache.commons.math.optimization.NelderMead;
|
||||
import org.apache.commons.math.ConvergenceException;
|
||||
import org.apache.commons.math.optimization.PointCostPair;
|
||||
import org.apache.commons.math.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math.random.NotPositiveDefiniteMatrixException;
|
||||
import org.apache.commons.math.random.RandomGenerator;
|
||||
import org.apache.commons.math.random.RandomVectorGenerator;
|
||||
import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
|
||||
import org.apache.commons.math.random.UniformRandomGenerator;
|
||||
|
||||
import junit.framework.*;
|
||||
|
||||
|
@ -33,8 +39,43 @@ public class NelderMeadTest
|
|||
super(name);
|
||||
}
|
||||
|
||||
public void testCostExceptions() throws ConvergenceException {
|
||||
CostFunction wrong =
|
||||
new CostFunction() {
|
||||
public double cost(double[] x) throws CostException {
|
||||
if (x[0] < 0) {
|
||||
throw new CostException("{0}", new Object[] { "oops"});
|
||||
} else if (x[0] > 1) {
|
||||
throw new CostException(new RuntimeException("oops"));
|
||||
} else {
|
||||
return x[0] * (1 - x[0]);
|
||||
}
|
||||
}
|
||||
};
|
||||
try {
|
||||
new NelderMead(0.9, 1.9, 0.4, 0.6).minimizes(wrong, 10, new ValueChecker(1.0e-3),
|
||||
new double[] { -0.5 }, new double[] { 0.5 });
|
||||
fail("an exception should have been thrown");
|
||||
} catch (CostException ce) {
|
||||
// expected behavior
|
||||
assertNull(ce.getCause());
|
||||
} catch (Exception e) {
|
||||
fail("wrong exception caught: " + e.getMessage());
|
||||
}
|
||||
try {
|
||||
new NelderMead(0.9, 1.9, 0.4, 0.6).minimizes(wrong, 10, new ValueChecker(1.0e-3),
|
||||
new double[] { 0.5 }, new double[] { 1.5 });
|
||||
fail("an exception should have been thrown");
|
||||
} catch (CostException ce) {
|
||||
// expected behavior
|
||||
assertNotNull(ce.getCause());
|
||||
} catch (Exception e) {
|
||||
fail("wrong exception caught: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public void testRosenbrock()
|
||||
throws CostException, ConvergenceException {
|
||||
throws CostException, ConvergenceException, NotPositiveDefiniteMatrixException {
|
||||
|
||||
CostFunction rosenbrock =
|
||||
new CostFunction() {
|
||||
|
@ -47,15 +88,51 @@ public class NelderMeadTest
|
|||
};
|
||||
|
||||
count = 0;
|
||||
PointCostPair optimum =
|
||||
new NelderMead().minimizes(rosenbrock, 100, new ValueChecker(1.0e-3),
|
||||
new double[] { -1.2, 1.0 },
|
||||
new double[] { 3.5, -2.3 });
|
||||
NelderMead nm = new NelderMead();
|
||||
try {
|
||||
nm.minimizes(rosenbrock, 100, new ValueChecker(1.0e-3),
|
||||
new double[][] {
|
||||
{ -1.2, 1.0 }, { 3.5, -2.3 }, { 0.4, 1.5 }
|
||||
}, 1, 5384353l);
|
||||
fail("an exception should have been thrown");
|
||||
} catch (ConvergenceException ce) {
|
||||
// expected behavior
|
||||
} catch (Exception e) {
|
||||
fail("wrong exception caught: " + e.getMessage());
|
||||
}
|
||||
|
||||
assertTrue(count < 50);
|
||||
assertEquals(0.0, optimum.cost, 6.0e-4);
|
||||
assertEquals(1.0, optimum.point[0], 0.05);
|
||||
assertEquals(1.0, optimum.point[1], 0.05);
|
||||
count = 0;
|
||||
PointCostPair optimum =
|
||||
nm.minimizes(rosenbrock, 100, new ValueChecker(1.0e-3),
|
||||
new double[][] {
|
||||
{ -1.2, 1.0 }, { 0.9, 1.2 }, { 3.5, -2.3 }
|
||||
}, 3, 1642738l);
|
||||
|
||||
assertTrue(count < 200);
|
||||
assertEquals(0.0, optimum.cost, 5.0e-5);
|
||||
assertEquals(1.0, optimum.point[0], 0.01);
|
||||
assertEquals(1.0, optimum.point[1], 0.01);
|
||||
|
||||
PointCostPair[] minima = nm.getMinima();
|
||||
assertEquals(3, minima.length);
|
||||
for (int i = 1; i < minima.length; ++i) {
|
||||
if (minima[i] != null) {
|
||||
assertTrue(minima[i-1].cost <= minima[i].cost);
|
||||
}
|
||||
}
|
||||
|
||||
RandomGenerator rg = new JDKRandomGenerator();
|
||||
rg.setSeed(64453353l);
|
||||
RandomVectorGenerator rvg =
|
||||
new UncorrelatedRandomVectorGenerator(new double[] { 0.9, 1.1 },
|
||||
new double[] { 0.2, 0.2 },
|
||||
new UniformRandomGenerator(rg));
|
||||
optimum =
|
||||
nm.minimizes(rosenbrock, 100, new ValueChecker(1.0e-3), rvg);
|
||||
assertEquals(0.0, optimum.cost, 2.0e-4);
|
||||
optimum =
|
||||
nm.minimizes(rosenbrock, 100, new ValueChecker(1.0e-3), rvg, 3);
|
||||
assertEquals(0.0, optimum.cost, 3.0e-5);
|
||||
|
||||
}
|
||||
|
||||
|
@ -75,10 +152,12 @@ public class NelderMeadTest
|
|||
};
|
||||
|
||||
count = 0;
|
||||
NelderMead nm = new NelderMead();
|
||||
PointCostPair optimum =
|
||||
new NelderMead().minimizes(powell, 200, new ValueChecker(1.0e-3),
|
||||
new double[] { 3.0, -1.0, 0.0, 1.0 },
|
||||
new double[] { 4.0, 0.0, 1.0, 2.0 });
|
||||
nm.minimizes(powell, 200, new ValueChecker(1.0e-3),
|
||||
new double[] { 3.0, -1.0, 0.0, 1.0 },
|
||||
new double[] { 4.0, 0.0, 1.0, 2.0 },
|
||||
1, 1642738l);
|
||||
assertTrue(count < 150);
|
||||
assertEquals(0.0, optimum.cost, 6.0e-4);
|
||||
assertEquals(0.0, optimum.point[0], 0.07);
|
||||
|
|
Loading…
Reference in New Issue