MATH-396.

Introducing "PowellOptimizer" which uses "BracketFinder".
"PowellOptimizerTest" uses "SumSincFunction" which uses "SincFunction".


git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@979460 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Gilles Sadowski 2010-07-26 21:47:32 +00:00
parent e0723d16df
commit c917b3b3fa
7 changed files with 880 additions and 0 deletions

View File

@ -5,6 +5,14 @@ This product includes software developed by
The Apache Software Foundation (http://www.apache.org/). The Apache Software Foundation (http://www.apache.org/).
=============================================================================== ===============================================================================
The BracketFinder (package org.apache.commons.math.optimization.univariate)
and PowellOptimizer (package org.apache.commons.math.optimization.general)
classes are based on the Python code in module "optimize.py" (version 0.5)
developed by Travis E. Oliphant for the SciPy library (http://www.scipy.org/)
Copyright © 2003-2009 SciPy Developers.
===============================================================================
The LinearConstraint, LinearObjectiveFunction, LinearOptimizer, The LinearConstraint, LinearObjectiveFunction, LinearOptimizer,
RelationShip, SimplexSolver and SimplexTableau classes in package RelationShip, SimplexSolver and SimplexTableau classes in package
org.apache.commons.math.optimization.linear include software developed by org.apache.commons.math.optimization.linear include software developed by

View File

@ -0,0 +1,250 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization.general;
import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.analysis.MultivariateRealFunction;
import org.apache.commons.math.analysis.UnivariateRealFunction;
import org.apache.commons.math.optimization.univariate.AbstractUnivariateRealOptimizer;
import org.apache.commons.math.optimization.univariate.BrentOptimizer;
import org.apache.commons.math.optimization.univariate.BracketFinder;
import org.apache.commons.math.optimization.GoalType;
import org.apache.commons.math.optimization.MultivariateRealOptimizer;
import org.apache.commons.math.optimization.OptimizationException;
import org.apache.commons.math.optimization.RealPointValuePair;
import org.apache.commons.math.optimization.SimpleScalarValueChecker;
/**
* Powell algorithm.
* This code is translated and adapted from the Python version of this
* algorithm (as implemented in module {@code optimize.py} v0.5 of
* <em>SciPy</em>).
*
* @version $Revision$ $Date$
* @since 2.2
*/
public class PowellOptimizer
extends AbstractScalarOptimizer {
/**
* Defautl line search tolerance ({@value}).
*/
public static final double DEFAULT_LINE_SEARCH_TOLERANCE = 1e-7;
/**
* Line search.
*/
private final LineSearch line;
/**
* Constructor using the default line search tolerance (see the
* {@link #PowellOptimizer(double) other constructor}).
*/
public PowellOptimizer() {
this(DEFAULT_LINE_SEARCH_TOLERANCE);
}
/**
* @param lineSearchTolerance Relative error tolerance for the line search
* algorithm ({@link BrentOptimizer}).
*/
public PowellOptimizer(double lineSearchTolerance) {
line = new LineSearch(lineSearchTolerance);
}
/** {@inheritDoc} */
@Override
protected RealPointValuePair doOptimize()
throws FunctionEvaluationException,
OptimizationException {
final GoalType goal = getGoalType();
final double[] guess = getStartPoint();
final int n = guess.length;
final double[][] direc = new double[n][n];
for (int i = 0; i < n; i++) {
direc[i][i] = 1;
}
double[] x = guess;
double fVal = computeObjectiveValue(x);
double[] x1 = x.clone();
while (true) {
incrementIterationsCounter();
double fX = fVal;
double fX2 = 0;
double delta = 0;
int bigInd = 0;
double alphaMin = 0;
double[] direc1 = new double[n];
for (int i = 0; i < n; i++) {
direc1 = direc[i];
fX2 = fVal;
line.search(x, direc1);
fVal = line.getValueAtOptimum();
alphaMin = line.getOptimum();
setNewPointAndDirection(x, direc1, alphaMin);
if ((fX2 - fVal) > delta) {
delta = fX2 - fVal;
bigInd = i;
}
}
final RealPointValuePair previous = new RealPointValuePair(x1, fX);
final RealPointValuePair current = new RealPointValuePair(x, fVal);
if (getConvergenceChecker().converged(getIterations(), previous, current)) {
switch (goal) {
case MINIMIZE:
return (fVal < fX ? current : previous);
case MAXIMIZE:
return (fVal > fX ? current : previous);
}
}
double[] x2 = new double[n];
for (int i = 0; i < n; i++) {
direc1[i] = x[i] - x1[i];
x2[i] = 2 * x[i] - x1[i];
}
x1 = x.clone();
fX2 = computeObjectiveValue(x2);
if (fX > fX2) {
double t = 2 * (fX + fX2 - 2 * fVal);
double temp = fX - fVal - delta;
t *= temp * temp;
temp = fX - fX2;
t -= delta * temp * temp;
if (t < 0.0) {
line.search(x, direc1);
fVal = line.getValueAtOptimum();
alphaMin = line.getOptimum();
setNewPointAndDirection(x, direc1, alphaMin);
final int lastInd = n - 1;
direc[bigInd] = direc[lastInd];
direc[lastInd] = direc1;
}
}
}
}
/**
* Compute a new point (in the original space) and a new direction
* vector, resulting from the line search.
* The parameters {@code p} and {@code d} will be changed in-place.
*
* @param p Point used in the line search.
* @param d Direction used in the line search.
* @param optimum Optimum found by the line search.
*/
private void setNewPointAndDirection(double[] p,
double[] d,
double optimum) {
final int n = p.length;
for (int i = 0; i < n; i++) {
d[i] *= optimum;
p[i] += d[i];
}
}
/**
* Class for finding the minimum of the objective function along a given
* direction.
*/
private class LineSearch {
/**
* Optimizer.
*/
private final AbstractUnivariateRealOptimizer optim = new BrentOptimizer();
/**
* Automatic bracketing.
*/
private final BracketFinder bracket = new BracketFinder();
/**
* Value of the optimum.
*/
private double optimum = Double.NaN;
/**
* Value of the objective function at the optimum.
*/
private double valueAtOptimum = Double.NaN;
/**
* @param tolerance Relative tolerance.
*/
public LineSearch(double tolerance) {
optim.setRelativeAccuracy(tolerance);
optim.setAbsoluteAccuracy(Math.ulp(1d));
}
/**
* Find the minimum of the function {@code f(p + alpha * d)}.
*
* @param p Starting point.
* @param d Search direction.
*/
public void search(final double[] p,
final double[] d)
throws OptimizationException {
try {
final int n = p.length;
final UnivariateRealFunction f = new UnivariateRealFunction() {
public double value(double alpha)
throws FunctionEvaluationException {
final double[] x = new double[n];
for (int i = 0; i < n; i++) {
x[i] = p[i] + alpha * d[i];
}
return computeObjectiveValue(x);
}
};
final GoalType goal = getGoalType();
bracket.search(f, goal, 0, 1);
optimum = optim.optimize(f, goal,
bracket.getLo(),
bracket.getHi(),
bracket.getMid());
valueAtOptimum = f.value(optimum);
} catch (Exception e) {
throw new OptimizationException(e);
}
}
/**
* @return the optimum.
*/
public double getOptimum() {
return optimum;
}
/**
* @return the value of the function at the optimum.
*/
public double getValueAtOptimum() {
return valueAtOptimum;
}
}
}

View File

@ -0,0 +1,261 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization.univariate;
import org.apache.commons.math.exception.NotStrictlyPositiveException;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.MaxIterationsExceededException;
import org.apache.commons.math.analysis.UnivariateRealFunction;
import org.apache.commons.math.optimization.GoalType;
/**
* Provide an interval that brackets a local optimum of a function.
* This code is based on a Python implementation (from <em>SciPy</em>,
* module {@code optimize.py} v0.5).
*/
public class BracketFinder {
private static final double EPS_MIN = 1e-21;
/**
* Golden section.
*/
private static final double GOLD = 1.618034;
/**
* Factor for expanding the interval.
*/
private final double growLimit;
/**
* Maximum number of iterations.
*/
private final int maxIterations;
/**
* Number of iterations.
*/
private int iterations;
/**
* Number of function evaluations.
*/
private int evaluations;
/**
* Lower bound of the bracket.
*/
private double lo;
/**
* Higher bound of the bracket.
*/
private double hi;
/**
* Point inside the bracket.
*/
private double mid;
/**
* Function value at {@link #lo}.
*/
private double fLo;
/**
* Function value at {@link #hi}.
*/
private double fHi;
/**
* Function value at {@link #mid}.
*/
private double fMid;
/**
* Constructor with default values {@code 100, 50} (see the
* {@link #BracketFinder(double,int) other constructor}).
*/
public BracketFinder() {
this(100, 50);
}
/**
* Create a bracketing interval finder.
*
* @param growLimit Expanding factor.
* @param maxIterations Maximum number of iterations allowed for finding
* a bracketing interval.
*/
public BracketFinder(double growLimit,
int maxIterations) {
if (growLimit <= 0) {
throw new NotStrictlyPositiveException(growLimit);
}
if (maxIterations <= 0) {
throw new NotStrictlyPositiveException(maxIterations);
}
this.growLimit = growLimit;
this.maxIterations = maxIterations;
}
/**
* Search new points that bracket a local optimum of the function.
*
* @param func Function whose optimum should be bracketted.
* @param goal {@link GoalType Goal type}.
* @param xA Initial point.
* @param xB Initial point.
*/
public void search(UnivariateRealFunction func,
GoalType goal,
double xA,
double xB)
throws MaxIterationsExceededException,
FunctionEvaluationException {
reset();
final boolean isMinim = (goal == GoalType.MINIMIZE);
double fA = eval(func, xA);
double fB = eval(func, xB);
if (isMinim ?
fA < fB :
fA > fB) {
double tmp = xA;
xA = xB;
xB = tmp;
tmp = fA;
fA = fB;
fB = tmp;
}
double xC = xB + GOLD * (xB - xA);
double fC = eval(func, xC);
while (isMinim ? fC < fB : fC > fB) {
if (++iterations > maxIterations) {
throw new MaxIterationsExceededException(maxIterations);
}
double tmp1 = (xB - xA) * (fB - fC);
double tmp2 = (xB - xC) * (fB - fA);
double val = tmp2 - tmp1;
double denom = Math.abs(val) < EPS_MIN ? 2 * EPS_MIN : 2 * val;
double w = xB - ((xB - xC) * tmp2 - (xB -xA) * tmp1) / denom;
double wLim = xB + growLimit * (xC - xB);
double fW;
if ((w - xC) * (xB - w) > 0) {
fW = eval(func, w);
if (isMinim ?
fW < fC :
fW > fC) {
xA = xB;
xB = w;
fA = fB;
fB = fW;
break;
} else if (isMinim ?
fW > fB :
fW < fB) {
xC = w;
fC = fW;
break;
}
w = xC + GOLD * (xC - xB);
fW = eval(func, w);
} else if ((w - wLim) * (wLim - xC) >= 0) {
w = wLim;
fW = eval(func, w);
} else if ((w - wLim) * (xC - w) > 0) {
fW = eval(func, w);
if (isMinim ?
fW < fC :
fW > fC) {
xB = xC;
xC = w;
w = xC + GOLD * (xC -xB);
fB = fC;
fC =fW;
fW = eval(func, w);
}
} else {
w = xC + GOLD * (xC - xB);
fW = eval(func, w);
}
xA = xB;
xB = xC;
xC = w;
fA = fB;
fB = fC;
fC = fW;
}
lo = xA;
mid = xB;
hi = xC;
fLo = fA;
fMid = fB;
fHi = fC;
}
/**
* @return the number of iterations.
*/
public int getIterations() {
return iterations;
}
/**
* @return the number of evalutations.
*/
public int getEvaluations() {
return evaluations;
}
/**
* @return the lower bound of the bracket.
*/
public double getLo() {
return lo;
}
/**
* @return the higher bound of the bracket.
*/
public double getHi() {
return hi;
}
/**
* @return a point in the middle of the bracket.
*/
public double getMid() {
return mid;
}
/**
* @param func Function.
* @param x Argument.
* @return {@code f(x)}
*/
private double eval(UnivariateRealFunction f,
double x)
throws FunctionEvaluationException {
++evaluations;
return f.value(x);
}
/**
* Reset internal state.
*/
private void reset() {
iterations = 0;
evaluations = 0;
}
}

View File

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.analysis;
import org.apache.commons.math.FunctionEvaluationException;
/**
* Auxiliary class for testing optimizers.
*
* @version $Revision$ $Date$
*/
public class SincFunction implements DifferentiableUnivariateRealFunction {
private static final double EPS = 1e-12;
/**
* @param x Argument.
* @return the value of this function at point {@code x}.
*/
public double value(double x) {
return (Math.abs(x) < EPS ?
1 :
Math.sin(x) / x);
}
/**
* {@inheritDoc}
*/
public UnivariateRealFunction derivative() {
return new UnivariateRealFunction() {
public double value(double x) {
return (Math.abs(x) < EPS ?
0 :
(x * Math.cos(x) - Math.sin(x)) / (x * x));
}
};
}
}

View File

@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.analysis;
import org.apache.commons.math.FunctionEvaluationException;
/**
* Auxiliary class for testing optimizers.
*
* @version $Revision$ $Date$
*/
public class SumSincFunction implements DifferentiableMultivariateRealFunction {
private static final DifferentiableUnivariateRealFunction sinc = new SincFunction();
private static final UnivariateRealFunction sincDeriv = sinc.derivative();
/**
* Factor that will multiply each term of the sum.
*/
private final double factor;
/**
* @param factor Factor that will multiply each term of the sum.
*/
public SumSincFunction(double factor) {
this.factor = factor;
}
/**
* @param point Argument.
* @return the value of this function at point {@code x}.
*/
public double value(double[] point) throws FunctionEvaluationException {
double sum = 0;
for (int i = 0, max = point.length; i < max; i++) {
final double x = point[i];
final double v = sinc.value(x);
sum += v;
}
return factor * sum;
}
/**
* {@inheritDoc}
*/
public MultivariateRealFunction partialDerivative(final int k) {
return new MultivariateRealFunction() {
public double value(double[] point) throws FunctionEvaluationException {
return sincDeriv.value(point[k]);
}
};
}
/**
* {@inheritDoc}
*/
public MultivariateVectorialFunction gradient() {
return new MultivariateVectorialFunction() {
public double[] value(double[] point)
throws FunctionEvaluationException {
final int n = point.length;
final double[] r = new double[n];
for (int i = 0; i < n; i++) {
final double x = point[i];
r[i] = factor * sincDeriv.value(x);
}
return r;
}
};
}
}

View File

@ -0,0 +1,153 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization.general;
import java.util.Arrays;
import org.apache.commons.math.MathException;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.analysis.MultivariateRealFunction;
import org.apache.commons.math.analysis.SumSincFunction;
import org.apache.commons.math.optimization.GoalType;
import org.apache.commons.math.optimization.MultivariateRealOptimizer;
import org.apache.commons.math.optimization.SimpleScalarValueChecker;
import org.apache.commons.math.optimization.RealPointValuePair;
import org.junit.Assert;
import org.junit.Test;
/**
* Test for {@link PowellOptimizer}.
*/
public class PowellOptimizerTest {
@Test
public void testSumSinc() throws MathException {
final MultivariateRealFunction func = new SumSincFunction(-1);
int dim = 1;
final double[] minPoint = new double[dim];
for (int i = 0; i < dim; i++) {
minPoint[i] = 0;
}
double[] init = new double[dim];
// Initial is minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i];
}
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-15, 1e-8);
// Initial is far from minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i] + 4;
}
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-15, 1e-8);
}
@Test
public void testQuadratic() throws MathException {
final MultivariateRealFunction func = new MultivariateRealFunction() {
public double value(double[] x)
throws FunctionEvaluationException {
final double a = x[0] - 1;
final double b = x[1] - 1;
return a * a + b * b + 1;
}
};
int dim = 2;
final double[] minPoint = new double[dim];
for (int i = 0; i < dim; i++) {
minPoint[i] = 1;
}
double[] init = new double[dim];
// Initial is minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i];
}
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-15, 1e-8);
// Initial is far from minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i] - 20;
}
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-15, 1e-8);
}
@Test
public void testMaximizeQuadratic() throws MathException {
final MultivariateRealFunction func = new MultivariateRealFunction() {
public double value(double[] x)
throws FunctionEvaluationException {
final double a = x[0] - 1;
final double b = x[1] - 1;
return -a * a - b * b + 1;
}
};
int dim = 2;
final double[] maxPoint = new double[dim];
for (int i = 0; i < dim; i++) {
maxPoint[i] = 1;
}
double[] init = new double[dim];
// Initial is minimum.
for (int i = 0; i < dim; i++) {
init[i] = maxPoint[i];
}
doTest(func, maxPoint, init, GoalType.MAXIMIZE, 1e-15, 1e-8);
// Initial is far from minimum.
for (int i = 0; i < dim; i++) {
init[i] = maxPoint[i] - 20;
}
doTest(func, maxPoint, init, GoalType.MAXIMIZE, 1e-15, 1e-8);
}
/**
* @param func Function to optimize.
* @param optimum Expected optimum.
* @param init Starting point.
* @param goal Minimization or maximization.
* @param objTol Tolerance (relative error on the objective function).
* @param pointTol Tolerance on the position of the optimum.
*/
private void doTest(MultivariateRealFunction func,
double[] optimum,
double[] init,
GoalType goal,
double objTol,
double pointTol)
throws MathException {
final MultivariateRealOptimizer optim = new PowellOptimizer();
final double relTol = 1e-10;
optim.setConvergenceChecker(new SimpleScalarValueChecker(objTol, -1));
final RealPointValuePair result = optim.optimize(func, goal, init);
final double[] found = result.getPoint();
for (int i = 0, dim = optimum.length; i < dim; i++) {
Assert.assertEquals(optimum[i], found[i], pointTol);
}
}
}

View File

@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization.univariate;
import org.apache.commons.math.MathException;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.analysis.UnivariateRealFunction;
import org.apache.commons.math.optimization.GoalType;
import org.junit.Assert;
import org.junit.Test;
public class BracketFinderTest {
@Test
public void testCubicMin() throws MathException {
final BracketFinder bFind = new BracketFinder();
final UnivariateRealFunction func = new UnivariateRealFunction() {
public double value(double x)
throws FunctionEvaluationException {
if (x < -2) {
return value(-2);
}
else {
return (x - 1) * (x + 2) * (x + 3);
}
}
};
bFind.search(func, GoalType.MINIMIZE, -2 , -1);
final double tol = 1e-15;
// Comparing with results computed in Python.
Assert.assertEquals(-2, bFind.getLo(), tol);
Assert.assertEquals(-1, bFind.getMid(), tol);
Assert.assertEquals(0.61803399999999997, bFind.getHi(), tol);
}
@Test
public void testCubicMax() throws MathException {
final BracketFinder bFind = new BracketFinder();
final UnivariateRealFunction func = new UnivariateRealFunction() {
public double value(double x)
throws FunctionEvaluationException {
if (x < -2) {
return value(-2);
}
else {
return -(x - 1) * (x + 2) * (x + 3);
}
}
};
bFind.search(func, GoalType.MAXIMIZE, -2 , -1);
final double tol = 1e-15;
Assert.assertEquals(-2, bFind.getLo(), tol);
Assert.assertEquals(-1, bFind.getMid(), tol);
Assert.assertEquals(0.61803399999999997, bFind.getHi(), tol);
}
}