diff --git a/src/test/org/apache/commons/math/optimization/direct/NelderMeadTest.java b/src/test/org/apache/commons/math/optimization/direct/NelderMeadTest.java index 1d02e14f7..6ec68c807 100644 --- a/src/test/org/apache/commons/math/optimization/direct/NelderMeadTest.java +++ b/src/test/org/apache/commons/math/optimization/direct/NelderMeadTest.java @@ -29,7 +29,11 @@ import org.apache.commons.math.MathException; import org.apache.commons.math.MaxEvaluationsExceededException; import org.apache.commons.math.MaxIterationsExceededException; import org.apache.commons.math.analysis.MultivariateRealFunction; +import org.apache.commons.math.analysis.MultivariateVectorialFunction; +import org.apache.commons.math.linear.Array2DRowRealMatrix; +import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.optimization.GoalType; +import org.apache.commons.math.optimization.LeastSquaresConverter; import org.apache.commons.math.optimization.OptimizationException; import org.apache.commons.math.optimization.RealPointValuePair; import org.apache.commons.math.optimization.SimpleRealPointChecker; @@ -173,6 +177,86 @@ public class NelderMeadTest { } + @Test + public void testLeastSquares1() + throws FunctionEvaluationException, ConvergenceException { + + final RealMatrix factors = + new Array2DRowRealMatrix(new double[][] { + { 1.0, 0.0 }, + { 0.0, 1.0 } + }, false); + LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() { + public double[] value(double[] variables) { + return factors.operate(variables); + } + }, new double[] { 2.0, -3.0 }); + NelderMead optimizer = new NelderMead(); + optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6)); + optimizer.setMaxIterations(200); + RealPointValuePair optimum = + optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 }); + assertEquals( 2.0, optimum.getPointRef()[0], 3.0e-5); + assertEquals(-3.0, optimum.getPointRef()[1], 4.0e-4); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 80); + assertTrue(optimum.getValue() < 1.0e-6); + } + + @Test + public void testLeastSquares2() + throws FunctionEvaluationException, ConvergenceException { + + final RealMatrix factors = + new Array2DRowRealMatrix(new double[][] { + { 1.0, 0.0 }, + { 0.0, 1.0 } + }, false); + LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() { + public double[] value(double[] variables) { + return factors.operate(variables); + } + }, new double[] { 2.0, -3.0 }, new double[] { 10.0, 0.1 }); + NelderMead optimizer = new NelderMead(); + optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6)); + optimizer.setMaxIterations(200); + RealPointValuePair optimum = + optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 }); + assertEquals( 2.0, optimum.getPointRef()[0], 5.0e-5); + assertEquals(-3.0, optimum.getPointRef()[1], 8.0e-4); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 80); + assertTrue(optimum.getValue() < 1.0e-6); + } + + @Test + public void testLeastSquares3() + throws FunctionEvaluationException, ConvergenceException { + + final RealMatrix factors = + new Array2DRowRealMatrix(new double[][] { + { 1.0, 0.0 }, + { 0.0, 1.0 } + }, false); + LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() { + public double[] value(double[] variables) { + return factors.operate(variables); + } + }, new double[] { 2.0, -3.0 }, new Array2DRowRealMatrix(new double [][] { + { 1.0, 1.2 }, { 1.2, 2.0 } + })); + NelderMead optimizer = new NelderMead(); + optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6)); + optimizer.setMaxIterations(200); + RealPointValuePair optimum = + optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 }); + assertEquals( 2.0, optimum.getPointRef()[0], 2.0e-3); + assertEquals(-3.0, optimum.getPointRef()[1], 8.0e-4); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 80); + assertTrue(optimum.getValue() < 1.0e-6); + } + @Test(expected = MaxIterationsExceededException.class) public void testMaxIterations() throws MathException { try {