Drop repeated tests inside "SimplexOptimizerTest.Task" class.

Explicitly specify the initial simplex side (as test parameter).
This commit is contained in:
Gilles Sadowski 2021-08-11 19:58:28 +02:00
parent 9b7a2c8edc
commit 6a7b4ccbe3
1 changed files with 57 additions and 73 deletions

View File

@ -17,7 +17,6 @@
package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv; package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;
import java.util.Arrays; import java.util.Arrays;
import org.opentest4j.AssertionFailedError;
import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ParameterContext; import org.junit.jupiter.api.extension.ParameterContext;
@ -83,6 +82,7 @@ public class SimplexOptimizerTest {
@ParameterizedTest @ParameterizedTest
@CsvFileSource(resources = NELDER_MEAD_INPUT_FILE) @CsvFileSource(resources = NELDER_MEAD_INPUT_FILE)
void testFunctionWithNelderMead(@AggregateWith(TaskAggregator.class) Task task) { void testFunctionWithNelderMead(@AggregateWith(TaskAggregator.class) Task task) {
// task.checkAlongLine(1000, true);
task.run(new NelderMeadTransform()); task.run(new NelderMeadTransform());
} }
@ -106,8 +106,6 @@ public class SimplexOptimizerTest {
private static final int FUNC_EVAL_DEBUG = 80000; private static final int FUNC_EVAL_DEBUG = 80000;
/** Default convergence criterion. */ /** Default convergence criterion. */
private static final double CONVERGENCE_CHECK = 1e-9; private static final double CONVERGENCE_CHECK = 1e-9;
/** Default simplex size. */
private static final double SIDE_LENGTH = 1;
/** Default cooling factor. */ /** Default cooling factor. */
private static final double SA_COOL_FACTOR = 0.5; private static final double SA_COOL_FACTOR = 0.5;
/** Default acceptance probability at beginning of SA. */ /** Default acceptance probability at beginning of SA. */
@ -115,7 +113,7 @@ public class SimplexOptimizerTest {
/** Default acceptance probability at end of SA. */ /** Default acceptance probability at end of SA. */
private static final double SA_END_PROB = 1e-20; private static final double SA_END_PROB = 1e-20;
/** Function. */ /** Function. */
private final MultivariateFunction f; private final MultivariateFunction function;
/** Initial value. */ /** Initial value. */
private final double[] start; private final double[] start;
/** Optimum. */ /** Optimum. */
@ -124,46 +122,45 @@ public class SimplexOptimizerTest {
private final double pointTolerance; private final double pointTolerance;
/** Allowed function evaluations. */ /** Allowed function evaluations. */
private final int functionEvaluations; private final int functionEvaluations;
/** Repeats on failure. */ /** Side length of initial simplex. */
private final int repeatsOnFailure; private final double simplexSideLength;
/** Range of random noise. */ /** Range of random noise. */
private final double jitter; private final double jitter;
/** Whether to perform simulated annealing. */ /** Whether to perform simulated annealing. */
private final boolean withSA; private final boolean withSA;
/** /**
* @param f Test function. * @param function Test function.
* @param start Start point. * @param start Start point.
* @param optimum Optimum. * @param optimum Optimum.
* @param pointTolerance Allowed distance between result and * @param pointTolerance Allowed distance between result and
* {@code optimum}. * {@code optimum}.
* @param functionEvaluations Allowed number of function evaluations. * @param functionEvaluations Allowed number of function evaluations.
* @param repeatsOnFailure Maximum number of times to rerun when an * @param simplexSideLength Side length of initial simplex.
* {@link AssertionFailedError} is thrown.
* @param jitter Size of random jitter. * @param jitter Size of random jitter.
* @param withSA Whether to perform simulated annealing. * @param withSA Whether to perform simulated annealing.
*/ */
Task(MultivariateFunction f, Task(MultivariateFunction function,
double[] start, double[] start,
double[] optimum, double[] optimum,
double pointTolerance, double pointTolerance,
int functionEvaluations, int functionEvaluations,
int repeatsOnFailure, double simplexSideLength,
double jitter, double jitter,
boolean withSA) { boolean withSA) {
this.f = f; this.function = function;
this.start = start; this.start = start;
this.optimum = optimum; this.optimum = optimum;
this.pointTolerance = pointTolerance; this.pointTolerance = pointTolerance;
this.functionEvaluations = functionEvaluations; this.functionEvaluations = functionEvaluations;
this.repeatsOnFailure = repeatsOnFailure; this.simplexSideLength = simplexSideLength;
this.jitter = jitter; this.jitter = jitter;
this.withSA = withSA; this.withSA = withSA;
} }
@Override @Override
public String toString() { public String toString() {
return f.toString(); return function.toString();
} }
/** /**
@ -176,67 +173,54 @@ public class SimplexOptimizerTest {
// required by the current code. // required by the current code.
final int maxEval = Math.max(functionEvaluations, FUNC_EVAL_DEBUG); final int maxEval = Math.max(functionEvaluations, FUNC_EVAL_DEBUG);
int currentRetry = -1; final String name = function.toString();
AssertionFailedError lastFailure = null; final int dim = start.length;
while (currentRetry++ <= repeatsOnFailure) {
try {
final String name = f.toString();
final int dim = start.length;
final SimulatedAnnealing sa; final SimulatedAnnealing sa;
final PopulationSize popSize; final PopulationSize popSize;
if (withSA) { if (withSA) {
final SimulatedAnnealing.CoolingSchedule coolSched = final SimulatedAnnealing.CoolingSchedule coolSched =
SimulatedAnnealing.CoolingSchedule.decreasingExponential(SA_COOL_FACTOR); SimulatedAnnealing.CoolingSchedule.decreasingExponential(SA_COOL_FACTOR);
sa = new SimulatedAnnealing(dim, sa = new SimulatedAnnealing(dim,
SA_START_PROB, SA_START_PROB,
SA_END_PROB, SA_END_PROB,
coolSched, coolSched,
RandomSource.KISS.create()); RandomSource.KISS.create());
popSize = new PopulationSize(dim); popSize = new PopulationSize(dim);
} else { } else {
sa = null; sa = null;
popSize = null; popSize = null;
}
final SimplexOptimizer optim = new SimplexOptimizer(-1, CONVERGENCE_CHECK);
final Simplex initialSimplex =
Simplex.alongAxes(OptimTestUtils.point(dim,
SIDE_LENGTH,
jitter));
final double[] startPoint = OptimTestUtils.point(start, jitter);
final PointValuePair result =
optim.optimize(new MaxEval(maxEval),
new ObjectiveFunction(f),
GoalType.MINIMIZE,
new InitialGuess(startPoint),
initialSimplex,
factory,
sa,
popSize);
final double[] endPoint = result.getPoint();
final double funcValue = result.getValue();
final double dist = MathArrays.distance(optimum, endPoint);
Assertions.assertEquals(0d, dist, pointTolerance,
name + ": distance to optimum" +
" f(" + Arrays.toString(endPoint) + ")=" +
funcValue);
final int nEval = optim.getEvaluations();
Assertions.assertTrue(nEval < functionEvaluations,
name + ": nEval=" + nEval);
break; // Assertions passed: Retry not neccessary.
} catch (AssertionFailedError e) {
if (currentRetry >= repeatsOnFailure) {
// Allowed repeats have been exhausted: Bail out.
throw e;
}
}
} }
final SimplexOptimizer optim = new SimplexOptimizer(-1, CONVERGENCE_CHECK);
final Simplex initialSimplex =
Simplex.alongAxes(OptimTestUtils.point(dim,
simplexSideLength,
jitter));
final double[] startPoint = OptimTestUtils.point(start, jitter);
final PointValuePair result =
optim.optimize(new MaxEval(maxEval),
new ObjectiveFunction(function),
GoalType.MINIMIZE,
new InitialGuess(startPoint),
initialSimplex,
factory,
sa,
popSize);
final double[] endPoint = result.getPoint();
final double funcValue = result.getValue();
final double dist = MathArrays.distance(optimum, endPoint);
Assertions.assertEquals(0d, dist, pointTolerance,
name + ": distance to optimum" +
" f(" + Arrays.toString(endPoint) + ")=" +
funcValue);
final int nEval = optim.getEvaluations();
Assertions.assertTrue(nEval < functionEvaluations,
name + ": nEval=" + nEval);
} }
} }
@ -257,7 +241,7 @@ public class SimplexOptimizerTest {
final double[] optimum = toArrayOfDoubles(a.getString(index++), dim); final double[] optimum = toArrayOfDoubles(a.getString(index++), dim);
final double pointTol = a.getDouble(index++); final double pointTol = a.getDouble(index++);
final int funcEval = a.getInteger(index++); final int funcEval = a.getInteger(index++);
final int repeat = a.getInteger(index++); final double sideLength = a.getDouble(index++);
final double jitter = a.getDouble(index++); final double jitter = a.getDouble(index++);
final boolean withSA = a.getBoolean(index++); final boolean withSA = a.getBoolean(index++);
@ -266,7 +250,7 @@ public class SimplexOptimizerTest {
optimum, optimum,
pointTol, pointTol,
funcEval, funcEval,
repeat, sideLength,
jitter, jitter,
withSA); withSA);
} }