MATH-396.

Introducing "PowellOptimizer" which uses "BracketFinder".
"PowellOptimizerTest" uses "SumSincFunction" which uses "SincFunction".


git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@979460 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Gilles Sadowski 2010-07-26 21:47:32 +00:00
parent e0723d16df
commit c917b3b3fa
7 changed files with 880 additions and 0 deletions

View File

@ -5,6 +5,14 @@ This product includes software developed by
The Apache Software Foundation (http://www.apache.org/).
===============================================================================
The BracketFinder (package org.apache.commons.math.optimization.univariate)
and PowellOptimizer (package org.apache.commons.math.optimization.general)
classes are based on the Python code in module "optimize.py" (version 0.5)
developed by Travis E. Oliphant for the SciPy library (http://www.scipy.org/)
Copyright © 2003-2009 SciPy Developers.
===============================================================================
The LinearConstraint, LinearObjectiveFunction, LinearOptimizer,
RelationShip, SimplexSolver and SimplexTableau classes in package
org.apache.commons.math.optimization.linear include software developed by

View File

@ -0,0 +1,250 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization.general;
import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.analysis.MultivariateRealFunction;
import org.apache.commons.math.analysis.UnivariateRealFunction;
import org.apache.commons.math.optimization.univariate.AbstractUnivariateRealOptimizer;
import org.apache.commons.math.optimization.univariate.BrentOptimizer;
import org.apache.commons.math.optimization.univariate.BracketFinder;
import org.apache.commons.math.optimization.GoalType;
import org.apache.commons.math.optimization.MultivariateRealOptimizer;
import org.apache.commons.math.optimization.OptimizationException;
import org.apache.commons.math.optimization.RealPointValuePair;
import org.apache.commons.math.optimization.SimpleScalarValueChecker;
/**
* Powell algorithm.
* This code is translated and adapted from the Python version of this
* algorithm (as implemented in module {@code optimize.py} v0.5 of
* <em>SciPy</em>).
*
* @version $Revision$ $Date$
* @since 2.2
*/
public class PowellOptimizer
extends AbstractScalarOptimizer {
/**
* Defautl line search tolerance ({@value}).
*/
public static final double DEFAULT_LINE_SEARCH_TOLERANCE = 1e-7;
/**
* Line search.
*/
private final LineSearch line;
/**
* Constructor using the default line search tolerance (see the
* {@link #PowellOptimizer(double) other constructor}).
*/
public PowellOptimizer() {
this(DEFAULT_LINE_SEARCH_TOLERANCE);
}
/**
* @param lineSearchTolerance Relative error tolerance for the line search
* algorithm ({@link BrentOptimizer}).
*/
public PowellOptimizer(double lineSearchTolerance) {
line = new LineSearch(lineSearchTolerance);
}
/** {@inheritDoc} */
@Override
protected RealPointValuePair doOptimize()
throws FunctionEvaluationException,
OptimizationException {
final GoalType goal = getGoalType();
final double[] guess = getStartPoint();
final int n = guess.length;
final double[][] direc = new double[n][n];
for (int i = 0; i < n; i++) {
direc[i][i] = 1;
}
double[] x = guess;
double fVal = computeObjectiveValue(x);
double[] x1 = x.clone();
while (true) {
incrementIterationsCounter();
double fX = fVal;
double fX2 = 0;
double delta = 0;
int bigInd = 0;
double alphaMin = 0;
double[] direc1 = new double[n];
for (int i = 0; i < n; i++) {
direc1 = direc[i];
fX2 = fVal;
line.search(x, direc1);
fVal = line.getValueAtOptimum();
alphaMin = line.getOptimum();
setNewPointAndDirection(x, direc1, alphaMin);
if ((fX2 - fVal) > delta) {
delta = fX2 - fVal;
bigInd = i;
}
}
final RealPointValuePair previous = new RealPointValuePair(x1, fX);
final RealPointValuePair current = new RealPointValuePair(x, fVal);
if (getConvergenceChecker().converged(getIterations(), previous, current)) {
switch (goal) {
case MINIMIZE:
return (fVal < fX ? current : previous);
case MAXIMIZE:
return (fVal > fX ? current : previous);
}
}
double[] x2 = new double[n];
for (int i = 0; i < n; i++) {
direc1[i] = x[i] - x1[i];
x2[i] = 2 * x[i] - x1[i];
}
x1 = x.clone();
fX2 = computeObjectiveValue(x2);
if (fX > fX2) {
double t = 2 * (fX + fX2 - 2 * fVal);
double temp = fX - fVal - delta;
t *= temp * temp;
temp = fX - fX2;
t -= delta * temp * temp;
if (t < 0.0) {
line.search(x, direc1);
fVal = line.getValueAtOptimum();
alphaMin = line.getOptimum();
setNewPointAndDirection(x, direc1, alphaMin);
final int lastInd = n - 1;
direc[bigInd] = direc[lastInd];
direc[lastInd] = direc1;
}
}
}
}
/**
* Compute a new point (in the original space) and a new direction
* vector, resulting from the line search.
* The parameters {@code p} and {@code d} will be changed in-place.
*
* @param p Point used in the line search.
* @param d Direction used in the line search.
* @param optimum Optimum found by the line search.
*/
private void setNewPointAndDirection(double[] p,
double[] d,
double optimum) {
final int n = p.length;
for (int i = 0; i < n; i++) {
d[i] *= optimum;
p[i] += d[i];
}
}
/**
* Class for finding the minimum of the objective function along a given
* direction.
*/
private class LineSearch {
/**
* Optimizer.
*/
private final AbstractUnivariateRealOptimizer optim = new BrentOptimizer();
/**
* Automatic bracketing.
*/
private final BracketFinder bracket = new BracketFinder();
/**
* Value of the optimum.
*/
private double optimum = Double.NaN;
/**
* Value of the objective function at the optimum.
*/
private double valueAtOptimum = Double.NaN;
/**
* @param tolerance Relative tolerance.
*/
public LineSearch(double tolerance) {
optim.setRelativeAccuracy(tolerance);
optim.setAbsoluteAccuracy(Math.ulp(1d));
}
/**
* Find the minimum of the function {@code f(p + alpha * d)}.
*
* @param p Starting point.
* @param d Search direction.
*/
public void search(final double[] p,
final double[] d)
throws OptimizationException {
try {
final int n = p.length;
final UnivariateRealFunction f = new UnivariateRealFunction() {
public double value(double alpha)
throws FunctionEvaluationException {
final double[] x = new double[n];
for (int i = 0; i < n; i++) {
x[i] = p[i] + alpha * d[i];
}
return computeObjectiveValue(x);
}
};
final GoalType goal = getGoalType();
bracket.search(f, goal, 0, 1);
optimum = optim.optimize(f, goal,
bracket.getLo(),
bracket.getHi(),
bracket.getMid());
valueAtOptimum = f.value(optimum);
} catch (Exception e) {
throw new OptimizationException(e);
}
}
/**
* @return the optimum.
*/
public double getOptimum() {
return optimum;
}
/**
* @return the value of the function at the optimum.
*/
public double getValueAtOptimum() {
return valueAtOptimum;
}
}
}

View File

@ -0,0 +1,261 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization.univariate;
import org.apache.commons.math.exception.NotStrictlyPositiveException;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.MaxIterationsExceededException;
import org.apache.commons.math.analysis.UnivariateRealFunction;
import org.apache.commons.math.optimization.GoalType;
/**
* Provide an interval that brackets a local optimum of a function.
* This code is based on a Python implementation (from <em>SciPy</em>,
* module {@code optimize.py} v0.5).
*/
public class BracketFinder {
private static final double EPS_MIN = 1e-21;
/**
* Golden section.
*/
private static final double GOLD = 1.618034;
/**
* Factor for expanding the interval.
*/
private final double growLimit;
/**
* Maximum number of iterations.
*/
private final int maxIterations;
/**
* Number of iterations.
*/
private int iterations;
/**
* Number of function evaluations.
*/
private int evaluations;
/**
* Lower bound of the bracket.
*/
private double lo;
/**
* Higher bound of the bracket.
*/
private double hi;
/**
* Point inside the bracket.
*/
private double mid;
/**
* Function value at {@link #lo}.
*/
private double fLo;
/**
* Function value at {@link #hi}.
*/
private double fHi;
/**
* Function value at {@link #mid}.
*/
private double fMid;
/**
* Constructor with default values {@code 100, 50} (see the
* {@link #BracketFinder(double,int) other constructor}).
*/
public BracketFinder() {
this(100, 50);
}
/**
* Create a bracketing interval finder.
*
* @param growLimit Expanding factor.
* @param maxIterations Maximum number of iterations allowed for finding
* a bracketing interval.
*/
public BracketFinder(double growLimit,
int maxIterations) {
if (growLimit <= 0) {
throw new NotStrictlyPositiveException(growLimit);
}
if (maxIterations <= 0) {
throw new NotStrictlyPositiveException(maxIterations);
}
this.growLimit = growLimit;
this.maxIterations = maxIterations;
}
/**
* Search new points that bracket a local optimum of the function.
*
* @param func Function whose optimum should be bracketted.
* @param goal {@link GoalType Goal type}.
* @param xA Initial point.
* @param xB Initial point.
*/
public void search(UnivariateRealFunction func,
GoalType goal,
double xA,
double xB)
throws MaxIterationsExceededException,
FunctionEvaluationException {
reset();
final boolean isMinim = (goal == GoalType.MINIMIZE);
double fA = eval(func, xA);
double fB = eval(func, xB);
if (isMinim ?
fA < fB :
fA > fB) {
double tmp = xA;
xA = xB;
xB = tmp;
tmp = fA;
fA = fB;
fB = tmp;
}
double xC = xB + GOLD * (xB - xA);
double fC = eval(func, xC);
while (isMinim ? fC < fB : fC > fB) {
if (++iterations > maxIterations) {
throw new MaxIterationsExceededException(maxIterations);
}
double tmp1 = (xB - xA) * (fB - fC);
double tmp2 = (xB - xC) * (fB - fA);
double val = tmp2 - tmp1;
double denom = Math.abs(val) < EPS_MIN ? 2 * EPS_MIN : 2 * val;
double w = xB - ((xB - xC) * tmp2 - (xB -xA) * tmp1) / denom;
double wLim = xB + growLimit * (xC - xB);
double fW;
if ((w - xC) * (xB - w) > 0) {
fW = eval(func, w);
if (isMinim ?
fW < fC :
fW > fC) {
xA = xB;
xB = w;
fA = fB;
fB = fW;
break;
} else if (isMinim ?
fW > fB :
fW < fB) {
xC = w;
fC = fW;
break;
}
w = xC + GOLD * (xC - xB);
fW = eval(func, w);
} else if ((w - wLim) * (wLim - xC) >= 0) {
w = wLim;
fW = eval(func, w);
} else if ((w - wLim) * (xC - w) > 0) {
fW = eval(func, w);
if (isMinim ?
fW < fC :
fW > fC) {
xB = xC;
xC = w;
w = xC + GOLD * (xC -xB);
fB = fC;
fC =fW;
fW = eval(func, w);
}
} else {
w = xC + GOLD * (xC - xB);
fW = eval(func, w);
}
xA = xB;
xB = xC;
xC = w;
fA = fB;
fB = fC;
fC = fW;
}
lo = xA;
mid = xB;
hi = xC;
fLo = fA;
fMid = fB;
fHi = fC;
}
/**
* @return the number of iterations.
*/
public int getIterations() {
return iterations;
}
/**
* @return the number of evalutations.
*/
public int getEvaluations() {
return evaluations;
}
/**
* @return the lower bound of the bracket.
*/
public double getLo() {
return lo;
}
/**
* @return the higher bound of the bracket.
*/
public double getHi() {
return hi;
}
/**
* @return a point in the middle of the bracket.
*/
public double getMid() {
return mid;
}
/**
* @param func Function.
* @param x Argument.
* @return {@code f(x)}
*/
private double eval(UnivariateRealFunction f,
double x)
throws FunctionEvaluationException {
++evaluations;
return f.value(x);
}
/**
* Reset internal state.
*/
private void reset() {
iterations = 0;
evaluations = 0;
}
}

View File

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.analysis;
import org.apache.commons.math.FunctionEvaluationException;
/**
* Auxiliary class for testing optimizers.
*
* @version $Revision$ $Date$
*/
public class SincFunction implements DifferentiableUnivariateRealFunction {
private static final double EPS = 1e-12;
/**
* @param x Argument.
* @return the value of this function at point {@code x}.
*/
public double value(double x) {
return (Math.abs(x) < EPS ?
1 :
Math.sin(x) / x);
}
/**
* {@inheritDoc}
*/
public UnivariateRealFunction derivative() {
return new UnivariateRealFunction() {
public double value(double x) {
return (Math.abs(x) < EPS ?
0 :
(x * Math.cos(x) - Math.sin(x)) / (x * x));
}
};
}
}

View File

@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.analysis;
import org.apache.commons.math.FunctionEvaluationException;
/**
* Auxiliary class for testing optimizers.
*
* @version $Revision$ $Date$
*/
public class SumSincFunction implements DifferentiableMultivariateRealFunction {
private static final DifferentiableUnivariateRealFunction sinc = new SincFunction();
private static final UnivariateRealFunction sincDeriv = sinc.derivative();
/**
* Factor that will multiply each term of the sum.
*/
private final double factor;
/**
* @param factor Factor that will multiply each term of the sum.
*/
public SumSincFunction(double factor) {
this.factor = factor;
}
/**
* @param point Argument.
* @return the value of this function at point {@code x}.
*/
public double value(double[] point) throws FunctionEvaluationException {
double sum = 0;
for (int i = 0, max = point.length; i < max; i++) {
final double x = point[i];
final double v = sinc.value(x);
sum += v;
}
return factor * sum;
}
/**
* {@inheritDoc}
*/
public MultivariateRealFunction partialDerivative(final int k) {
return new MultivariateRealFunction() {
public double value(double[] point) throws FunctionEvaluationException {
return sincDeriv.value(point[k]);
}
};
}
/**
* {@inheritDoc}
*/
public MultivariateVectorialFunction gradient() {
return new MultivariateVectorialFunction() {
public double[] value(double[] point)
throws FunctionEvaluationException {
final int n = point.length;
final double[] r = new double[n];
for (int i = 0; i < n; i++) {
final double x = point[i];
r[i] = factor * sincDeriv.value(x);
}
return r;
}
};
}
}

View File

@ -0,0 +1,153 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization.general;
import java.util.Arrays;
import org.apache.commons.math.MathException;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.analysis.MultivariateRealFunction;
import org.apache.commons.math.analysis.SumSincFunction;
import org.apache.commons.math.optimization.GoalType;
import org.apache.commons.math.optimization.MultivariateRealOptimizer;
import org.apache.commons.math.optimization.SimpleScalarValueChecker;
import org.apache.commons.math.optimization.RealPointValuePair;
import org.junit.Assert;
import org.junit.Test;
/**
* Test for {@link PowellOptimizer}.
*/
public class PowellOptimizerTest {
@Test
public void testSumSinc() throws MathException {
final MultivariateRealFunction func = new SumSincFunction(-1);
int dim = 1;
final double[] minPoint = new double[dim];
for (int i = 0; i < dim; i++) {
minPoint[i] = 0;
}
double[] init = new double[dim];
// Initial is minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i];
}
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-15, 1e-8);
// Initial is far from minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i] + 4;
}
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-15, 1e-8);
}
@Test
public void testQuadratic() throws MathException {
final MultivariateRealFunction func = new MultivariateRealFunction() {
public double value(double[] x)
throws FunctionEvaluationException {
final double a = x[0] - 1;
final double b = x[1] - 1;
return a * a + b * b + 1;
}
};
int dim = 2;
final double[] minPoint = new double[dim];
for (int i = 0; i < dim; i++) {
minPoint[i] = 1;
}
double[] init = new double[dim];
// Initial is minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i];
}
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-15, 1e-8);
// Initial is far from minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i] - 20;
}
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-15, 1e-8);
}
@Test
public void testMaximizeQuadratic() throws MathException {
final MultivariateRealFunction func = new MultivariateRealFunction() {
public double value(double[] x)
throws FunctionEvaluationException {
final double a = x[0] - 1;
final double b = x[1] - 1;
return -a * a - b * b + 1;
}
};
int dim = 2;
final double[] maxPoint = new double[dim];
for (int i = 0; i < dim; i++) {
maxPoint[i] = 1;
}
double[] init = new double[dim];
// Initial is minimum.
for (int i = 0; i < dim; i++) {
init[i] = maxPoint[i];
}
doTest(func, maxPoint, init, GoalType.MAXIMIZE, 1e-15, 1e-8);
// Initial is far from minimum.
for (int i = 0; i < dim; i++) {
init[i] = maxPoint[i] - 20;
}
doTest(func, maxPoint, init, GoalType.MAXIMIZE, 1e-15, 1e-8);
}
/**
* @param func Function to optimize.
* @param optimum Expected optimum.
* @param init Starting point.
* @param goal Minimization or maximization.
* @param objTol Tolerance (relative error on the objective function).
* @param pointTol Tolerance on the position of the optimum.
*/
private void doTest(MultivariateRealFunction func,
double[] optimum,
double[] init,
GoalType goal,
double objTol,
double pointTol)
throws MathException {
final MultivariateRealOptimizer optim = new PowellOptimizer();
final double relTol = 1e-10;
optim.setConvergenceChecker(new SimpleScalarValueChecker(objTol, -1));
final RealPointValuePair result = optim.optimize(func, goal, init);
final double[] found = result.getPoint();
for (int i = 0, dim = optimum.length; i < dim; i++) {
Assert.assertEquals(optimum[i], found[i], pointTol);
}
}
}

View File

@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization.univariate;
import org.apache.commons.math.MathException;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.analysis.UnivariateRealFunction;
import org.apache.commons.math.optimization.GoalType;
import org.junit.Assert;
import org.junit.Test;
public class BracketFinderTest {
@Test
public void testCubicMin() throws MathException {
final BracketFinder bFind = new BracketFinder();
final UnivariateRealFunction func = new UnivariateRealFunction() {
public double value(double x)
throws FunctionEvaluationException {
if (x < -2) {
return value(-2);
}
else {
return (x - 1) * (x + 2) * (x + 3);
}
}
};
bFind.search(func, GoalType.MINIMIZE, -2 , -1);
final double tol = 1e-15;
// Comparing with results computed in Python.
Assert.assertEquals(-2, bFind.getLo(), tol);
Assert.assertEquals(-1, bFind.getMid(), tol);
Assert.assertEquals(0.61803399999999997, bFind.getHi(), tol);
}
@Test
public void testCubicMax() throws MathException {
final BracketFinder bFind = new BracketFinder();
final UnivariateRealFunction func = new UnivariateRealFunction() {
public double value(double x)
throws FunctionEvaluationException {
if (x < -2) {
return value(-2);
}
else {
return -(x - 1) * (x + 2) * (x + 3);
}
}
};
bFind.search(func, GoalType.MAXIMIZE, -2 , -1);
final double tol = 1e-15;
Assert.assertEquals(-2, bFind.getLo(), tol);
Assert.assertEquals(-1, bFind.getMid(), tol);
Assert.assertEquals(0.61803399999999997, bFind.getHi(), tol);
}
}