diff --git a/NOTICE.txt b/NOTICE.txt index fc05a09c8..d1f0fc0a3 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -5,6 +5,14 @@ This product includes software developed by The Apache Software Foundation (http://www.apache.org/). =============================================================================== + +The BracketFinder (package org.apache.commons.math.optimization.univariate) +and PowellOptimizer (package org.apache.commons.math.optimization.general) +classes are based on the Python code in module "optimize.py" (version 0.5) +developed by Travis E. Oliphant for the SciPy library (http://www.scipy.org/) +Copyright © 2003-2009 SciPy Developers. +=============================================================================== + The LinearConstraint, LinearObjectiveFunction, LinearOptimizer, RelationShip, SimplexSolver and SimplexTableau classes in package org.apache.commons.math.optimization.linear include software developed by diff --git a/src/main/java/org/apache/commons/math/optimization/general/PowellOptimizer.java b/src/main/java/org/apache/commons/math/optimization/general/PowellOptimizer.java new file mode 100644 index 000000000..9ef5d598b --- /dev/null +++ b/src/main/java/org/apache/commons/math/optimization/general/PowellOptimizer.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.optimization.general; + +import org.apache.commons.math.ConvergenceException; +import org.apache.commons.math.FunctionEvaluationException; +import org.apache.commons.math.analysis.MultivariateRealFunction; +import org.apache.commons.math.analysis.UnivariateRealFunction; +import org.apache.commons.math.optimization.univariate.AbstractUnivariateRealOptimizer; +import org.apache.commons.math.optimization.univariate.BrentOptimizer; +import org.apache.commons.math.optimization.univariate.BracketFinder; +import org.apache.commons.math.optimization.GoalType; +import org.apache.commons.math.optimization.MultivariateRealOptimizer; +import org.apache.commons.math.optimization.OptimizationException; +import org.apache.commons.math.optimization.RealPointValuePair; +import org.apache.commons.math.optimization.SimpleScalarValueChecker; + +/** + * Powell algorithm. + * This code is translated and adapted from the Python version of this + * algorithm (as implemented in module {@code optimize.py} v0.5 of + * SciPy). + * + * @version $Revision$ $Date$ + * @since 2.2 + */ +public class PowellOptimizer + extends AbstractScalarOptimizer { + /** + * Defautl line search tolerance ({@value}). + */ + public static final double DEFAULT_LINE_SEARCH_TOLERANCE = 1e-7; + /** + * Line search. + */ + private final LineSearch line; + + /** + * Constructor using the default line search tolerance (see the + * {@link #PowellOptimizer(double) other constructor}). + */ + public PowellOptimizer() { + this(DEFAULT_LINE_SEARCH_TOLERANCE); + } + + /** + * @param lineSearchTolerance Relative error tolerance for the line search + * algorithm ({@link BrentOptimizer}). + */ + public PowellOptimizer(double lineSearchTolerance) { + line = new LineSearch(lineSearchTolerance); + } + + /** {@inheritDoc} */ + @Override + protected RealPointValuePair doOptimize() + throws FunctionEvaluationException, + OptimizationException { + final GoalType goal = getGoalType(); + final double[] guess = getStartPoint(); + final int n = guess.length; + + final double[][] direc = new double[n][n]; + for (int i = 0; i < n; i++) { + direc[i][i] = 1; + } + + double[] x = guess; + double fVal = computeObjectiveValue(x); + double[] x1 = x.clone(); + while (true) { + incrementIterationsCounter(); + + double fX = fVal; + double fX2 = 0; + double delta = 0; + int bigInd = 0; + double alphaMin = 0; + + double[] direc1 = new double[n]; + for (int i = 0; i < n; i++) { + direc1 = direc[i]; + + fX2 = fVal; + + line.search(x, direc1); + fVal = line.getValueAtOptimum(); + alphaMin = line.getOptimum(); + setNewPointAndDirection(x, direc1, alphaMin); + + if ((fX2 - fVal) > delta) { + delta = fX2 - fVal; + bigInd = i; + } + } + + final RealPointValuePair previous = new RealPointValuePair(x1, fX); + final RealPointValuePair current = new RealPointValuePair(x, fVal); + if (getConvergenceChecker().converged(getIterations(), previous, current)) { + switch (goal) { + case MINIMIZE: + return (fVal < fX ? current : previous); + case MAXIMIZE: + return (fVal > fX ? current : previous); + } + } + + double[] x2 = new double[n]; + for (int i = 0; i < n; i++) { + direc1[i] = x[i] - x1[i]; + x2[i] = 2 * x[i] - x1[i]; + } + + x1 = x.clone(); + fX2 = computeObjectiveValue(x2); + + if (fX > fX2) { + double t = 2 * (fX + fX2 - 2 * fVal); + double temp = fX - fVal - delta; + t *= temp * temp; + temp = fX - fX2; + t -= delta * temp * temp; + + if (t < 0.0) { + line.search(x, direc1); + fVal = line.getValueAtOptimum(); + alphaMin = line.getOptimum(); + setNewPointAndDirection(x, direc1, alphaMin); + + final int lastInd = n - 1; + direc[bigInd] = direc[lastInd]; + direc[lastInd] = direc1; + } + } + } + } + + /** + * Compute a new point (in the original space) and a new direction + * vector, resulting from the line search. + * The parameters {@code p} and {@code d} will be changed in-place. + * + * @param p Point used in the line search. + * @param d Direction used in the line search. + * @param optimum Optimum found by the line search. + */ + private void setNewPointAndDirection(double[] p, + double[] d, + double optimum) { + final int n = p.length; + for (int i = 0; i < n; i++) { + d[i] *= optimum; + p[i] += d[i]; + } + } + + /** + * Class for finding the minimum of the objective function along a given + * direction. + */ + private class LineSearch { + /** + * Optimizer. + */ + private final AbstractUnivariateRealOptimizer optim = new BrentOptimizer(); + /** + * Automatic bracketing. + */ + private final BracketFinder bracket = new BracketFinder(); + /** + * Value of the optimum. + */ + private double optimum = Double.NaN; + /** + * Value of the objective function at the optimum. + */ + private double valueAtOptimum = Double.NaN; + + /** + * @param tolerance Relative tolerance. + */ + public LineSearch(double tolerance) { + optim.setRelativeAccuracy(tolerance); + optim.setAbsoluteAccuracy(Math.ulp(1d)); + } + + /** + * Find the minimum of the function {@code f(p + alpha * d)}. + * + * @param p Starting point. + * @param d Search direction. + */ + public void search(final double[] p, + final double[] d) + throws OptimizationException { + try { + final int n = p.length; + final UnivariateRealFunction f = new UnivariateRealFunction() { + public double value(double alpha) + throws FunctionEvaluationException { + + final double[] x = new double[n]; + for (int i = 0; i < n; i++) { + x[i] = p[i] + alpha * d[i]; + } + return computeObjectiveValue(x); + } + }; + + final GoalType goal = getGoalType(); + bracket.search(f, goal, 0, 1); + optimum = optim.optimize(f, goal, + bracket.getLo(), + bracket.getHi(), + bracket.getMid()); + valueAtOptimum = f.value(optimum); + } catch (Exception e) { + throw new OptimizationException(e); + } + } + + /** + * @return the optimum. + */ + public double getOptimum() { + return optimum; + } + /** + * @return the value of the function at the optimum. + */ + public double getValueAtOptimum() { + return valueAtOptimum; + } + } +} diff --git a/src/main/java/org/apache/commons/math/optimization/univariate/BracketFinder.java b/src/main/java/org/apache/commons/math/optimization/univariate/BracketFinder.java new file mode 100644 index 000000000..f53f36f10 --- /dev/null +++ b/src/main/java/org/apache/commons/math/optimization/univariate/BracketFinder.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.optimization.univariate; + +import org.apache.commons.math.exception.NotStrictlyPositiveException; +import org.apache.commons.math.FunctionEvaluationException; +import org.apache.commons.math.MaxIterationsExceededException; +import org.apache.commons.math.analysis.UnivariateRealFunction; +import org.apache.commons.math.optimization.GoalType; + +/** + * Provide an interval that brackets a local optimum of a function. + * This code is based on a Python implementation (from SciPy, + * module {@code optimize.py} v0.5). + */ +public class BracketFinder { + private static final double EPS_MIN = 1e-21; + /** + * Golden section. + */ + private static final double GOLD = 1.618034; + /** + * Factor for expanding the interval. + */ + private final double growLimit; + /** + * Maximum number of iterations. + */ + private final int maxIterations; + /** + * Number of iterations. + */ + private int iterations; + /** + * Number of function evaluations. + */ + private int evaluations; + /** + * Lower bound of the bracket. + */ + private double lo; + /** + * Higher bound of the bracket. + */ + private double hi; + /** + * Point inside the bracket. + */ + private double mid; + /** + * Function value at {@link #lo}. + */ + private double fLo; + /** + * Function value at {@link #hi}. + */ + private double fHi; + /** + * Function value at {@link #mid}. + */ + private double fMid; + + /** + * Constructor with default values {@code 100, 50} (see the + * {@link #BracketFinder(double,int) other constructor}). + */ + public BracketFinder() { + this(100, 50); + } + + /** + * Create a bracketing interval finder. + * + * @param growLimit Expanding factor. + * @param maxIterations Maximum number of iterations allowed for finding + * a bracketing interval. + */ + public BracketFinder(double growLimit, + int maxIterations) { + if (growLimit <= 0) { + throw new NotStrictlyPositiveException(growLimit); + } + if (maxIterations <= 0) { + throw new NotStrictlyPositiveException(maxIterations); + } + + this.growLimit = growLimit; + this.maxIterations = maxIterations; + } + + /** + * Search new points that bracket a local optimum of the function. + * + * @param func Function whose optimum should be bracketted. + * @param goal {@link GoalType Goal type}. + * @param xA Initial point. + * @param xB Initial point. + */ + public void search(UnivariateRealFunction func, + GoalType goal, + double xA, + double xB) + throws MaxIterationsExceededException, + FunctionEvaluationException { + reset(); + final boolean isMinim = (goal == GoalType.MINIMIZE); + + double fA = eval(func, xA); + double fB = eval(func, xB); + if (isMinim ? + fA < fB : + fA > fB) { + double tmp = xA; + xA = xB; + xB = tmp; + + tmp = fA; + fA = fB; + fB = tmp; + } + + double xC = xB + GOLD * (xB - xA); + double fC = eval(func, xC); + + while (isMinim ? fC < fB : fC > fB) { + if (++iterations > maxIterations) { + throw new MaxIterationsExceededException(maxIterations); + } + + double tmp1 = (xB - xA) * (fB - fC); + double tmp2 = (xB - xC) * (fB - fA); + + double val = tmp2 - tmp1; + double denom = Math.abs(val) < EPS_MIN ? 2 * EPS_MIN : 2 * val; + + double w = xB - ((xB - xC) * tmp2 - (xB -xA) * tmp1) / denom; + double wLim = xB + growLimit * (xC - xB); + + double fW; + if ((w - xC) * (xB - w) > 0) { + fW = eval(func, w); + if (isMinim ? + fW < fC : + fW > fC) { + xA = xB; + xB = w; + fA = fB; + fB = fW; + break; + } else if (isMinim ? + fW > fB : + fW < fB) { + xC = w; + fC = fW; + break; + } + w = xC + GOLD * (xC - xB); + fW = eval(func, w); + } else if ((w - wLim) * (wLim - xC) >= 0) { + w = wLim; + fW = eval(func, w); + } else if ((w - wLim) * (xC - w) > 0) { + fW = eval(func, w); + if (isMinim ? + fW < fC : + fW > fC) { + xB = xC; + xC = w; + w = xC + GOLD * (xC -xB); + fB = fC; + fC =fW; + fW = eval(func, w); + } + } else { + w = xC + GOLD * (xC - xB); + fW = eval(func, w); + } + + xA = xB; + xB = xC; + xC = w; + fA = fB; + fB = fC; + fC = fW; + } + + lo = xA; + mid = xB; + hi = xC; + fLo = fA; + fMid = fB; + fHi = fC; + } + + /** + * @return the number of iterations. + */ + public int getIterations() { + return iterations; + } + /** + * @return the number of evalutations. + */ + public int getEvaluations() { + return evaluations; + } + + /** + * @return the lower bound of the bracket. + */ + public double getLo() { + return lo; + } + /** + * @return the higher bound of the bracket. + */ + public double getHi() { + return hi; + } + /** + * @return a point in the middle of the bracket. + */ + public double getMid() { + return mid; + } + + /** + * @param func Function. + * @param x Argument. + * @return {@code f(x)} + */ + private double eval(UnivariateRealFunction f, + double x) + throws FunctionEvaluationException { + + ++evaluations; + return f.value(x); + } + + /** + * Reset internal state. + */ + private void reset() { + iterations = 0; + evaluations = 0; + } +} diff --git a/src/test/java/org/apache/commons/math/analysis/SincFunction.java b/src/test/java/org/apache/commons/math/analysis/SincFunction.java new file mode 100644 index 000000000..5b717652c --- /dev/null +++ b/src/test/java/org/apache/commons/math/analysis/SincFunction.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.analysis; + +import org.apache.commons.math.FunctionEvaluationException; + +/** + * Auxiliary class for testing optimizers. + * + * @version $Revision$ $Date$ + */ +public class SincFunction implements DifferentiableUnivariateRealFunction { + private static final double EPS = 1e-12; + + /** + * @param x Argument. + * @return the value of this function at point {@code x}. + */ + public double value(double x) { + return (Math.abs(x) < EPS ? + 1 : + Math.sin(x) / x); + } + + /** + * {@inheritDoc} + */ + public UnivariateRealFunction derivative() { + return new UnivariateRealFunction() { + public double value(double x) { + return (Math.abs(x) < EPS ? + 0 : + (x * Math.cos(x) - Math.sin(x)) / (x * x)); + } + }; + } +} diff --git a/src/test/java/org/apache/commons/math/analysis/SumSincFunction.java b/src/test/java/org/apache/commons/math/analysis/SumSincFunction.java new file mode 100644 index 000000000..707507c41 --- /dev/null +++ b/src/test/java/org/apache/commons/math/analysis/SumSincFunction.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.analysis; + +import org.apache.commons.math.FunctionEvaluationException; + +/** + * Auxiliary class for testing optimizers. + * + * @version $Revision$ $Date$ + */ +public class SumSincFunction implements DifferentiableMultivariateRealFunction { + private static final DifferentiableUnivariateRealFunction sinc = new SincFunction(); + private static final UnivariateRealFunction sincDeriv = sinc.derivative(); + + /** + * Factor that will multiply each term of the sum. + */ + private final double factor; + + /** + * @param factor Factor that will multiply each term of the sum. + */ + public SumSincFunction(double factor) { + this.factor = factor; + } + + /** + * @param point Argument. + * @return the value of this function at point {@code x}. + */ + public double value(double[] point) throws FunctionEvaluationException { + double sum = 0; + for (int i = 0, max = point.length; i < max; i++) { + final double x = point[i]; + final double v = sinc.value(x); + sum += v; + } + return factor * sum; + } + + /** + * {@inheritDoc} + */ + public MultivariateRealFunction partialDerivative(final int k) { + return new MultivariateRealFunction() { + public double value(double[] point) throws FunctionEvaluationException { + return sincDeriv.value(point[k]); + } + }; + } + + /** + * {@inheritDoc} + */ + public MultivariateVectorialFunction gradient() { + return new MultivariateVectorialFunction() { + public double[] value(double[] point) + throws FunctionEvaluationException { + final int n = point.length; + final double[] r = new double[n]; + for (int i = 0; i < n; i++) { + final double x = point[i]; + r[i] = factor * sincDeriv.value(x); + } + return r; + } + }; + } +} diff --git a/src/test/java/org/apache/commons/math/optimization/general/PowellOptimizerTest.java b/src/test/java/org/apache/commons/math/optimization/general/PowellOptimizerTest.java new file mode 100644 index 000000000..c43efd81c --- /dev/null +++ b/src/test/java/org/apache/commons/math/optimization/general/PowellOptimizerTest.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.optimization.general; + +import java.util.Arrays; + +import org.apache.commons.math.MathException; +import org.apache.commons.math.FunctionEvaluationException; +import org.apache.commons.math.analysis.MultivariateRealFunction; +import org.apache.commons.math.analysis.SumSincFunction; +import org.apache.commons.math.optimization.GoalType; +import org.apache.commons.math.optimization.MultivariateRealOptimizer; +import org.apache.commons.math.optimization.SimpleScalarValueChecker; +import org.apache.commons.math.optimization.RealPointValuePair; + +import org.junit.Assert; +import org.junit.Test; + +/** + * Test for {@link PowellOptimizer}. + */ +public class PowellOptimizerTest { + + @Test + public void testSumSinc() throws MathException { + final MultivariateRealFunction func = new SumSincFunction(-1); + + int dim = 1; + final double[] minPoint = new double[dim]; + for (int i = 0; i < dim; i++) { + minPoint[i] = 0; + } + + double[] init = new double[dim]; + + // Initial is minimum. + for (int i = 0; i < dim; i++) { + init[i] = minPoint[i]; + } + doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-15, 1e-8); + + // Initial is far from minimum. + for (int i = 0; i < dim; i++) { + init[i] = minPoint[i] + 4; + } + doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-15, 1e-8); + } + + @Test + public void testQuadratic() throws MathException { + final MultivariateRealFunction func = new MultivariateRealFunction() { + public double value(double[] x) + throws FunctionEvaluationException { + final double a = x[0] - 1; + final double b = x[1] - 1; + return a * a + b * b + 1; + } + }; + + int dim = 2; + final double[] minPoint = new double[dim]; + for (int i = 0; i < dim; i++) { + minPoint[i] = 1; + } + + double[] init = new double[dim]; + + // Initial is minimum. + for (int i = 0; i < dim; i++) { + init[i] = minPoint[i]; + } + doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-15, 1e-8); + + // Initial is far from minimum. + for (int i = 0; i < dim; i++) { + init[i] = minPoint[i] - 20; + } + doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-15, 1e-8); + } + + @Test + public void testMaximizeQuadratic() throws MathException { + final MultivariateRealFunction func = new MultivariateRealFunction() { + public double value(double[] x) + throws FunctionEvaluationException { + final double a = x[0] - 1; + final double b = x[1] - 1; + return -a * a - b * b + 1; + } + }; + + int dim = 2; + final double[] maxPoint = new double[dim]; + for (int i = 0; i < dim; i++) { + maxPoint[i] = 1; + } + + double[] init = new double[dim]; + + // Initial is minimum. + for (int i = 0; i < dim; i++) { + init[i] = maxPoint[i]; + } + doTest(func, maxPoint, init, GoalType.MAXIMIZE, 1e-15, 1e-8); + + // Initial is far from minimum. + for (int i = 0; i < dim; i++) { + init[i] = maxPoint[i] - 20; + } + doTest(func, maxPoint, init, GoalType.MAXIMIZE, 1e-15, 1e-8); + } + + /** + * @param func Function to optimize. + * @param optimum Expected optimum. + * @param init Starting point. + * @param goal Minimization or maximization. + * @param objTol Tolerance (relative error on the objective function). + * @param pointTol Tolerance on the position of the optimum. + */ + private void doTest(MultivariateRealFunction func, + double[] optimum, + double[] init, + GoalType goal, + double objTol, + double pointTol) + throws MathException { + final MultivariateRealOptimizer optim = new PowellOptimizer(); + final double relTol = 1e-10; + optim.setConvergenceChecker(new SimpleScalarValueChecker(objTol, -1)); + + final RealPointValuePair result = optim.optimize(func, goal, init); + final double[] found = result.getPoint(); + + for (int i = 0, dim = optimum.length; i < dim; i++) { + Assert.assertEquals(optimum[i], found[i], pointTol); + } + } +} diff --git a/src/test/java/org/apache/commons/math/optimization/univariate/BracketFinderTest.java b/src/test/java/org/apache/commons/math/optimization/univariate/BracketFinderTest.java new file mode 100644 index 000000000..984c60bd6 --- /dev/null +++ b/src/test/java/org/apache/commons/math/optimization/univariate/BracketFinderTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.optimization.univariate; + +import org.apache.commons.math.MathException; +import org.apache.commons.math.FunctionEvaluationException; +import org.apache.commons.math.analysis.UnivariateRealFunction; +import org.apache.commons.math.optimization.GoalType; + +import org.junit.Assert; +import org.junit.Test; + +public class BracketFinderTest { + + @Test + public void testCubicMin() throws MathException { + final BracketFinder bFind = new BracketFinder(); + final UnivariateRealFunction func = new UnivariateRealFunction() { + public double value(double x) + throws FunctionEvaluationException { + if (x < -2) { + return value(-2); + } + else { + return (x - 1) * (x + 2) * (x + 3); + } + } + }; + + bFind.search(func, GoalType.MINIMIZE, -2 , -1); + final double tol = 1e-15; + // Comparing with results computed in Python. + Assert.assertEquals(-2, bFind.getLo(), tol); + Assert.assertEquals(-1, bFind.getMid(), tol); + Assert.assertEquals(0.61803399999999997, bFind.getHi(), tol); + } + + @Test + public void testCubicMax() throws MathException { + final BracketFinder bFind = new BracketFinder(); + final UnivariateRealFunction func = new UnivariateRealFunction() { + public double value(double x) + throws FunctionEvaluationException { + if (x < -2) { + return value(-2); + } + else { + return -(x - 1) * (x + 2) * (x + 3); + } + } + }; + + bFind.search(func, GoalType.MAXIMIZE, -2 , -1); + final double tol = 1e-15; + Assert.assertEquals(-2, bFind.getLo(), tol); + Assert.assertEquals(-1, bFind.getMid(), tol); + Assert.assertEquals(0.61803399999999997, bFind.getHi(), tol); + } +}