Consistency check to ensure that "TestFunction" implementations are correct.

Call is commented out (it is mostly intended for a one-time visual check).
This commit is contained in:
Gilles Sadowski 2021-08-12 01:59:37 +02:00
parent f29ebd2e13
commit 59e9604dd8
1 changed files with 109 additions and 0 deletions

View File

@ -17,6 +17,10 @@
package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;
import java.util.Arrays;
import java.io.PrintWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ParameterContext;
@ -222,6 +226,111 @@ public class SimplexOptimizerTest {
Assertions.assertTrue(nEval < functionEvaluations,
name + ": nEval=" + nEval);
}
/**
* Asserts that the lowest function value (along a line starting at
* {@link #start} is reached at the {@link #optimum}.
*
* @param numPoints Number of points at which to evaluate the function.
* @param plot Whether to generate a file (for visual debugging).
*/
public void checkAlongLine(int numPoints,
boolean plot) {
if (plot) {
final String name = createPlotBasename(function, start, optimum);
try (PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(name)))) {
checkAlongLine(numPoints, out);
} catch (IOException e) {
Assertions.fail(e.getMessage());
}
} else {
checkAlongLine(numPoints, null);
}
}
/**
* Computes the values of the function along the straight line between
* {@link #startPoint} and {@link #optimum} and asserts that the value
* at the latter is smaller than at any other points along the line.
* <p>
* If the {@code output} stream is not {@code null}, two columns are
* printed:
* <ol>
* <li>parameter in the {@code [0, 1]} interval (0 at {@link #startPoint}
* and 1 at {@link #optimum}),</li>
* <li>function value at {@code t * (optimum - startPoint)}.</li>
* </ol>
*
* @param numPoints Number of points to evaluate between {@link #start}
* and {@link #optimum}.
* @param output Output stream.
*/
private void checkAlongLine(int numPoints,
PrintWriter output) {
final double delta = 1d / numPoints;
final int dim = start.length;
final double[] dir = new double[dim];
for (int i = 0; i < dim; i++) {
dir[i] = optimum[i] - start[i];
}
double[] minPoint = null;
double minValue = Double.POSITIVE_INFINITY;
int count = 0;
while (count <= numPoints) {
final double[] p = new double[dim];
final double t = count * delta;
for (int i = 0; i < dim; i++) {
p[i] = start[i] + t * dir[i];
}
final double value = function.value(p);
if (value <= minValue) {
minValue = value;
minPoint = p;
}
if (output != null) {
output.println(t + " " + value);
}
++count;
}
final double tol = 1e-15;
Assertions.assertArrayEquals(optimum, minPoint, tol,
"Minimum: f(" + Arrays.toString(minPoint) + ")=" + minValue);
}
/**
* Generates a string suitable as a file name.
* <p>
* Brackets are removed; space, slash, "=" sign and comma
* characters are converted to underscores.
*
* @param f Function.
* @param start Start point.
* @param end End point.
* @return a string.
*/
private static String createPlotBasename(MultivariateFunction f,
double[] start,
double[] end) {
final String s = f.toString() + "__" +
Arrays.toString(start) + "__" +
Arrays.toString(end) + ".dat";
final String repl = "_";
return s
.replaceAll("\\[", "")
.replaceAll("\\]", "")
.replaceAll("=", repl)
.replaceAll(",\\s+", repl)
.replaceAll(",", repl)
.replaceAll("\\s", repl)
.replaceAll("/", repl);
}
}
/**