diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/optim/nonlinear/scalar/noderiv/SimplexOptimizerTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/optim/nonlinear/scalar/noderiv/SimplexOptimizerTest.java index 25dd7c897..732ba42de 100644 --- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/optim/nonlinear/scalar/noderiv/SimplexOptimizerTest.java +++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/optim/nonlinear/scalar/noderiv/SimplexOptimizerTest.java @@ -17,6 +17,10 @@ package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv; import java.util.Arrays; +import java.io.PrintWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ParameterContext; @@ -222,6 +226,111 @@ public class SimplexOptimizerTest { Assertions.assertTrue(nEval < functionEvaluations, name + ": nEval=" + nEval); } + + /** + * Asserts that the lowest function value (along a line starting at + * {@link #start} is reached at the {@link #optimum}. + * + * @param numPoints Number of points at which to evaluate the function. + * @param plot Whether to generate a file (for visual debugging). + */ + public void checkAlongLine(int numPoints, + boolean plot) { + if (plot) { + final String name = createPlotBasename(function, start, optimum); + try (PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(name)))) { + checkAlongLine(numPoints, out); + } catch (IOException e) { + Assertions.fail(e.getMessage()); + } + } else { + checkAlongLine(numPoints, null); + } + } + + /** + * Computes the values of the function along the straight line between + * {@link #startPoint} and {@link #optimum} and asserts that the value + * at the latter is smaller than at any other points along the line. + *

+ * If the {@code output} stream is not {@code null}, two columns are + * printed: + *

    + *
  1. parameter in the {@code [0, 1]} interval (0 at {@link #startPoint} + * and 1 at {@link #optimum}),
  2. + *
  3. function value at {@code t * (optimum - startPoint)}.
  4. + *
+ * + * @param numPoints Number of points to evaluate between {@link #start} + * and {@link #optimum}. + * @param output Output stream. + */ + private void checkAlongLine(int numPoints, + PrintWriter output) { + final double delta = 1d / numPoints; + + final int dim = start.length; + final double[] dir = new double[dim]; + for (int i = 0; i < dim; i++) { + dir[i] = optimum[i] - start[i]; + } + + double[] minPoint = null; + double minValue = Double.POSITIVE_INFINITY; + int count = 0; + while (count <= numPoints) { + final double[] p = new double[dim]; + final double t = count * delta; + for (int i = 0; i < dim; i++) { + p[i] = start[i] + t * dir[i]; + } + + final double value = function.value(p); + if (value <= minValue) { + minValue = value; + minPoint = p; + } + + if (output != null) { + output.println(t + " " + value); + } + + ++count; + } + + final double tol = 1e-15; + Assertions.assertArrayEquals(optimum, minPoint, tol, + "Minimum: f(" + Arrays.toString(minPoint) + ")=" + minValue); + } + + /** + * Generates a string suitable as a file name. + *

+ * Brackets are removed; space, slash, "=" sign and comma + * characters are converted to underscores. + * + * @param f Function. + * @param start Start point. + * @param end End point. + * @return a string. + */ + private static String createPlotBasename(MultivariateFunction f, + double[] start, + double[] end) { + final String s = f.toString() + "__" + + Arrays.toString(start) + "__" + + Arrays.toString(end) + ".dat"; + + final String repl = "_"; + return s + .replaceAll("\\[", "") + .replaceAll("\\]", "") + .replaceAll("=", repl) + .replaceAll(",\\s+", repl) + .replaceAll(",", repl) + .replaceAll("\\s", repl) + .replaceAll("/", repl); + } } /**