improved test coverage

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@574082 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2007-09-09 21:39:00 +00:00
parent 5ae04469dc
commit c5cb64a7e3
2 changed files with 130 additions and 15 deletions

View File

@ -33,6 +33,41 @@ public class MultiDirectionalTest
super(name);
}
public void testCostExceptions() throws ConvergenceException {
CostFunction wrong =
new CostFunction() {
public double cost(double[] x) throws CostException {
if (x[0] < 0) {
throw new CostException("{0}", new Object[] { "oops"});
} else if (x[0] > 1) {
throw new CostException(new RuntimeException("oops"));
} else {
return x[0] * (1 - x[0]);
}
}
};
try {
new MultiDirectional(1.9, 0.4).minimizes(wrong, 10, new ValueChecker(1.0e-3),
new double[] { -0.5 }, new double[] { 0.5 });
fail("an exception should have been thrown");
} catch (CostException ce) {
// expected behavior
assertNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
try {
new MultiDirectional(1.9, 0.4).minimizes(wrong, 10, new ValueChecker(1.0e-3),
new double[] { 0.5 }, new double[] { 1.5 });
fail("an exception should have been thrown");
} catch (CostException ce) {
// expected behavior
assertNotNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
}
public void testRosenbrock()
throws CostException, ConvergenceException {
@ -49,11 +84,12 @@ public class MultiDirectionalTest
count = 0;
PointCostPair optimum =
new MultiDirectional().minimizes(rosenbrock, 100, new ValueChecker(1.0e-3),
new double[] { -1.2, 1.0 },
new double[] { 3.5, -2.3 });
new double[][] {
{ -1.2, 1.0 }, { 0.9, 1.2 } , { 3.5, -2.3 }
});
assertTrue(count > 60);
assertTrue(optimum.cost > 0.02);
assertTrue(optimum.cost > 0.01);
}

View File

@ -23,6 +23,12 @@ import org.apache.commons.math.optimization.CostFunction;
import org.apache.commons.math.optimization.NelderMead;
import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.optimization.PointCostPair;
import org.apache.commons.math.random.JDKRandomGenerator;
import org.apache.commons.math.random.NotPositiveDefiniteMatrixException;
import org.apache.commons.math.random.RandomGenerator;
import org.apache.commons.math.random.RandomVectorGenerator;
import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
import org.apache.commons.math.random.UniformRandomGenerator;
import junit.framework.*;
@ -33,8 +39,43 @@ public class NelderMeadTest
super(name);
}
public void testCostExceptions() throws ConvergenceException {
CostFunction wrong =
new CostFunction() {
public double cost(double[] x) throws CostException {
if (x[0] < 0) {
throw new CostException("{0}", new Object[] { "oops"});
} else if (x[0] > 1) {
throw new CostException(new RuntimeException("oops"));
} else {
return x[0] * (1 - x[0]);
}
}
};
try {
new NelderMead(0.9, 1.9, 0.4, 0.6).minimizes(wrong, 10, new ValueChecker(1.0e-3),
new double[] { -0.5 }, new double[] { 0.5 });
fail("an exception should have been thrown");
} catch (CostException ce) {
// expected behavior
assertNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
try {
new NelderMead(0.9, 1.9, 0.4, 0.6).minimizes(wrong, 10, new ValueChecker(1.0e-3),
new double[] { 0.5 }, new double[] { 1.5 });
fail("an exception should have been thrown");
} catch (CostException ce) {
// expected behavior
assertNotNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
}
public void testRosenbrock()
throws CostException, ConvergenceException {
throws CostException, ConvergenceException, NotPositiveDefiniteMatrixException {
CostFunction rosenbrock =
new CostFunction() {
@ -47,15 +88,51 @@ public class NelderMeadTest
};
count = 0;
PointCostPair optimum =
new NelderMead().minimizes(rosenbrock, 100, new ValueChecker(1.0e-3),
new double[] { -1.2, 1.0 },
new double[] { 3.5, -2.3 });
NelderMead nm = new NelderMead();
try {
nm.minimizes(rosenbrock, 100, new ValueChecker(1.0e-3),
new double[][] {
{ -1.2, 1.0 }, { 3.5, -2.3 }, { 0.4, 1.5 }
}, 1, 5384353l);
fail("an exception should have been thrown");
} catch (ConvergenceException ce) {
// expected behavior
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
assertTrue(count < 50);
assertEquals(0.0, optimum.cost, 6.0e-4);
assertEquals(1.0, optimum.point[0], 0.05);
assertEquals(1.0, optimum.point[1], 0.05);
count = 0;
PointCostPair optimum =
nm.minimizes(rosenbrock, 100, new ValueChecker(1.0e-3),
new double[][] {
{ -1.2, 1.0 }, { 0.9, 1.2 }, { 3.5, -2.3 }
}, 3, 1642738l);
assertTrue(count < 200);
assertEquals(0.0, optimum.cost, 5.0e-5);
assertEquals(1.0, optimum.point[0], 0.01);
assertEquals(1.0, optimum.point[1], 0.01);
PointCostPair[] minima = nm.getMinima();
assertEquals(3, minima.length);
for (int i = 1; i < minima.length; ++i) {
if (minima[i] != null) {
assertTrue(minima[i-1].cost <= minima[i].cost);
}
}
RandomGenerator rg = new JDKRandomGenerator();
rg.setSeed(64453353l);
RandomVectorGenerator rvg =
new UncorrelatedRandomVectorGenerator(new double[] { 0.9, 1.1 },
new double[] { 0.2, 0.2 },
new UniformRandomGenerator(rg));
optimum =
nm.minimizes(rosenbrock, 100, new ValueChecker(1.0e-3), rvg);
assertEquals(0.0, optimum.cost, 2.0e-4);
optimum =
nm.minimizes(rosenbrock, 100, new ValueChecker(1.0e-3), rvg, 3);
assertEquals(0.0, optimum.cost, 3.0e-5);
}
@ -75,10 +152,12 @@ public class NelderMeadTest
};
count = 0;
NelderMead nm = new NelderMead();
PointCostPair optimum =
new NelderMead().minimizes(powell, 200, new ValueChecker(1.0e-3),
new double[] { 3.0, -1.0, 0.0, 1.0 },
new double[] { 4.0, 0.0, 1.0, 2.0 });
nm.minimizes(powell, 200, new ValueChecker(1.0e-3),
new double[] { 3.0, -1.0, 0.0, 1.0 },
new double[] { 4.0, 0.0, 1.0, 2.0 },
1, 1642738l);
assertTrue(count < 150);
assertEquals(0.0, optimum.cost, 6.0e-4);
assertEquals(0.0, optimum.point[0], 0.07);