Consistency check to ensure that "TestFunction" implementations are correct.
Call is commented out (it is mostly intended for a one-time visual check).
This commit is contained in:
parent
f29ebd2e13
commit
59e9604dd8
|
@ -17,6 +17,10 @@
|
|||
package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.io.PrintWriter;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ParameterContext;
|
||||
|
@ -222,6 +226,111 @@ public class SimplexOptimizerTest {
|
|||
Assertions.assertTrue(nEval < functionEvaluations,
|
||||
name + ": nEval=" + nEval);
|
||||
}
|
||||
|
||||
/**
|
||||
* Asserts that the lowest function value (along a line starting at
|
||||
* {@link #start} is reached at the {@link #optimum}.
|
||||
*
|
||||
* @param numPoints Number of points at which to evaluate the function.
|
||||
* @param plot Whether to generate a file (for visual debugging).
|
||||
*/
|
||||
public void checkAlongLine(int numPoints,
|
||||
boolean plot) {
|
||||
if (plot) {
|
||||
final String name = createPlotBasename(function, start, optimum);
|
||||
try (PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(name)))) {
|
||||
checkAlongLine(numPoints, out);
|
||||
} catch (IOException e) {
|
||||
Assertions.fail(e.getMessage());
|
||||
}
|
||||
} else {
|
||||
checkAlongLine(numPoints, null);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the values of the function along the straight line between
|
||||
* {@link #startPoint} and {@link #optimum} and asserts that the value
|
||||
* at the latter is smaller than at any other points along the line.
|
||||
* <p>
|
||||
* If the {@code output} stream is not {@code null}, two columns are
|
||||
* printed:
|
||||
* <ol>
|
||||
* <li>parameter in the {@code [0, 1]} interval (0 at {@link #startPoint}
|
||||
* and 1 at {@link #optimum}),</li>
|
||||
* <li>function value at {@code t * (optimum - startPoint)}.</li>
|
||||
* </ol>
|
||||
*
|
||||
* @param numPoints Number of points to evaluate between {@link #start}
|
||||
* and {@link #optimum}.
|
||||
* @param output Output stream.
|
||||
*/
|
||||
private void checkAlongLine(int numPoints,
|
||||
PrintWriter output) {
|
||||
final double delta = 1d / numPoints;
|
||||
|
||||
final int dim = start.length;
|
||||
final double[] dir = new double[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
dir[i] = optimum[i] - start[i];
|
||||
}
|
||||
|
||||
double[] minPoint = null;
|
||||
double minValue = Double.POSITIVE_INFINITY;
|
||||
int count = 0;
|
||||
while (count <= numPoints) {
|
||||
final double[] p = new double[dim];
|
||||
final double t = count * delta;
|
||||
for (int i = 0; i < dim; i++) {
|
||||
p[i] = start[i] + t * dir[i];
|
||||
}
|
||||
|
||||
final double value = function.value(p);
|
||||
if (value <= minValue) {
|
||||
minValue = value;
|
||||
minPoint = p;
|
||||
}
|
||||
|
||||
if (output != null) {
|
||||
output.println(t + " " + value);
|
||||
}
|
||||
|
||||
++count;
|
||||
}
|
||||
|
||||
final double tol = 1e-15;
|
||||
Assertions.assertArrayEquals(optimum, minPoint, tol,
|
||||
"Minimum: f(" + Arrays.toString(minPoint) + ")=" + minValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a string suitable as a file name.
|
||||
* <p>
|
||||
* Brackets are removed; space, slash, "=" sign and comma
|
||||
* characters are converted to underscores.
|
||||
*
|
||||
* @param f Function.
|
||||
* @param start Start point.
|
||||
* @param end End point.
|
||||
* @return a string.
|
||||
*/
|
||||
private static String createPlotBasename(MultivariateFunction f,
|
||||
double[] start,
|
||||
double[] end) {
|
||||
final String s = f.toString() + "__" +
|
||||
Arrays.toString(start) + "__" +
|
||||
Arrays.toString(end) + ".dat";
|
||||
|
||||
final String repl = "_";
|
||||
return s
|
||||
.replaceAll("\\[", "")
|
||||
.replaceAll("\\]", "")
|
||||
.replaceAll("=", repl)
|
||||
.replaceAll(",\\s+", repl)
|
||||
.replaceAll(",", repl)
|
||||
.replaceAll("\\s", repl)
|
||||
.replaceAll("/", repl);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
Loading…
Reference in New Issue