From 0d6a91f69853e57e7811438dde18270eaf161000 Mon Sep 17 00:00:00 2001 From: Gilles Sadowski Date: Mon, 25 Oct 2010 09:42:33 +0000 Subject: [PATCH] MATH-428 Refactoring of "DirectSearchOptimizer" to separate the optimization and simplex management aspects. Old classes are deprecated. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1027007 13f79535-47bb-0310-9956-ffa450edef68 --- .../MathIllegalArgumentException.java | 16 +- .../math/optimization/RealPointValuePair.java | 6 +- .../optimization/direct/AbstractSimplex.java | 340 ++++++++++++++++++ .../direct/DirectSearchOptimizer.java | 1 + .../optimization/direct/MultiDirectional.java | 1 + .../direct/MultiDirectionalSimplex.java | 195 ++++++++++ .../math/optimization/direct/NelderMead.java | 1 + .../direct/NelderMeadSimplex.java | 253 +++++++++++++ .../optimization/direct/SimplexOptimizer.java | 161 +++++++++ src/site/xdoc/changes.xml | 56 +-- .../SimplexOptimizerMultiDirectionalTest.java | 189 ++++++++++ .../SimplexOptimizerNelderMeadTest.java | 262 ++++++++++++++ 12 files changed, 1448 insertions(+), 33 deletions(-) create mode 100644 src/main/java/org/apache/commons/math/optimization/direct/AbstractSimplex.java create mode 100644 src/main/java/org/apache/commons/math/optimization/direct/MultiDirectionalSimplex.java create mode 100644 src/main/java/org/apache/commons/math/optimization/direct/NelderMeadSimplex.java create mode 100644 src/main/java/org/apache/commons/math/optimization/direct/SimplexOptimizer.java create mode 100644 src/test/java/org/apache/commons/math/optimization/direct/SimplexOptimizerMultiDirectionalTest.java create mode 100644 src/test/java/org/apache/commons/math/optimization/direct/SimplexOptimizerNelderMeadTest.java diff --git a/src/main/java/org/apache/commons/math/exception/MathIllegalArgumentException.java b/src/main/java/org/apache/commons/math/exception/MathIllegalArgumentException.java index 7e8586f41..8e5d45bd2 100644 --- a/src/main/java/org/apache/commons/math/exception/MathIllegalArgumentException.java +++ b/src/main/java/org/apache/commons/math/exception/MathIllegalArgumentException.java @@ -24,9 +24,9 @@ import org.apache.commons.math.exception.util.Localizable; /** * Base class for all preconditions violation exceptions. - * This class is not intended to be instantiated directly: it should serve - * as a base class to create all the exceptions that share the semantics of - * the standard {@link IllegalArgumentException}, but must also provide a + * In most cases, this class should not be instantiated directly: it should + * serve as a base class to create all the exceptions that share the semantics + * of the standard {@link IllegalArgumentException}, but must also provide a * localized message. * * @since 2.2 @@ -56,9 +56,9 @@ public class MathIllegalArgumentException extends IllegalArgumentException { * @param general Message pattern explaining the cause of the error. * @param args Arguments. */ - protected MathIllegalArgumentException(Localizable specific, - Localizable general, - Object ... args) { + public MathIllegalArgumentException(Localizable specific, + Localizable general, + Object ... args) { this.specific = specific; this.general = general; arguments = ArgUtils.flatten(args); @@ -67,8 +67,8 @@ public class MathIllegalArgumentException extends IllegalArgumentException { * @param general Message pattern explaining the cause of the error. * @param args Arguments. */ - protected MathIllegalArgumentException(Localizable general, - Object ... args) { + public MathIllegalArgumentException(Localizable general, + Object ... args) { this(null, general, args); } diff --git a/src/main/java/org/apache/commons/math/optimization/RealPointValuePair.java b/src/main/java/org/apache/commons/math/optimization/RealPointValuePair.java index 0c901c75f..4e545d9e0 100644 --- a/src/main/java/org/apache/commons/math/optimization/RealPointValuePair.java +++ b/src/main/java/org/apache/commons/math/optimization/RealPointValuePair.java @@ -31,10 +31,8 @@ import java.io.Serializable; public class RealPointValuePair implements Serializable { /** Serializable version identifier. */ private static final long serialVersionUID = 1003888396256744753L; - /** Point coordinates. */ private final double[] point; - /** Value of the objective function at the point. */ private final double value; @@ -45,7 +43,7 @@ public class RealPointValuePair implements Serializable { */ public RealPointValuePair(final double[] point, final double value) { this.point = (point == null) ? null : point.clone(); - this.value = value; + this.value = value; } /** Build a point/objective function value pair. @@ -60,7 +58,7 @@ public class RealPointValuePair implements Serializable { this.point = copyArray ? ((point == null) ? null : point.clone()) : point; - this.value = value; + this.value = value; } /** Get the point. diff --git a/src/main/java/org/apache/commons/math/optimization/direct/AbstractSimplex.java b/src/main/java/org/apache/commons/math/optimization/direct/AbstractSimplex.java new file mode 100644 index 000000000..f0a73c74d --- /dev/null +++ b/src/main/java/org/apache/commons/math/optimization/direct/AbstractSimplex.java @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.optimization.direct; + +import java.util.Arrays; +import java.util.Comparator; + +import org.apache.commons.math.FunctionEvaluationException; +import org.apache.commons.math.analysis.MultivariateRealFunction; +import org.apache.commons.math.exception.NotStrictlyPositiveException; +import org.apache.commons.math.exception.DimensionMismatchException; +import org.apache.commons.math.exception.ZeroException; +import org.apache.commons.math.exception.OutOfRangeException; +import org.apache.commons.math.exception.NullArgumentException; +import org.apache.commons.math.exception.MathIllegalArgumentException; +import org.apache.commons.math.exception.util.LocalizedFormats; +import org.apache.commons.math.optimization.RealPointValuePair; + +/** + * This class implements the simplex concept. + * It is intended to be used in conjunction with {@link SimplexOptimizer}. + *
+ * The initial configuration of the simplex is set by the constructors + * {@link #AbstractSimplex(double[])} or {@link #AbstractSimplex(double[][])}. + * The other {@link #AbstractSimplex(int) constructor} will set all steps + * to 1, thus building a default configuration from a unit hypercube. + *
+ * Users must call the {@link #build(double[]) build} method in order + * to create the data structure that will be acted on by the other methods of + * this class. + * + * @see SimplexOptimizer + * @version $Revision$ $Date$ + * @since 3.0 + */ +public abstract class AbstractSimplex { + /** Simplex. */ + private RealPointValuePair[] simplex; + /** Start simplex configuration. */ + private double[][] startConfiguration; + /** Simplex dimension (must be equal to {@code simplex.length - 1}). */ + private final int dimension; + + /** + * Default constructor. + * Build a unit hypercube. + * + * @param n Dimension of the simplex. + */ + protected AbstractSimplex(int n) { + this(createUnitHypercubeSteps(n)); + } + + /** + * The start configuration for simplex is built from a box parallel to + * the canonical axes of the space. The simplex is the subset of vertices + * of a box parallel to the canonical axes. It is built as the path followed + * while traveling from one vertex of the box to the diagonally opposite + * vertex moving only along the box edges. The first vertex of the box will + * be located at the start point of the optimization. + * As an example, in dimension 3 a simplex has 4 vertices. Setting the + * steps to (1, 10, 2) and the start point to (1, 1, 1) would imply the + * start simplex would be: { (1, 1, 1), (2, 1, 1), (2, 11, 1), (2, 11, 3) }. + * The first vertex would be set to the start point at (1, 1, 1) and the + * last vertex would be set to the diagonally opposite vertex at (2, 11, 3). + * + * @param steps Steps along the canonical axes representing box edges. They + * may be negative but not zero. + * @throws NullArgumentException if {@code steps} is {@code null}. + * @throws ZeroException if one of the steps is zero. + */ + protected AbstractSimplex(final double[] steps) { + if (steps == null) { + throw new NullArgumentException(); + } + if (steps.length == 0) { + throw new ZeroException(); + } + dimension = steps.length; + + // Only the relative position of the n final vertices with respect + // to the first one are stored. + startConfiguration = new double[dimension][dimension]; + for (int i = 0; i < dimension; i++) { + final double[] vertexI = startConfiguration[i]; + for (int j = 0; j < i + 1; j++) { + if (steps[j] == 0) { + throw new ZeroException(LocalizedFormats.EQUAL_VERTICES_IN_SIMPLEX); + } + System.arraycopy(steps, 0, vertexI, 0, j + 1); + } + } + } + + /** + * The real initial simplex will be set up by moving the reference + * simplex such that its first point is located at the start point of the + * optimization. + * + * @param referenceSimplex Reference simplex. + * @throws NotStrictlyPositiveException if the reference simplex does not + * contain at least one point. + * @throws DimensionMismatchException if there is a dimension mismatch + * in the reference simplex. + * @throws IllegalArgumentException if one of its vertices is duplicated. + */ + protected AbstractSimplex(final double[][] referenceSimplex) { + if (referenceSimplex.length <= 0) { + throw new NotStrictlyPositiveException(LocalizedFormats.SIMPLEX_NEED_ONE_POINT, + referenceSimplex.length); + } + dimension = referenceSimplex.length - 1; + + // Only the relative position of the n final vertices with respect + // to the first one are stored. + startConfiguration = new double[dimension][dimension]; + final double[] ref0 = referenceSimplex[0]; + + // Loop over vertices. + for (int i = 0; i < referenceSimplex.length; i++) { + final double[] refI = referenceSimplex[i]; + + // Safety checks. + if (refI.length != dimension) { + throw new DimensionMismatchException(refI.length, dimension); + } + for (int j = 0; j < i; j++) { + final double[] refJ = referenceSimplex[j]; + boolean allEquals = true; + for (int k = 0; k < dimension; k++) { + if (refI[k] != refJ[k]) { + allEquals = false; + break; + } + } + if (allEquals) { + throw new MathIllegalArgumentException(LocalizedFormats.EQUAL_VERTICES_IN_SIMPLEX, + i, j); + } + } + + // Store vertex i position relative to vertex 0 position. + if (i > 0) { + final double[] confI = startConfiguration[i - 1]; + for (int k = 0; k < dimension; k++) { + confI[k] = refI[k] - ref0[k]; + } + } + } + } + + /** + * Get simplex dimension. + * + * @return the dimension of the simplex. + */ + public int getDimension() { + return dimension; + } + + /** + * Get simplex size. + * After calling the {@link #build(double[]) build} method, this method will + * will be equivalent to {@code getDimension() + 1}. + * + * @return the size of the simplex. + */ + public int getSize() { + return simplex.length; + } + + /** + * Compute the next simplex of the algorithm. + * + * @param evaluationFunction Evaluation function. + * @param comparator Comparator to use to sort simplex vertices from best + * to worst. + * @throws FunctionEvaluationException if the function cannot be evaluated + * at some point. + * @throws org.apache.commons.math.exception.TooManyEvaluationsException + * if the algorithm fails to converge. + */ + public abstract void iterate(final MultivariateRealFunction evaluationFunction, + final Comparator comparator) + throws FunctionEvaluationException; + + /** + * Build an initial simplex. + * + * @param startPoint First point of the simplex. + * @throws DimensionMismatchException if the start point does not match + * simplex dimension. + */ + public void build(final double[] startPoint) { + if (dimension != startPoint.length) { + throw new DimensionMismatchException(dimension, startPoint.length); + } + + // Set first vertex. + simplex = new RealPointValuePair[dimension + 1]; + simplex[0] = new RealPointValuePair(startPoint, Double.NaN); + + // Set remaining vertices. + for (int i = 0; i < dimension; i++) { + final double[] confI = startConfiguration[i]; + final double[] vertexI = new double[dimension]; + for (int k = 0; k < dimension; k++) { + vertexI[k] = startPoint[k] + confI[k]; + } + simplex[i + 1] = new RealPointValuePair(vertexI, Double.NaN); + } + } + + /** + * Evaluate all the non-evaluated points of the simplex. + * + * @param evaluationFunction Evaluation function. + * @param comparator Comparator to use to sort simplex vertices from best to worst. + * @throws FunctionEvaluationException if no value can be computed for the parameters. + * @throws org.apache.commons.math.exception.TooManyEvaluationsException + * if the maximal number of evaluations is exceeded. + */ + public void evaluate(final MultivariateRealFunction evaluationFunction, + final Comparator comparator) + throws FunctionEvaluationException { + + // Evaluate the objective function at all non-evaluated simplex points. + for (int i = 0; i < simplex.length; i++) { + final RealPointValuePair vertex = simplex[i]; + final double[] point = vertex.getPointRef(); + if (Double.isNaN(vertex.getValue())) { + simplex[i] = new RealPointValuePair(point, evaluationFunction.value(point), false); + } + } + + // Sort the simplex from best to worst. + Arrays.sort(simplex, comparator); + } + + /** + * Replace the worst point of the simplex by a new point. + * + * @param pointValuePair Point to insert. + * @param comparator Comparator to use for sorting the simplex vertices + * from best to worst. + */ + protected void replaceWorstPoint(RealPointValuePair pointValuePair, + final Comparator comparator) { + for (int i = 0; i < dimension; i++) { + if (comparator.compare(simplex[i], pointValuePair) > 0) { + RealPointValuePair tmp = simplex[i]; + simplex[i] = pointValuePair; + pointValuePair = tmp; + } + } + simplex[dimension] = pointValuePair; + } + + /** + * Get the points of the simplex. + * + * @return all the simplex points. + */ + public RealPointValuePair[] getPoints() { + final RealPointValuePair[] copy = new RealPointValuePair[simplex.length]; + System.arraycopy(simplex, 0, copy, 0, simplex.length); + return copy; + } + + /** + * Get the simplex point stored at the requested {@code index}. + * + * @param index Location. + * @return the point at location {@code index}. + */ + public RealPointValuePair getPoint(int index) { + if (index < 0 || + index >= simplex.length) { + throw new OutOfRangeException(index, 0, simplex.length - 1); + } + return simplex[index]; + } + + /** + * Store a new point at location {@code index}. + * Note that no deep-copy of {@code point} is performed. + * + * @param index Location. + * @param point New value. + */ + protected void setPoint(int index, RealPointValuePair point) { + if (index < 0 || + index >= simplex.length) { + throw new OutOfRangeException(index, 0, simplex.length - 1); + } + simplex[index] = point; + } + + /** + * Replace all points. + * Note that no deep-copy of {@code points} is performed. + * + * @param points New Points. + */ + protected void setPoints(RealPointValuePair[] points) { + if (points.length != simplex.length) { + throw new DimensionMismatchException(points.length, simplex.length); + } + simplex = points; + } + + /** + * Create steps for a unit hypercube. + * + * @param n Dimension of the hypercube. + * @return unit steps. + */ + private static double[] createUnitHypercubeSteps(int n) { + final double[] steps = new double[n]; + for (int i = 0; i < n; i++) { + steps[i] = 1; + } + return steps; + } +} diff --git a/src/main/java/org/apache/commons/math/optimization/direct/DirectSearchOptimizer.java b/src/main/java/org/apache/commons/math/optimization/direct/DirectSearchOptimizer.java index b3d183b6a..17e378140 100644 --- a/src/main/java/org/apache/commons/math/optimization/direct/DirectSearchOptimizer.java +++ b/src/main/java/org/apache/commons/math/optimization/direct/DirectSearchOptimizer.java @@ -80,6 +80,7 @@ import org.apache.commons.math.optimization.SimpleScalarValueChecker; * @see MultiDirectional * @version $Revision$ $Date$ * @since 1.2 + * @deprecated in 2.2 (to be removed in 3.0). Please use {@link SimplexOptimizer} instead. */ public abstract class DirectSearchOptimizer extends BaseAbstractScalarOptimizer { diff --git a/src/main/java/org/apache/commons/math/optimization/direct/MultiDirectional.java b/src/main/java/org/apache/commons/math/optimization/direct/MultiDirectional.java index c279ddad1..90e0f72e5 100644 --- a/src/main/java/org/apache/commons/math/optimization/direct/MultiDirectional.java +++ b/src/main/java/org/apache/commons/math/optimization/direct/MultiDirectional.java @@ -29,6 +29,7 @@ import org.apache.commons.math.optimization.MultivariateRealOptimizer; * * @version $Revision$ $Date$ * @see NelderMead + * @deprecated in 2.2 (to be removed in 3.0). Please use {@link MultiDirectionalSimplex} instead. * @since 1.2 */ public class MultiDirectional extends DirectSearchOptimizer diff --git a/src/main/java/org/apache/commons/math/optimization/direct/MultiDirectionalSimplex.java b/src/main/java/org/apache/commons/math/optimization/direct/MultiDirectionalSimplex.java new file mode 100644 index 000000000..278d106e8 --- /dev/null +++ b/src/main/java/org/apache/commons/math/optimization/direct/MultiDirectionalSimplex.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.optimization.direct; + +import java.util.Comparator; + +import org.apache.commons.math.FunctionEvaluationException; +import org.apache.commons.math.analysis.MultivariateRealFunction; +import org.apache.commons.math.optimization.RealPointValuePair; + +/** + * This class implements the multi-directional direct search method. + * + * @version $Revision$ $Date$ + * @since 3.0 + */ +public class MultiDirectionalSimplex extends AbstractSimplex { + /** Default value for {@link #khi}: {@value}. */ + private static final double DEFAULT_KHI = 2; + /** Default value for {@link #gamma}: {@value}. */ + private static final double DEFAULT_GAMMA = 0.5; + /** Expansion coefficient. */ + private final double khi; + /** Contraction coefficient. */ + private final double gamma; + + /** + * Build a multi-directional simplex with default coefficients. + * The default values are 2.0 for khi and 0.5 for gamma. + * + * @param n Dimension of the simplex. + */ + public MultiDirectionalSimplex(final int n) { + this(n, DEFAULT_KHI, DEFAULT_GAMMA); + } + + /** + * Build a multi-directional simplex with specified coefficients. + * + * @param n Dimension of the simplex. See + * {@link AbstractSimplex#AbstractSimplex(int)}. + * @param khi Expansion coefficient. + * @param gamma Contraction coefficient. + */ + public MultiDirectionalSimplex(final int n, + final double khi, final double gamma) { + super(n); + + this.khi = khi; + this.gamma = gamma; + } + + /** + * Build a multi-directional simplex with default coefficients. + * The default values are 2.0 for khi and 0.5 for gamma. + * + * @param steps Steps along the canonical axes representing box edges. + * They may be negative but not zero. See + */ + public MultiDirectionalSimplex(final double[] steps) { + this(steps, DEFAULT_KHI, DEFAULT_GAMMA); + } + + /** + * Build a multi-directional simplex with specified coefficients. + * + * @param steps Steps along the canonical axes representing box edges. + * They may be negative but not zero. See + * {@link AbstractSimplex#AbstractSimplex(double[])}. + * @param khi Expansion coefficient. + * @param gamma Contraction coefficient. + */ + public MultiDirectionalSimplex(final double[] steps, + final double khi, final double gamma) { + super(steps); + + this.khi = khi; + this.gamma = gamma; + } + + /** + * Build a multi-directional simplex with default coefficients. + * The default values are 2.0 for khi and 0.5 for gamma. + * + * @param referenceSimplex Reference simplex. See + * {@link AbstractSimplex#AbstractSimplex(double[][])}. + */ + public MultiDirectionalSimplex(final double[][] referenceSimplex) { + this(referenceSimplex, DEFAULT_KHI, DEFAULT_GAMMA); + } + + /** + * Build a multi-directional simplex with specified coefficients. + * + * @param referenceSimplex Reference simplex. See + * {@link AbstractSimplex#AbstractSimplex(double[][])}. + * @param khi Expansion coefficient. + * @param gamma Contraction coefficient. + * @throws org.apache.commons.math.exception.NotStrictlyPositiveException + * if the reference simplex does not contain at least one point. + * @throws org.apache.commons.math.exception.DimensionMismatchException + * if there is a dimension mismatch in the reference simplex. + */ + public MultiDirectionalSimplex(final double[][] referenceSimplex, + final double khi, final double gamma) { + super(referenceSimplex); + + this.khi = khi; + this.gamma = gamma; + } + + /** {@inheritDoc} */ + @Override + public void iterate(final MultivariateRealFunction evaluationFunction, + final Comparator comparator) + throws FunctionEvaluationException { + // Save the original simplex. + final RealPointValuePair[] original = getPoints(); + final RealPointValuePair best = original[0]; + + // Perform a reflection step. + final RealPointValuePair reflected = evaluateNewSimplex(evaluationFunction, + original, 1, comparator); + if (comparator.compare(reflected, best) < 0) { + // Compute the expanded simplex. + final RealPointValuePair[] reflectedSimplex = getPoints(); + final RealPointValuePair expanded = evaluateNewSimplex(evaluationFunction, + original, khi, comparator); + if (comparator.compare(reflected, expanded) <= 0) { + // Keep the reflected simplex. + setPoints(reflectedSimplex); + } + // Keep the expanded simplex. + return; + } + + // Compute the contracted simplex. + final RealPointValuePair contracted = evaluateNewSimplex(evaluationFunction, + original, gamma, comparator); + } + + /** + * Compute and evaluate a new simplex. + * + * @param evaluationFunction Evaluation function. + * @param original Original simplex (to be preserved). + * @param coeff Linear coefficient. + * @param comparator Comparator to use to sort simplex vertices from best + * to poorest. + * @return the best point in the transformed simplex. + * @throws FunctionEvaluationException if the function cannot be + * evaluated at some point. + * @throws org.apache.commons.math.exception.TooManyEvaluationsException + * if the maximal number of evaluations is exceeded. + */ + private RealPointValuePair evaluateNewSimplex(final MultivariateRealFunction evaluationFunction, + final RealPointValuePair[] original, + final double coeff, + final Comparator comparator) + throws FunctionEvaluationException { + final double[] xSmallest = original[0].getPointRef(); + // Perform a linear transformation on all the simplex points, + // except the first one. + setPoint(0, original[0]); + final int dim = getDimension(); + for (int i = 1; i < getSize(); i++) { + final double[] xOriginal = original[i].getPointRef(); + final double[] xTransformed = new double[dim]; + for (int j = 0; j < dim; j++) { + xTransformed[j] = xSmallest[j] + coeff * (xSmallest[j] - xOriginal[j]); + } + setPoint(i, new RealPointValuePair(xTransformed, Double.NaN, false)); + } + + // Evaluate the simplex. + evaluate(evaluationFunction, comparator); + + return getPoint(0); + } +} diff --git a/src/main/java/org/apache/commons/math/optimization/direct/NelderMead.java b/src/main/java/org/apache/commons/math/optimization/direct/NelderMead.java index ff2c26239..8e2f37613 100644 --- a/src/main/java/org/apache/commons/math/optimization/direct/NelderMead.java +++ b/src/main/java/org/apache/commons/math/optimization/direct/NelderMead.java @@ -29,6 +29,7 @@ import org.apache.commons.math.optimization.MultivariateRealOptimizer; * @version $Revision$ $Date$ * @see MultiDirectional * @since 1.2 + * @deprecated in 2.2 (to be removed in 3.0). Please use {@link NelderMeadSimplex} instead. */ public class NelderMead extends DirectSearchOptimizer implements MultivariateRealOptimizer { diff --git a/src/main/java/org/apache/commons/math/optimization/direct/NelderMeadSimplex.java b/src/main/java/org/apache/commons/math/optimization/direct/NelderMeadSimplex.java new file mode 100644 index 000000000..0024ce763 --- /dev/null +++ b/src/main/java/org/apache/commons/math/optimization/direct/NelderMeadSimplex.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.optimization.direct; + +import java.util.Comparator; + +import org.apache.commons.math.FunctionEvaluationException; +import org.apache.commons.math.optimization.RealPointValuePair; +import org.apache.commons.math.analysis.MultivariateRealFunction; + +/** + * This class implements the Nelder-Mead simplex algorithm. + * + * @version $Revision$ $Date$ + * @since 3.0 + */ +public class NelderMeadSimplex extends AbstractSimplex { + /** Default value for {@link #rho}: {@value}. */ + private static final double DEFAULT_RHO = 1; + /** Default value for {@link #khi}: {@value}. */ + private static final double DEFAULT_KHI = 2; + /** Default value for {@link #gamma}: {@value}. */ + private static final double DEFAULT_GAMMA = 0.5; + /** Default value for {@link #sigma}: {@value}. */ + private static final double DEFAULT_SIGMA = 0.5; + /** Reflection coefficient. */ + private final double rho; + /** Expansion coefficient. */ + private final double khi; + /** Contraction coefficient. */ + private final double gamma; + /** Shrinkage coefficient. */ + private final double sigma; + + /** + * Build a Nelder-Mead simplex with default coefficients. + * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5 + * for both gamma and sigma. + * + * @param n Dimension of the simplex. + */ + public NelderMeadSimplex(final int n) { + this(n, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA); + } + + /** + * Build a Nelder-Mead simplex with specified coefficients. + * + * @param n Dimension of the simplex. See + * {@link AbstractSimplex#AbstractSimplex(int)}. + * @param rho Reflection coefficient. + * @param khi Expansion coefficient. + * @param gamma Contraction coefficient. + * @param sigma Shrinkage coefficient. + */ + public NelderMeadSimplex(final int n, + final double rho, final double khi, + final double gamma, final double sigma) { + super(n); + + this.rho = rho; + this.khi = khi; + this.gamma = gamma; + this.sigma = sigma; + } + + /** + * Build a Nelder-Mead simplex with default coefficients. + * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5 + * for both gamma and sigma. + * + * @param steps Steps along the canonical axes representing box edges. + * They may be negative but not zero. See + */ + public NelderMeadSimplex(final double[] steps) { + this(steps, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA); + } + + /** + * Build a Nelder-Mead simplex with specified coefficients. + * + * @param steps Steps along the canonical axes representing box edges. + * They may be negative but not zero. See + * {@link AbstractSimplex#AbstractSimplex(double[])}. + * @param rho Reflection coefficient. + * @param khi Expansion coefficient. + * @param gamma Contraction coefficient. + * @param sigma Shrinkage coefficient. + * @throws IllegalArgumentException if one of the steps is zero. + */ + public NelderMeadSimplex(final double[] steps, + final double rho, final double khi, + final double gamma, final double sigma) { + super(steps); + + this.rho = rho; + this.khi = khi; + this.gamma = gamma; + this.sigma = sigma; + } + + /** + * Build a Nelder-Mead simplex with default coefficients. + * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5 + * for both gamma and sigma. + * + * @param referenceSimplex Reference simplex. See + * {@link AbstractSimplex#AbstractSimplex(double[][])}. + */ + public NelderMeadSimplex(final double[][] referenceSimplex) { + this(referenceSimplex, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA); + } + + /** + * Build a Nelder-Mead simplex with specified coefficients. + * + * @param referenceSimplex Reference simplex. See + * {@link AbstractSimplex#AbstractSimplex(double[][])}. + * @param rho Reflection coefficient. + * @param khi Expansion coefficient. + * @param gamma Contraction coefficient. + * @param sigma Shrinkage coefficient. + * @throws org.apache.commons.math.exception.NotStrictlyPositiveException + * if the reference simplex does not contain at least one point. + * @throws org.apache.commons.math.exception.DimensionMismatchException + * if there is a dimension mismatch in the reference simplex. + */ + public NelderMeadSimplex(final double[][] referenceSimplex, + final double rho, final double khi, + final double gamma, final double sigma) { + super(referenceSimplex); + + this.rho = rho; + this.khi = khi; + this.gamma = gamma; + this.sigma = sigma; + } + + /** {@inheritDoc} */ + @Override + public void iterate(final MultivariateRealFunction evaluationFunction, + final Comparator comparator) + throws FunctionEvaluationException { + + // The simplex has n + 1 points if dimension is n. + final int n = getDimension(); + + // Interesting values. + final RealPointValuePair best = getPoint(0); + final RealPointValuePair secondBest = getPoint(n - 1); + final RealPointValuePair worst = getPoint(n); + final double[] xWorst = worst.getPointRef(); + + // Compute the centroid of the best vertices (dismissing the worst + // point at index n). + final double[] centroid = new double[n]; + for (int i = 0; i < n; i++) { + final double[] x = getPoint(i).getPointRef(); + for (int j = 0; j < n; j++) { + centroid[j] += x[j]; + } + } + final double scaling = 1.0 / n; + for (int j = 0; j < n; j++) { + centroid[j] *= scaling; + } + + // compute the reflection point + final double[] xR = new double[n]; + for (int j = 0; j < n; j++) { + xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]); + } + final RealPointValuePair reflected + = new RealPointValuePair(xR, evaluationFunction.value(xR), false); + + if (comparator.compare(best, reflected) <= 0 && + comparator.compare(reflected, secondBest) < 0) { + // Accept the reflected point. + replaceWorstPoint(reflected, comparator); + } else if (comparator.compare(reflected, best) < 0) { + // Compute the expansion point. + final double[] xE = new double[n]; + for (int j = 0; j < n; j++) { + xE[j] = centroid[j] + khi * (xR[j] - centroid[j]); + } + final RealPointValuePair expanded + = new RealPointValuePair(xE, evaluationFunction.value(xE), false); + + if (comparator.compare(expanded, reflected) < 0) { + // Accept the expansion point. + replaceWorstPoint(expanded, comparator); + } else { + // Accept the reflected point. + replaceWorstPoint(reflected, comparator); + } + } else { + if (comparator.compare(reflected, worst) < 0) { + // Perform an outside contraction. + final double[] xC = new double[n]; + for (int j = 0; j < n; j++) { + xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]); + } + final RealPointValuePair outContracted + = new RealPointValuePair(xC, evaluationFunction.value(xC), false); + if (comparator.compare(outContracted, reflected) <= 0) { + // Accept the contraction point. + replaceWorstPoint(outContracted, comparator); + return; + } + } else { + // Perform an inside contraction. + final double[] xC = new double[n]; + for (int j = 0; j < n; j++) { + xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]); + } + final RealPointValuePair inContracted + = new RealPointValuePair(xC, evaluationFunction.value(xC), false); + + if (comparator.compare(inContracted, worst) < 0) { + // Accept the contraction point. + replaceWorstPoint(inContracted, comparator); + return; + } + } + + // Perform a shrink. + final double[] xSmallest = getPoint(0).getPointRef(); + for (int i = 1; i <= n; i++) { + final double[] x = getPoint(i).getPoint(); + for (int j = 0; j < n; j++) { + x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]); + } + setPoint(i, new RealPointValuePair(x, Double.NaN, false)); + } + evaluate(evaluationFunction, comparator); + } + } +} diff --git a/src/main/java/org/apache/commons/math/optimization/direct/SimplexOptimizer.java b/src/main/java/org/apache/commons/math/optimization/direct/SimplexOptimizer.java new file mode 100644 index 000000000..6fdf209bb --- /dev/null +++ b/src/main/java/org/apache/commons/math/optimization/direct/SimplexOptimizer.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.optimization.direct; + +import java.util.Comparator; + +import org.apache.commons.math.FunctionEvaluationException; +import org.apache.commons.math.analysis.MultivariateRealFunction; +import org.apache.commons.math.exception.NullArgumentException; +import org.apache.commons.math.optimization.GoalType; +import org.apache.commons.math.optimization.ConvergenceChecker; +import org.apache.commons.math.optimization.RealPointValuePair; +import org.apache.commons.math.optimization.SimpleScalarValueChecker; + +/** + * This class implements simplex-based direct search optimization. + * + *

+ * Direct search methods only use objective function values, they do + * not need derivatives and don't either try to compute approximation + * of the derivatives. According to a 1996 paper by Margaret H. Wright + * (Direct + * Search Methods: Once Scorned, Now Respectable), they are used + * when either the computation of the derivative is impossible (noisy + * functions, unpredictable discontinuities) or difficult (complexity, + * computation cost). In the first cases, rather than an optimum, a + * not too bad point is desired. In the latter cases, an + * optimum is desired but cannot be reasonably found. In all cases + * direct search methods can be useful. + *

+ *

+ * Simplex-based direct search methods are based on comparison of + * the objective function values at the vertices of a simplex (which is a + * set of n+1 points in dimension n) that is updated by the algorithms + * steps. + *

+ *

+ * The {@link #setSimplex(AbstractSimplex) setSimplex} method must + * be called prior to calling the {@code optimize} method. + *

+ *

+ * Each call to {@link #optimize(MultivariateRealFunction,GoalType,double[]) + * optimize} will re-use the start configuration of the current simplex and + * move it such that its first vertex is at the provided start point of the + * optimization. If the {@code optimize} method is called to solve a different + * problem and the number of parameters change, the simplex must be + * re-initialized to one with the appropriate dimensions. + *

+ *

+ * If {@link #setConvergenceChecker(ConvergenceChecker)} is not called, + * a default {@link SimpleScalarValueChecker} is used. + *

+ *

+ * Convergence is checked by providing the worst points of + * previous and current simplex to the convergence checker, not the best + * ones. + *

+ * + * @see AbstractSimplex + * @version $Revision$ $Date$ + * @since 3.0 + */ +public class SimplexOptimizer + extends BaseAbstractScalarOptimizer { + /** Simplex. */ + private AbstractSimplex simplex; + + /** + * Default constructor. + */ + public SimplexOptimizer() { + setConvergenceChecker(new SimpleScalarValueChecker()); + } + + /** + * @param rel Relative threshold. + * @param abs Absolute threshold. + */ + public SimplexOptimizer(double rel, double abs) { + setConvergenceChecker(new SimpleScalarValueChecker(rel, abs)); + } + + /** + * Set the simplex algorithm. + * + * @param simplex Simplex. + */ + public void setSimplex(AbstractSimplex simplex) { + this.simplex = simplex; + } + + /** {@inheritDoc} */ + @Override + protected RealPointValuePair doOptimize() + throws FunctionEvaluationException { + if (simplex == null) { + throw new NullArgumentException(); + } + + // Indirect call to "computeObjectiveValue" in order to update the + // evaluations counter. + final MultivariateRealFunction evalFunc + = new MultivariateRealFunction() { + public double value(double[] point) + throws FunctionEvaluationException { + return computeObjectiveValue(point); + } + }; + + final boolean isMinim = getGoalType() == GoalType.MINIMIZE; + final Comparator comparator + = new Comparator() { + public int compare(final RealPointValuePair o1, + final RealPointValuePair o2) { + final double v1 = o1.getValue(); + final double v2 = o2.getValue(); + return isMinim ? Double.compare(v1, v2) : Double.compare(v2, v1); + } + }; + + // Initialize search. + simplex.build(getStartPoint()); + simplex.evaluate(evalFunc, comparator); + + RealPointValuePair[] previous = null; + int iteration = 0; + final ConvergenceChecker checker = getConvergenceChecker(); + while (true) { + if (iteration > 0) { + boolean converged = true; + for (int i = 0; i < simplex.getSize(); i++) { + converged &= checker.converged(iteration, previous[i], simplex.getPoint(i)); + } + if (converged) { + // We have found an optimum. + return simplex.getPoint(0); + } + } + + // We still need to search. + previous = simplex.getPoints(); + simplex.iterate(evalFunc, comparator); + ++iteration; + } + } +} diff --git a/src/site/xdoc/changes.xml b/src/site/xdoc/changes.xml index 7c9143623..be8ab0ee9 100644 --- a/src/site/xdoc/changes.xml +++ b/src/site/xdoc/changes.xml @@ -52,6 +52,13 @@ The type attribute can be add,update,fix,remove. If the output is not quite correct, check for invisible trailing spaces! --> + + Class "DirectSearchOptimizer" (and subclasses "NelderMead" + and "MultiDirectional") was refactored into new classes: + "SimplexOptimizer" and "AbstractSimplex" (and subclasses + "NelderMeadSimplex" and "MultiDirectionalSimplex"). The old + classes were deprecated and removed. + Replaced old exceptions. @@ -86,43 +93,50 @@ The type attribute can be add,update,fix,remove. - Fixed k-means++ to add several strategies to deal with empty clusters that may appear - during iterations + Fixed k-means++ to add several strategies to deal with empty clusters that + may appear during iterations. Improved Percentile performance by using a selection algorithm instead of a complete sort, and by allowing caching data array and pivots when several - different percentiles are desired + different percentiles are desired. - Fixed an error preventing zero length vectors to be built by some constructors + Fixed an error preventing zero length vectors to be built by some constructors. - Fixed an error preventing ODE solvers to be restarted after they have been stopped by a discrete event + Fixed an error preventing ODE solvers to be restarted after they have + been stopped by a discrete event. - Added new random number generators from the Well Equidistributed Long-period Linear (WELL). + Added new random number generators from the Well Equidistributed + Long-period Linear (WELL). - Made intercept / no intercept configurable in multiple regression classes. By default, regression - models are estimated with an intercept term. When the "noIntercept" property is set to - true, regression models are estimated without intercepts. + Made intercept / no intercept configurable in multiple regression + classes. By default, regression models are estimated with an intercept + term. When the "noIntercept" property is set to true, regression models + are estimated without intercepts. - Fixed lost cause in MathRuntimeException.createInternalError. Note that the message is still the default - message for internal errors asking to report a bug to commons-math JIRA tracker. In order to retrieve - the message from the root cause, one has to get the cause itself by getCause(). + Fixed lost cause in MathRuntimeException.createInternalError. Note that + the message is still the default message for internal errors asking to + report a bug to commons-math JIRA tracker. In order to retrieve the + message from the root cause, one has to get the cause itself by getCause(). - Modified multiple regression newSample methods to ensure that by default in all cases, - regression models are estimated with intercept terms. Prior to the fix for this issue, - newXSampleData(double[][]), newSampleData(double[], double[][]) and - newSampleData(double[], double[][], double[][]) all required columns of "1's" to be inserted - into the x[][] arrays to create a model with an intercept term; while newSampleData(double[], int, int) - created a model including an intercept term without requiring the unitary column. All methods have - been changed to eliminate the need for users to add unitary columns to specify regression models. - Users of OLSMultipleLinearRegression or GLSMultipleLinearRegression versions 2.0 or 2.1 should either - verify that their code either does not use the first set of data loading methods above or set the noIntercept + Modified multiple regression newSample methods to ensure that by default + in all cases, regression models are estimated with intercept terms. + Prior to the fix for this issue, newXSampleData(double[][]), + newSampleData(double[], double[][]) and newSampleData(double[], double[][], double[][]) + all required columns of "1's" to be inserted into the x[][] arrays to + create a model with an intercept term; while newSampleData(double[], int, int) + created a model including an intercept term without requiring the + unitary column. All methods have been changed to eliminate the need for + users to add unitary columns to specify regression models. + Users of OLSMultipleLinearRegression or GLSMultipleLinearRegression + versions 2.0 or 2.1 should either verify that their code either does + not use the first set of data loading methods above or set the noIntercept property on estimated models to get the previous behavior. diff --git a/src/test/java/org/apache/commons/math/optimization/direct/SimplexOptimizerMultiDirectionalTest.java b/src/test/java/org/apache/commons/math/optimization/direct/SimplexOptimizerMultiDirectionalTest.java new file mode 100644 index 000000000..55a3a372b --- /dev/null +++ b/src/test/java/org/apache/commons/math/optimization/direct/SimplexOptimizerMultiDirectionalTest.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.optimization.direct; + +import org.apache.commons.math.FunctionEvaluationException; +import org.apache.commons.math.analysis.MultivariateRealFunction; +import org.apache.commons.math.optimization.GoalType; +import org.apache.commons.math.optimization.RealPointValuePair; +import org.apache.commons.math.optimization.SimpleScalarValueChecker; +import org.apache.commons.math.util.FastMath; +import org.junit.Assert; +import org.junit.Test; + +public class SimplexOptimizerMultiDirectionalTest { + @Test + public void testMinimizeMaximize() throws FunctionEvaluationException { + // the following function has 4 local extrema: + final double xM = -3.841947088256863675365; + final double yM = -1.391745200270734924416; + final double xP = 0.2286682237349059125691; + final double yP = -yM; + final double valueXmYm = 0.2373295333134216789769; // local maximum + final double valueXmYp = -valueXmYm; // local minimum + final double valueXpYm = -0.7290400707055187115322; // global minimum + final double valueXpYp = -valueXpYm; // global maximum + MultivariateRealFunction fourExtrema = new MultivariateRealFunction() { + private static final long serialVersionUID = -7039124064449091152L; + public double value(double[] variables) throws FunctionEvaluationException { + final double x = variables[0]; + final double y = variables[1]; + return ((x == 0) || (y == 0)) ? 0 : + (FastMath.atan(x) * FastMath.atan(x + 2) * FastMath.atan(y) * FastMath.atan(y) / (x * y)); + } + }; + + SimplexOptimizer optimizer = new SimplexOptimizer(1e-11, 1e-30); + optimizer.setMaxEvaluations(200); + optimizer.setSimplex(new MultiDirectionalSimplex(new double[] { 0.2, 0.2 })); + RealPointValuePair optimum; + + // minimization + optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { -3, 0 }); + Assert.assertEquals(xM, optimum.getPoint()[0], 4e-6); + Assert.assertEquals(yP, optimum.getPoint()[1], 3e-6); + Assert.assertEquals(valueXmYp, optimum.getValue(), 8e-13); + Assert.assertTrue(optimizer.getEvaluations() > 120); + Assert.assertTrue(optimizer.getEvaluations() < 150); + + optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { 1, 0 }); + Assert.assertEquals(xP, optimum.getPoint()[0], 2e-8); + Assert.assertEquals(yM, optimum.getPoint()[1], 3e-6); + Assert.assertEquals(valueXpYm, optimum.getValue(), 2e-12); + Assert.assertTrue(optimizer.getEvaluations() > 120); + Assert.assertTrue(optimizer.getEvaluations() < 150); + + // maximization + optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { -3.0, 0.0 }); + Assert.assertEquals(xM, optimum.getPoint()[0], 7e-7); + Assert.assertEquals(yM, optimum.getPoint()[1], 3e-7); + Assert.assertEquals(valueXmYm, optimum.getValue(), 2e-14); + Assert.assertTrue(optimizer.getEvaluations() > 120); + Assert.assertTrue(optimizer.getEvaluations() < 150); + + optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1e-15, 1e-30)); + optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { 1, 0 }); + Assert.assertEquals(xP, optimum.getPoint()[0], 2e-8); + Assert.assertEquals(yP, optimum.getPoint()[1], 3e-6); + Assert.assertEquals(valueXpYp, optimum.getValue(), 2e-12); + Assert.assertTrue(optimizer.getEvaluations() > 180); + Assert.assertTrue(optimizer.getEvaluations() < 220); + } + + @Test + public void testRosenbrock() throws FunctionEvaluationException { + MultivariateRealFunction rosenbrock = + new MultivariateRealFunction() { + private static final long serialVersionUID = -9044950469615237490L; + public double value(double[] x) throws FunctionEvaluationException { + ++count; + double a = x[1] - x[0] * x[0]; + double b = 1.0 - x[0]; + return 100 * a * a + b * b; + } + }; + + count = 0; + SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3); + optimizer.setMaxEvaluations(100); + optimizer.setSimplex(new MultiDirectionalSimplex(new double[][] { + { -1.2, 1.0 }, { 0.9, 1.2 } , { 3.5, -2.3 } + })); + RealPointValuePair optimum = + optimizer.optimize(rosenbrock, GoalType.MINIMIZE, new double[] { -1.2, 1 }); + + Assert.assertEquals(count, optimizer.getEvaluations()); + Assert.assertTrue(optimizer.getEvaluations() > 50); + Assert.assertTrue(optimizer.getEvaluations() < 100); + Assert.assertTrue(optimum.getValue() > 1e-2); + } + + @Test + public void testPowell() throws FunctionEvaluationException { + MultivariateRealFunction powell = + new MultivariateRealFunction() { + private static final long serialVersionUID = -832162886102041840L; + public double value(double[] x) throws FunctionEvaluationException { + ++count; + double a = x[0] + 10 * x[1]; + double b = x[2] - x[3]; + double c = x[1] - 2 * x[2]; + double d = x[0] - x[3]; + return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d; + } + }; + + count = 0; + SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3); + optimizer.setMaxEvaluations(1000); + optimizer.setSimplex(new MultiDirectionalSimplex(4)); + RealPointValuePair optimum = + optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3, -1, 0, 1 }); + Assert.assertEquals(count, optimizer.getEvaluations()); + Assert.assertTrue(optimizer.getEvaluations() > 800); + Assert.assertTrue(optimizer.getEvaluations() < 900); + Assert.assertTrue(optimum.getValue() > 1e-2); + } + + @Test + public void testMath283() throws FunctionEvaluationException { + // fails because MultiDirectional.iterateSimplex is looping forever + // the while(true) should be replaced with a convergence check + SimplexOptimizer optimizer = new SimplexOptimizer(); + optimizer.setMaxEvaluations(1000); + optimizer.setSimplex(new MultiDirectionalSimplex(2)); + final Gaussian2D function = new Gaussian2D(0, 0, 1); + RealPointValuePair estimate = optimizer.optimize(function, + GoalType.MAXIMIZE, function.getMaximumPosition()); + final double EPSILON = 1e-5; + final double expectedMaximum = function.getMaximum(); + final double actualMaximum = estimate.getValue(); + Assert.assertEquals(expectedMaximum, actualMaximum, EPSILON); + + final double[] expectedPosition = function.getMaximumPosition(); + final double[] actualPosition = estimate.getPoint(); + Assert.assertEquals(expectedPosition[0], actualPosition[0], EPSILON ); + Assert.assertEquals(expectedPosition[1], actualPosition[1], EPSILON ); + } + + private static class Gaussian2D implements MultivariateRealFunction { + private final double[] maximumPosition; + private final double std; + + public Gaussian2D(double xOpt, double yOpt, double std) { + maximumPosition = new double[] { xOpt, yOpt }; + this.std = std; + } + + public double getMaximum() { + return value(maximumPosition); + } + + public double[] getMaximumPosition() { + return maximumPosition.clone(); + } + + public double value(double[] point) { + final double x = point[0], y = point[1]; + final double twoS2 = 2.0 * std * std; + return 1.0 / (twoS2 * FastMath.PI) * FastMath.exp(-(x * x + y * y) / twoS2); + } + } + + private int count; +} diff --git a/src/test/java/org/apache/commons/math/optimization/direct/SimplexOptimizerNelderMeadTest.java b/src/test/java/org/apache/commons/math/optimization/direct/SimplexOptimizerNelderMeadTest.java new file mode 100644 index 000000000..8610b5b61 --- /dev/null +++ b/src/test/java/org/apache/commons/math/optimization/direct/SimplexOptimizerNelderMeadTest.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.optimization.direct; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.apache.commons.math.FunctionEvaluationException; +import org.apache.commons.math.exception.TooManyEvaluationsException; +import org.apache.commons.math.analysis.MultivariateRealFunction; +import org.apache.commons.math.analysis.MultivariateVectorialFunction; +import org.apache.commons.math.linear.Array2DRowRealMatrix; +import org.apache.commons.math.linear.RealMatrix; +import org.apache.commons.math.optimization.GoalType; +import org.apache.commons.math.optimization.LeastSquaresConverter; +import org.apache.commons.math.optimization.RealPointValuePair; +import org.apache.commons.math.optimization.SimpleScalarValueChecker; +import org.junit.Test; +import org.junit.Ignore; + +public class SimplexOptimizerNelderMeadTest { + @Test + public void testMinimizeMaximize() + throws FunctionEvaluationException { + + // the following function has 4 local extrema: + final double xM = -3.841947088256863675365; + final double yM = -1.391745200270734924416; + final double xP = 0.2286682237349059125691; + final double yP = -yM; + final double valueXmYm = 0.2373295333134216789769; // local maximum + final double valueXmYp = -valueXmYm; // local minimum + final double valueXpYm = -0.7290400707055187115322; // global minimum + final double valueXpYp = -valueXpYm; // global maximum + MultivariateRealFunction fourExtrema = new MultivariateRealFunction() { + private static final long serialVersionUID = -7039124064449091152L; + public double value(double[] variables) throws FunctionEvaluationException { + final double x = variables[0]; + final double y = variables[1]; + return (x == 0 || y == 0) ? 0 : + (Math.atan(x) * Math.atan(x + 2) * Math.atan(y) * Math.atan(y) / (x * y)); + } + }; + + SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30); + optimizer.setMaxEvaluations(100); + optimizer.setSimplex(new NelderMeadSimplex(new double[] { 0.2, 0.2 })); + RealPointValuePair optimum; + + // minimization + optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { -3, 0 }); + assertEquals(xM, optimum.getPoint()[0], 2e-7); + assertEquals(yP, optimum.getPoint()[1], 2e-5); + assertEquals(valueXmYp, optimum.getValue(), 6e-12); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 90); + + optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { 1, 0 }); + assertEquals(xP, optimum.getPoint()[0], 5e-6); + assertEquals(yM, optimum.getPoint()[1], 6e-6); + assertEquals(valueXpYm, optimum.getValue(), 1e-11); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 90); + + // maximization + optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { -3, 0 }); + assertEquals(xM, optimum.getPoint()[0], 1e-5); + assertEquals(yM, optimum.getPoint()[1], 3e-6); + assertEquals(valueXmYm, optimum.getValue(), 3e-12); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 90); + + optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { 1, 0 }); + assertEquals(xP, optimum.getPoint()[0], 4e-6); + assertEquals(yP, optimum.getPoint()[1], 5e-6); + assertEquals(valueXpYp, optimum.getValue(), 7e-12); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 90); + } + + @Test + public void testRosenbrock() + throws FunctionEvaluationException { + + Rosenbrock rosenbrock = new Rosenbrock(); + SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3); + optimizer.setMaxEvaluations(100); + optimizer.setSimplex(new NelderMeadSimplex(new double[][] { + { -1.2, 1 }, { 0.9, 1.2 } , { 3.5, -2.3 } + })); + RealPointValuePair optimum = + optimizer.optimize(rosenbrock, GoalType.MINIMIZE, new double[] { -1.2, 1 }); + + assertEquals(rosenbrock.getCount(), optimizer.getEvaluations()); + assertTrue(optimizer.getEvaluations() > 40); + assertTrue(optimizer.getEvaluations() < 50); + assertTrue(optimum.getValue() < 8e-4); + } + + @Test + public void testPowell() + throws FunctionEvaluationException { + + Powell powell = new Powell(); + SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3); + optimizer.setMaxEvaluations(200); + optimizer.setSimplex(new NelderMeadSimplex(4)); + RealPointValuePair optimum = + optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3, -1, 0, 1 }); + assertEquals(powell.getCount(), optimizer.getEvaluations()); + assertTrue(optimizer.getEvaluations() > 110); + assertTrue(optimizer.getEvaluations() < 130); + assertTrue(optimum.getValue() < 2e-3); + } + + @Test + public void testLeastSquares1() + throws FunctionEvaluationException { + + final RealMatrix factors = + new Array2DRowRealMatrix(new double[][] { + { 1, 0 }, + { 0, 1 } + }, false); + LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() { + public double[] value(double[] variables) { + return factors.operate(variables); + } + }, new double[] { 2.0, -3.0 }); + SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-6); + optimizer.setMaxEvaluations(200); + optimizer.setSimplex(new NelderMeadSimplex(2)); + RealPointValuePair optimum = + optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10, 10 }); + assertEquals( 2, optimum.getPointRef()[0], 3e-5); + assertEquals(-3, optimum.getPointRef()[1], 4e-4); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 80); + assertTrue(optimum.getValue() < 1.0e-6); + } + + @Test + public void testLeastSquares2() + throws FunctionEvaluationException { + + final RealMatrix factors = + new Array2DRowRealMatrix(new double[][] { + { 1, 0 }, + { 0, 1 } + }, false); + LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() { + public double[] value(double[] variables) { + return factors.operate(variables); + } + }, new double[] { 2, -3 }, new double[] { 10, 0.1 }); + SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-6); + optimizer.setMaxEvaluations(200); + optimizer.setSimplex(new NelderMeadSimplex(2)); + RealPointValuePair optimum = + optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10, 10 }); + assertEquals( 2, optimum.getPointRef()[0], 5e-5); + assertEquals(-3, optimum.getPointRef()[1], 8e-4); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 80); + assertTrue(optimum.getValue() < 1e-6); + } + + @Test + public void testLeastSquares3() + throws FunctionEvaluationException { + + final RealMatrix factors = + new Array2DRowRealMatrix(new double[][] { + { 1, 0 }, + { 0, 1 } + }, false); + LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() { + public double[] value(double[] variables) { + return factors.operate(variables); + } + }, new double[] { 2, -3 }, new Array2DRowRealMatrix(new double [][] { + { 1, 1.2 }, { 1.2, 2 } + })); + SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-6); + optimizer.setMaxEvaluations(200); + optimizer.setSimplex(new NelderMeadSimplex(2)); + RealPointValuePair optimum = + optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10, 10 }); + assertEquals( 2, optimum.getPointRef()[0], 2e-3); + assertEquals(-3, optimum.getPointRef()[1], 8e-4); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 80); + assertTrue(optimum.getValue() < 1e-6); + } + + @Test(expected = TooManyEvaluationsException.class) + public void testMaxIterations() throws FunctionEvaluationException { + Powell powell = new Powell(); + SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3); + optimizer.setMaxEvaluations(20); + optimizer.setSimplex(new NelderMeadSimplex(4)); + optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3, -1, 0, 1 }); + } + + private static class Rosenbrock implements MultivariateRealFunction { + private int count; + + public Rosenbrock() { + count = 0; + } + + public double value(double[] x) throws FunctionEvaluationException { + ++count; + double a = x[1] - x[0] * x[0]; + double b = 1.0 - x[0]; + return 100 * a * a + b * b; + } + + public int getCount() { + return count; + } + } + + private static class Powell implements MultivariateRealFunction { + private int count; + + public Powell() { + count = 0; + } + + public double value(double[] x) throws FunctionEvaluationException { + ++count; + double a = x[0] + 10 * x[1]; + double b = x[2] - x[3]; + double c = x[1] - 2 * x[2]; + double d = x[0] - x[3]; + return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d; + } + + public int getCount() { + return count; + } + } +}