Refactored of the contents of package "o.a.c.m.optimization"
into the new "o.a.c.m.optim" and "o.a.c.m.fitting" packages.
* All deprecated classes/fields/methods have been removed in the
  replacement packages.
* Simplified API: a single "optimize(OptimizationData... data)"
  for all optimizer types.
* Simplified class hierarchy, merged interfaces and abstract
  classes, only base classes are generic.
* The new classes do not use the "DerivativeStructure" type.


git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1420684 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Gilles Sadowski 2012-12-12 14:10:38 +00:00
parent 63623c9236
commit 7ee7843ffe
118 changed files with 25159 additions and 1 deletions

View File

@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.exception;
import org.apache.commons.math3.exception.util.LocalizedFormats;
/**
* Exception to be thrown when the maximal number of iterations is exceeded.
*
* @since 3.1
* @version $Id$
*/
public class TooManyIterationsException extends MaxCountExceededException {
/** Serializable version Id. */
private static final long serialVersionUID = 20121211L;
/**
* Construct the exception.
*
* @param max Maximum number of evaluations.
*/
public TooManyIterationsException(Number max) {
super(max);
getContext().addMessage(LocalizedFormats.ITERATIONS);
}
}

View File

@ -148,6 +148,7 @@ public enum LocalizedFormats implements Localizable {
INVALID_REGRESSION_OBSERVATION("length of regressor array = {0} does not match the number of variables = {1} in the model"),
INVALID_ROUNDING_METHOD("invalid rounding method {0}, valid methods: {1} ({2}), {3} ({4}), {5} ({6}), {7} ({8}), {9} ({10}), {11} ({12}), {13} ({14}), {15} ({16})"),
ITERATOR_EXHAUSTED("iterator exhausted"),
ITERATIONS("iterations"), /* keep */
LCM_OVERFLOW_32_BITS("overflow: lcm({0}, {1}) is 2^31"),
LCM_OVERFLOW_64_BITS("overflow: lcm({0}, {1}) is 2^63"),
LIST_OF_CHROMOSOMES_BIGGER_THAN_POPULATION_SIZE("list of chromosomes bigger than maxPopulationSize"),

View File

@ -0,0 +1,230 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.fitting;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.PointVectorValuePair;
import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
import org.apache.commons.math3.optim.nonlinear.vector.ModelFunction;
import org.apache.commons.math3.optim.nonlinear.vector.ModelFunctionJacobian;
import org.apache.commons.math3.optim.nonlinear.vector.Target;
import org.apache.commons.math3.optim.nonlinear.vector.Weight;
/**
* Fitter for parametric univariate real functions y = f(x).
* <br/>
* When a univariate real function y = f(x) does depend on some
* unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
* this class can be used to find these parameters. It does this
* by <em>fitting</em> the curve so it remains very close to a set of
* observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
* y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
* is done by finding the parameters values that minimizes the objective
* function &sum;(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
* really a least squares problem.
*
* @param <T> Function to use for the fit.
*
* @version $Id: CurveFitter.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class CurveFitter<T extends ParametricUnivariateFunction> {
/** Optimizer to use for the fitting. */
private final MultivariateVectorOptimizer optimizer;
/** Observed points. */
private final List<WeightedObservedPoint> observations;
/**
* Simple constructor.
*
* @param optimizer Optimizer to use for the fitting.
* @since 3.1
*/
public CurveFitter(final MultivariateVectorOptimizer optimizer) {
this.optimizer = optimizer;
observations = new ArrayList<WeightedObservedPoint>();
}
/** Add an observed (x,y) point to the sample with unit weight.
* <p>Calling this method is equivalent to call
* {@code addObservedPoint(1.0, x, y)}.</p>
* @param x abscissa of the point
* @param y observed value of the point at x, after fitting we should
* have f(x) as close as possible to this value
* @see #addObservedPoint(double, double, double)
* @see #addObservedPoint(WeightedObservedPoint)
* @see #getObservations()
*/
public void addObservedPoint(double x, double y) {
addObservedPoint(1.0, x, y);
}
/** Add an observed weighted (x,y) point to the sample.
* @param weight weight of the observed point in the fit
* @param x abscissa of the point
* @param y observed value of the point at x, after fitting we should
* have f(x) as close as possible to this value
* @see #addObservedPoint(double, double)
* @see #addObservedPoint(WeightedObservedPoint)
* @see #getObservations()
*/
public void addObservedPoint(double weight, double x, double y) {
observations.add(new WeightedObservedPoint(weight, x, y));
}
/** Add an observed weighted (x,y) point to the sample.
* @param observed observed point to add
* @see #addObservedPoint(double, double)
* @see #addObservedPoint(double, double, double)
* @see #getObservations()
*/
public void addObservedPoint(WeightedObservedPoint observed) {
observations.add(observed);
}
/** Get the observed points.
* @return observed points
* @see #addObservedPoint(double, double)
* @see #addObservedPoint(double, double, double)
* @see #addObservedPoint(WeightedObservedPoint)
*/
public WeightedObservedPoint[] getObservations() {
return observations.toArray(new WeightedObservedPoint[observations.size()]);
}
/**
* Remove all observations.
*/
public void clearObservations() {
observations.clear();
}
/**
* Fit a curve.
* This method compute the coefficients of the curve that best
* fit the sample of observed points previously given through calls
* to the {@link #addObservedPoint(WeightedObservedPoint)
* addObservedPoint} method.
*
* @param f parametric function to fit.
* @param initialGuess first guess of the function parameters.
* @return the fitted parameters.
* @throws org.apache.commons.math3.exception.DimensionMismatchException
* if the start point dimension is wrong.
*/
public double[] fit(T f, final double[] initialGuess) {
return fit(Integer.MAX_VALUE, f, initialGuess);
}
/**
* Fit a curve.
* This method compute the coefficients of the curve that best
* fit the sample of observed points previously given through calls
* to the {@link #addObservedPoint(WeightedObservedPoint)
* addObservedPoint} method.
*
* @param f parametric function to fit.
* @param initialGuess first guess of the function parameters.
* @param maxEval Maximum number of function evaluations.
* @return the fitted parameters.
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
* if the number of allowed evaluations is exceeded.
* @throws org.apache.commons.math3.exception.DimensionMismatchException
* if the start point dimension is wrong.
* @since 3.0
*/
public double[] fit(int maxEval, T f,
final double[] initialGuess) {
// Prepare least squares problem.
double[] target = new double[observations.size()];
double[] weights = new double[observations.size()];
int i = 0;
for (WeightedObservedPoint point : observations) {
target[i] = point.getY();
weights[i] = point.getWeight();
++i;
}
// Input to the optimizer: the model and its Jacobian.
final TheoreticalValuesFunction model = new TheoreticalValuesFunction(f);
// Perform the fit.
final PointVectorValuePair optimum
= optimizer.optimize(new MaxEval(maxEval),
model.getModelFunction(),
model.getModelFunctionJacobian(),
new Target(target),
new Weight(weights),
new InitialGuess(initialGuess));
// Extract the coefficients.
return optimum.getPointRef();
}
/** Vectorial function computing function theoretical values. */
private class TheoreticalValuesFunction {
/** Function to fit. */
private final ParametricUnivariateFunction f;
/**
* @param f function to fit.
*/
public TheoreticalValuesFunction(final ParametricUnivariateFunction f) {
this.f = f;
}
/**
* @return the model function values.
*/
public ModelFunction getModelFunction() {
return new ModelFunction(new MultivariateVectorFunction() {
/** {@inheritDoc} */
public double[] value(double[] point) {
// compute the residuals
final double[] values = new double[observations.size()];
int i = 0;
for (WeightedObservedPoint observed : observations) {
values[i++] = f.value(observed.getX(), point);
}
return values;
}
});
}
/**
* @return the model function Jacobian.
*/
public ModelFunctionJacobian getModelFunctionJacobian() {
return new ModelFunctionJacobian(new MultivariateMatrixFunction() {
public double[][] value(double[] point) {
final double[][] jacobian = new double[observations.size()][];
int i = 0;
for (WeightedObservedPoint observed : observations) {
jacobian[i++] = f.gradient(observed.getX(), point);
}
return jacobian;
}
});
}
}
}

View File

@ -0,0 +1,362 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.fitting;
import java.util.Arrays;
import java.util.Comparator;
import org.apache.commons.math3.analysis.function.Gaussian;
import org.apache.commons.math3.exception.NullArgumentException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.exception.OutOfRangeException;
import org.apache.commons.math3.exception.ZeroException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
import org.apache.commons.math3.util.FastMath;
/**
* Fits points to a {@link
* org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian} function.
* <p>
* Usage example:
* <pre>
* GaussianFitter fitter = new GaussianFitter(
* new LevenbergMarquardtOptimizer());
* fitter.addObservedPoint(4.0254623, 531026.0);
* fitter.addObservedPoint(4.03128248, 984167.0);
* fitter.addObservedPoint(4.03839603, 1887233.0);
* fitter.addObservedPoint(4.04421621, 2687152.0);
* fitter.addObservedPoint(4.05132976, 3461228.0);
* fitter.addObservedPoint(4.05326982, 3580526.0);
* fitter.addObservedPoint(4.05779662, 3439750.0);
* fitter.addObservedPoint(4.0636168, 2877648.0);
* fitter.addObservedPoint(4.06943698, 2175960.0);
* fitter.addObservedPoint(4.07525716, 1447024.0);
* fitter.addObservedPoint(4.08237071, 717104.0);
* fitter.addObservedPoint(4.08366408, 620014.0);
* double[] parameters = fitter.fit();
* </pre>
*
* @since 2.2
* @version $Id: GaussianFitter.java 1416643 2012-12-03 19:37:14Z tn $
*/
public class GaussianFitter extends CurveFitter<Gaussian.Parametric> {
/**
* Constructs an instance using the specified optimizer.
*
* @param optimizer Optimizer to use for the fitting.
*/
public GaussianFitter(MultivariateVectorOptimizer optimizer) {
super(optimizer);
}
/**
* Fits a Gaussian function to the observed points.
*
* @param initialGuess First guess values in the following order:
* <ul>
* <li>Norm</li>
* <li>Mean</li>
* <li>Sigma</li>
* </ul>
* @return the parameters of the Gaussian function that best fits the
* observed points (in the same order as above).
* @since 3.0
*/
public double[] fit(double[] initialGuess) {
final Gaussian.Parametric f = new Gaussian.Parametric() {
@Override
public double value(double x, double ... p) {
double v = Double.POSITIVE_INFINITY;
try {
v = super.value(x, p);
} catch (NotStrictlyPositiveException e) {
// Do nothing.
}
return v;
}
@Override
public double[] gradient(double x, double ... p) {
double[] v = { Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY };
try {
v = super.gradient(x, p);
} catch (NotStrictlyPositiveException e) {
// Do nothing.
}
return v;
}
};
return fit(f, initialGuess);
}
/**
* Fits a Gaussian function to the observed points.
*
* @return the parameters of the Gaussian function that best fits the
* observed points (in the same order as above).
*/
public double[] fit() {
final double[] guess = (new ParameterGuesser(getObservations())).guess();
return fit(guess);
}
/**
* Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
* of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric}
* based on the specified observed points.
*/
public static class ParameterGuesser {
/** Normalization factor. */
private final double norm;
/** Mean. */
private final double mean;
/** Standard deviation. */
private final double sigma;
/**
* Constructs instance with the specified observed points.
*
* @param observations Observed points from which to guess the
* parameters of the Gaussian.
* @throws NullArgumentException if {@code observations} is
* {@code null}.
* @throws NumberIsTooSmallException if there are less than 3
* observations.
*/
public ParameterGuesser(WeightedObservedPoint[] observations) {
if (observations == null) {
throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
}
if (observations.length < 3) {
throw new NumberIsTooSmallException(observations.length, 3, true);
}
final WeightedObservedPoint[] sorted = sortObservations(observations);
final double[] params = basicGuess(sorted);
norm = params[0];
mean = params[1];
sigma = params[2];
}
/**
* Gets an estimation of the parameters.
*
* @return the guessed parameters, in the following order:
* <ul>
* <li>Normalization factor</li>
* <li>Mean</li>
* <li>Standard deviation</li>
* </ul>
*/
public double[] guess() {
return new double[] { norm, mean, sigma };
}
/**
* Sort the observations.
*
* @param unsorted Input observations.
* @return the input observations, sorted.
*/
private WeightedObservedPoint[] sortObservations(WeightedObservedPoint[] unsorted) {
final WeightedObservedPoint[] observations = unsorted.clone();
final Comparator<WeightedObservedPoint> cmp
= new Comparator<WeightedObservedPoint>() {
public int compare(WeightedObservedPoint p1,
WeightedObservedPoint p2) {
if (p1 == null && p2 == null) {
return 0;
}
if (p1 == null) {
return -1;
}
if (p2 == null) {
return 1;
}
if (p1.getX() < p2.getX()) {
return -1;
}
if (p1.getX() > p2.getX()) {
return 1;
}
if (p1.getY() < p2.getY()) {
return -1;
}
if (p1.getY() > p2.getY()) {
return 1;
}
if (p1.getWeight() < p2.getWeight()) {
return -1;
}
if (p1.getWeight() > p2.getWeight()) {
return 1;
}
return 0;
}
};
Arrays.sort(observations, cmp);
return observations;
}
/**
* Guesses the parameters based on the specified observed points.
*
* @param points Observed points, sorted.
* @return the guessed parameters (normalization factor, mean and
* sigma).
*/
private double[] basicGuess(WeightedObservedPoint[] points) {
final int maxYIdx = findMaxY(points);
final double n = points[maxYIdx].getY();
final double m = points[maxYIdx].getX();
double fwhmApprox;
try {
final double halfY = n + ((m - n) / 2);
final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
fwhmApprox = fwhmX2 - fwhmX1;
} catch (OutOfRangeException e) {
// TODO: Exceptions should not be used for flow control.
fwhmApprox = points[points.length - 1].getX() - points[0].getX();
}
final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));
return new double[] { n, m, s };
}
/**
* Finds index of point in specified points with the largest Y.
*
* @param points Points to search.
* @return the index in specified points array.
*/
private int findMaxY(WeightedObservedPoint[] points) {
int maxYIdx = 0;
for (int i = 1; i < points.length; i++) {
if (points[i].getY() > points[maxYIdx].getY()) {
maxYIdx = i;
}
}
return maxYIdx;
}
/**
* Interpolates using the specified points to determine X at the
* specified Y.
*
* @param points Points to use for interpolation.
* @param startIdx Index within points from which to start the search for
* interpolation bounds points.
* @param idxStep Index step for searching interpolation bounds points.
* @param y Y value for which X should be determined.
* @return the value of X for the specified Y.
* @throws ZeroException if {@code idxStep} is 0.
* @throws OutOfRangeException if specified {@code y} is not within the
* range of the specified {@code points}.
*/
private double interpolateXAtY(WeightedObservedPoint[] points,
int startIdx,
int idxStep,
double y)
throws OutOfRangeException {
if (idxStep == 0) {
throw new ZeroException();
}
final WeightedObservedPoint[] twoPoints
= getInterpolationPointsForY(points, startIdx, idxStep, y);
final WeightedObservedPoint p1 = twoPoints[0];
final WeightedObservedPoint p2 = twoPoints[1];
if (p1.getY() == y) {
return p1.getX();
}
if (p2.getY() == y) {
return p2.getX();
}
return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
(p2.getY() - p1.getY()));
}
/**
* Gets the two bounding interpolation points from the specified points
* suitable for determining X at the specified Y.
*
* @param points Points to use for interpolation.
* @param startIdx Index within points from which to start search for
* interpolation bounds points.
* @param idxStep Index step for search for interpolation bounds points.
* @param y Y value for which X should be determined.
* @return the array containing two points suitable for determining X at
* the specified Y.
* @throws ZeroException if {@code idxStep} is 0.
* @throws OutOfRangeException if specified {@code y} is not within the
* range of the specified {@code points}.
*/
private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
int startIdx,
int idxStep,
double y)
throws OutOfRangeException {
if (idxStep == 0) {
throw new ZeroException();
}
for (int i = startIdx;
idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
i += idxStep) {
final WeightedObservedPoint p1 = points[i];
final WeightedObservedPoint p2 = points[i + idxStep];
if (isBetween(y, p1.getY(), p2.getY())) {
if (idxStep < 0) {
return new WeightedObservedPoint[] { p2, p1 };
} else {
return new WeightedObservedPoint[] { p1, p2 };
}
}
}
// Boundaries are replaced by dummy values because the raised
// exception is caught and the message never displayed.
// TODO: Exceptions should not be used for flow control.
throw new OutOfRangeException(y,
Double.NEGATIVE_INFINITY,
Double.POSITIVE_INFINITY);
}
/**
* Determines whether a value is between two other values.
*
* @param value Value to test whether it is between {@code boundary1}
* and {@code boundary2}.
* @param boundary1 One end of the range.
* @param boundary2 Other end of the range.
* @return {@code true} if {@code value} is between {@code boundary1} and
* {@code boundary2} (inclusive), {@code false} otherwise.
*/
private boolean isBetween(double value,
double boundary1,
double boundary2) {
return (value >= boundary1 && value <= boundary2) ||
(value >= boundary2 && value <= boundary1);
}
}
}

View File

@ -0,0 +1,382 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.fitting;
import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
import org.apache.commons.math3.analysis.function.HarmonicOscillator;
import org.apache.commons.math3.exception.ZeroException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.exception.MathIllegalStateException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.util.FastMath;
/**
* Class that implements a curve fitting specialized for sinusoids.
*
* Harmonic fitting is a very simple case of curve fitting. The
* estimated coefficients are the amplitude a, the pulsation &omega; and
* the phase &phi;: <code>f (t) = a cos (&omega; t + &phi;)</code>. They are
* searched by a least square estimator initialized with a rough guess
* based on integrals.
*
* @version $Id: HarmonicFitter.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class HarmonicFitter extends CurveFitter<HarmonicOscillator.Parametric> {
/**
* Simple constructor.
* @param optimizer Optimizer to use for the fitting.
*/
public HarmonicFitter(final MultivariateVectorOptimizer optimizer) {
super(optimizer);
}
/**
* Fit an harmonic function to the observed points.
*
* @param initialGuess First guess values in the following order:
* <ul>
* <li>Amplitude</li>
* <li>Angular frequency</li>
* <li>Phase</li>
* </ul>
* @return the parameters of the harmonic function that best fits the
* observed points (in the same order as above).
*/
public double[] fit(double[] initialGuess) {
return fit(new HarmonicOscillator.Parametric(), initialGuess);
}
/**
* Fit an harmonic function to the observed points.
* An initial guess will be automatically computed.
*
* @return the parameters of the harmonic function that best fits the
* observed points (see the other {@link #fit(double[]) fit} method.
* @throws NumberIsTooSmallException if the sample is too short for the
* the first guess to be computed.
* @throws ZeroException if the first guess cannot be computed because
* the abscissa range is zero.
*/
public double[] fit() {
return fit((new ParameterGuesser(getObservations())).guess());
}
/**
* This class guesses harmonic coefficients from a sample.
* <p>The algorithm used to guess the coefficients is as follows:</p>
*
* <p>We know f (t) at some sampling points t<sub>i</sub> and want to find a,
* &omega; and &phi; such that f (t) = a cos (&omega; t + &phi;).
* </p>
*
* <p>From the analytical expression, we can compute two primitives :
* <pre>
* If2 (t) = &int; f<sup>2</sup> = a<sup>2</sup> &times; [t + S (t)] / 2
* If'2 (t) = &int; f'<sup>2</sup> = a<sup>2</sup> &omega;<sup>2</sup> &times; [t - S (t)] / 2
* where S (t) = sin (2 (&omega; t + &phi;)) / (2 &omega;)
* </pre>
* </p>
*
* <p>We can remove S between these expressions :
* <pre>
* If'2 (t) = a<sup>2</sup> &omega;<sup>2</sup> t - &omega;<sup>2</sup> If2 (t)
* </pre>
* </p>
*
* <p>The preceding expression shows that If'2 (t) is a linear
* combination of both t and If2 (t): If'2 (t) = A &times; t + B &times; If2 (t)
* </p>
*
* <p>From the primitive, we can deduce the same form for definite
* integrals between t<sub>1</sub> and t<sub>i</sub> for each t<sub>i</sub> :
* <pre>
* If2 (t<sub>i</sub>) - If2 (t<sub>1</sub>) = A &times; (t<sub>i</sub> - t<sub>1</sub>) + B &times; (If2 (t<sub>i</sub>) - If2 (t<sub>1</sub>))
* </pre>
* </p>
*
* <p>We can find the coefficients A and B that best fit the sample
* to this linear expression by computing the definite integrals for
* each sample points.
* </p>
*
* <p>For a bilinear expression z (x<sub>i</sub>, y<sub>i</sub>) = A &times; x<sub>i</sub> + B &times; y<sub>i</sub>, the
* coefficients A and B that minimize a least square criterion
* &sum; (z<sub>i</sub> - z (x<sub>i</sub>, y<sub>i</sub>))<sup>2</sup> are given by these expressions:</p>
* <pre>
*
* &sum;y<sub>i</sub>y<sub>i</sub> &sum;x<sub>i</sub>z<sub>i</sub> - &sum;x<sub>i</sub>y<sub>i</sub> &sum;y<sub>i</sub>z<sub>i</sub>
* A = ------------------------
* &sum;x<sub>i</sub>x<sub>i</sub> &sum;y<sub>i</sub>y<sub>i</sub> - &sum;x<sub>i</sub>y<sub>i</sub> &sum;x<sub>i</sub>y<sub>i</sub>
*
* &sum;x<sub>i</sub>x<sub>i</sub> &sum;y<sub>i</sub>z<sub>i</sub> - &sum;x<sub>i</sub>y<sub>i</sub> &sum;x<sub>i</sub>z<sub>i</sub>
* B = ------------------------
* &sum;x<sub>i</sub>x<sub>i</sub> &sum;y<sub>i</sub>y<sub>i</sub> - &sum;x<sub>i</sub>y<sub>i</sub> &sum;x<sub>i</sub>y<sub>i</sub>
* </pre>
* </p>
*
*
* <p>In fact, we can assume both a and &omega; are positive and
* compute them directly, knowing that A = a<sup>2</sup> &omega;<sup>2</sup> and that
* B = - &omega;<sup>2</sup>. The complete algorithm is therefore:</p>
* <pre>
*
* for each t<sub>i</sub> from t<sub>1</sub> to t<sub>n-1</sub>, compute:
* f (t<sub>i</sub>)
* f' (t<sub>i</sub>) = (f (t<sub>i+1</sub>) - f(t<sub>i-1</sub>)) / (t<sub>i+1</sub> - t<sub>i-1</sub>)
* x<sub>i</sub> = t<sub>i</sub> - t<sub>1</sub>
* y<sub>i</sub> = &int; f<sup>2</sup> from t<sub>1</sub> to t<sub>i</sub>
* z<sub>i</sub> = &int; f'<sup>2</sup> from t<sub>1</sub> to t<sub>i</sub>
* update the sums &sum;x<sub>i</sub>x<sub>i</sub>, &sum;y<sub>i</sub>y<sub>i</sub>, &sum;x<sub>i</sub>y<sub>i</sub>, &sum;x<sub>i</sub>z<sub>i</sub> and &sum;y<sub>i</sub>z<sub>i</sub>
* end for
*
* |--------------------------
* \ | &sum;y<sub>i</sub>y<sub>i</sub> &sum;x<sub>i</sub>z<sub>i</sub> - &sum;x<sub>i</sub>y<sub>i</sub> &sum;y<sub>i</sub>z<sub>i</sub>
* a = \ | ------------------------
* \| &sum;x<sub>i</sub>y<sub>i</sub> &sum;x<sub>i</sub>z<sub>i</sub> - &sum;x<sub>i</sub>x<sub>i</sub> &sum;y<sub>i</sub>z<sub>i</sub>
*
*
* |--------------------------
* \ | &sum;x<sub>i</sub>y<sub>i</sub> &sum;x<sub>i</sub>z<sub>i</sub> - &sum;x<sub>i</sub>x<sub>i</sub> &sum;y<sub>i</sub>z<sub>i</sub>
* &omega; = \ | ------------------------
* \| &sum;x<sub>i</sub>x<sub>i</sub> &sum;y<sub>i</sub>y<sub>i</sub> - &sum;x<sub>i</sub>y<sub>i</sub> &sum;x<sub>i</sub>y<sub>i</sub>
*
* </pre>
* </p>
*
* <p>Once we know &omega;, we can compute:
* <pre>
* fc = &omega; f (t) cos (&omega; t) - f' (t) sin (&omega; t)
* fs = &omega; f (t) sin (&omega; t) + f' (t) cos (&omega; t)
* </pre>
* </p>
*
* <p>It appears that <code>fc = a &omega; cos (&phi;)</code> and
* <code>fs = -a &omega; sin (&phi;)</code>, so we can use these
* expressions to compute &phi;. The best estimate over the sample is
* given by averaging these expressions.
* </p>
*
* <p>Since integrals and means are involved in the preceding
* estimations, these operations run in O(n) time, where n is the
* number of measurements.</p>
*/
public static class ParameterGuesser {
/** Amplitude. */
private final double a;
/** Angular frequency. */
private final double omega;
/** Phase. */
private final double phi;
/**
* Simple constructor.
*
* @param observations Sampled observations.
* @throws NumberIsTooSmallException if the sample is too short.
* @throws ZeroException if the abscissa range is zero.
* @throws MathIllegalStateException when the guessing procedure cannot
* produce sensible results.
*/
public ParameterGuesser(WeightedObservedPoint[] observations) {
if (observations.length < 4) {
throw new NumberIsTooSmallException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE,
observations.length, 4, true);
}
final WeightedObservedPoint[] sorted = sortObservations(observations);
final double aOmega[] = guessAOmega(sorted);
a = aOmega[0];
omega = aOmega[1];
phi = guessPhi(sorted);
}
/**
* Gets an estimation of the parameters.
*
* @return the guessed parameters, in the following order:
* <ul>
* <li>Amplitude</li>
* <li>Angular frequency</li>
* <li>Phase</li>
* </ul>
*/
public double[] guess() {
return new double[] { a, omega, phi };
}
/**
* Sort the observations with respect to the abscissa.
*
* @param unsorted Input observations.
* @return the input observations, sorted.
*/
private WeightedObservedPoint[] sortObservations(WeightedObservedPoint[] unsorted) {
final WeightedObservedPoint[] observations = unsorted.clone();
// Since the samples are almost always already sorted, this
// method is implemented as an insertion sort that reorders the
// elements in place. Insertion sort is very efficient in this case.
WeightedObservedPoint curr = observations[0];
for (int j = 1; j < observations.length; ++j) {
WeightedObservedPoint prec = curr;
curr = observations[j];
if (curr.getX() < prec.getX()) {
// the current element should be inserted closer to the beginning
int i = j - 1;
WeightedObservedPoint mI = observations[i];
while ((i >= 0) && (curr.getX() < mI.getX())) {
observations[i + 1] = mI;
if (i-- != 0) {
mI = observations[i];
}
}
observations[i + 1] = curr;
curr = observations[j];
}
}
return observations;
}
/**
* Estimate a first guess of the amplitude and angular frequency.
* This method assumes that the {@link #sortObservations()} method
* has been called previously.
*
* @param observations Observations, sorted w.r.t. abscissa.
* @throws ZeroException if the abscissa range is zero.
* @throws MathIllegalStateException when the guessing procedure cannot
* produce sensible results.
* @return the guessed amplitude (at index 0) and circular frequency
* (at index 1).
*/
private double[] guessAOmega(WeightedObservedPoint[] observations) {
final double[] aOmega = new double[2];
// initialize the sums for the linear model between the two integrals
double sx2 = 0;
double sy2 = 0;
double sxy = 0;
double sxz = 0;
double syz = 0;
double currentX = observations[0].getX();
double currentY = observations[0].getY();
double f2Integral = 0;
double fPrime2Integral = 0;
final double startX = currentX;
for (int i = 1; i < observations.length; ++i) {
// one step forward
final double previousX = currentX;
final double previousY = currentY;
currentX = observations[i].getX();
currentY = observations[i].getY();
// update the integrals of f<sup>2</sup> and f'<sup>2</sup>
// considering a linear model for f (and therefore constant f')
final double dx = currentX - previousX;
final double dy = currentY - previousY;
final double f2StepIntegral =
dx * (previousY * previousY + previousY * currentY + currentY * currentY) / 3;
final double fPrime2StepIntegral = dy * dy / dx;
final double x = currentX - startX;
f2Integral += f2StepIntegral;
fPrime2Integral += fPrime2StepIntegral;
sx2 += x * x;
sy2 += f2Integral * f2Integral;
sxy += x * f2Integral;
sxz += x * fPrime2Integral;
syz += f2Integral * fPrime2Integral;
}
// compute the amplitude and pulsation coefficients
double c1 = sy2 * sxz - sxy * syz;
double c2 = sxy * sxz - sx2 * syz;
double c3 = sx2 * sy2 - sxy * sxy;
if ((c1 / c2 < 0) || (c2 / c3 < 0)) {
final int last = observations.length - 1;
// Range of the observations, assuming that the
// observations are sorted.
final double xRange = observations[last].getX() - observations[0].getX();
if (xRange == 0) {
throw new ZeroException();
}
aOmega[1] = 2 * Math.PI / xRange;
double yMin = Double.POSITIVE_INFINITY;
double yMax = Double.NEGATIVE_INFINITY;
for (int i = 1; i < observations.length; ++i) {
final double y = observations[i].getY();
if (y < yMin) {
yMin = y;
}
if (y > yMax) {
yMax = y;
}
}
aOmega[0] = 0.5 * (yMax - yMin);
} else {
if (c2 == 0) {
// In some ill-conditioned cases (cf. MATH-844), the guesser
// procedure cannot produce sensible results.
throw new MathIllegalStateException(LocalizedFormats.ZERO_DENOMINATOR);
}
aOmega[0] = FastMath.sqrt(c1 / c2);
aOmega[1] = FastMath.sqrt(c2 / c3);
}
return aOmega;
}
/**
* Estimate a first guess of the phase.
*
* @param observations Observations, sorted w.r.t. abscissa.
* @return the guessed phase.
*/
private double guessPhi(WeightedObservedPoint[] observations) {
// initialize the means
double fcMean = 0;
double fsMean = 0;
double currentX = observations[0].getX();
double currentY = observations[0].getY();
for (int i = 1; i < observations.length; ++i) {
// one step forward
final double previousX = currentX;
final double previousY = currentY;
currentX = observations[i].getX();
currentY = observations[i].getY();
final double currentYPrime = (currentY - previousY) / (currentX - previousX);
double omegaX = omega * currentX;
double cosine = FastMath.cos(omegaX);
double sine = FastMath.sin(omegaX);
fcMean += omega * currentY * cosine - currentYPrime * sine;
fsMean += omega * currentY * sine + currentYPrime * cosine;
}
return FastMath.atan2(-fsMean, fcMean);
}
}
}

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.fitting;
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
/**
* Polynomial fitting is a very simple case of {@link CurveFitter curve fitting}.
* The estimated coefficients are the polynomial coefficients (see the
* {@link #fit(double[]) fit} method).
*
* @version $Id: PolynomialFitter.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class PolynomialFitter extends CurveFitter<PolynomialFunction.Parametric> {
/**
* Simple constructor.
*
* @param optimizer Optimizer to use for the fitting.
*/
public PolynomialFitter(MultivariateVectorOptimizer optimizer) {
super(optimizer);
}
/**
* Get the coefficients of the polynomial fitting the weighted data points.
* The degree of the fitting polynomial is {@code guess.length - 1}.
*
* @param guess First guess for the coefficients. They must be sorted in
* increasing order of the polynomial's degree.
* @param maxEval Maximum number of evaluations of the polynomial.
* @return the coefficients of the polynomial that best fits the observed points.
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException if
* the number of evaluations exceeds {@code maxEval}.
* @throws org.apache.commons.math3.exception.ConvergenceException
* if the algorithm failed to converge.
*/
public double[] fit(int maxEval, double[] guess) {
return fit(maxEval, new PolynomialFunction.Parametric(), guess);
}
/**
* Get the coefficients of the polynomial fitting the weighted data points.
* The degree of the fitting polynomial is {@code guess.length - 1}.
*
* @param guess First guess for the coefficients. They must be sorted in
* increasing order of the polynomial's degree.
* @return the coefficients of the polynomial that best fits the observed points.
* @throws org.apache.commons.math3.exception.ConvergenceException
* if the algorithm failed to converge.
*/
public double[] fit(double[] guess) {
return fit(new PolynomialFunction.Parametric(), guess);
}
}

View File

@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.fitting;
import java.io.Serializable;
/**
* This class is a simple container for weighted observed point in
* {@link CurveFitter curve fitting}.
* <p>Instances of this class are guaranteed to be immutable.</p>
* @version $Id: WeightedObservedPoint.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class WeightedObservedPoint implements Serializable {
/** Serializable version id. */
private static final long serialVersionUID = 5306874947404636157L;
/** Weight of the measurement in the fitting process. */
private final double weight;
/** Abscissa of the point. */
private final double x;
/** Observed value of the function at x. */
private final double y;
/**
* Simple constructor.
*
* @param weight Weight of the measurement in the fitting process.
* @param x Abscissa of the measurement.
* @param y Ordinate of the measurement.
*/
public WeightedObservedPoint(final double weight, final double x, final double y) {
this.weight = weight;
this.x = x;
this.y = y;
}
/**
* Gets the weight of the measurement in the fitting process.
*
* @return the weight of the measurement in the fitting process.
*/
public double getWeight() {
return weight;
}
/**
* Gets the abscissa of the point.
*
* @return the abscissa of the point.
*/
public double getX() {
return x;
}
/**
* Gets the observed value of the function at x.
*
* @return the observed value of the function at x.
*/
public double getY() {
return y;
}
}

View File

@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Classes to perform curve fitting.
*
* Curve fitting is a special case of a least squares problem
* were the parameters are the coefficients of a function {@code f}
* whose graph {@code y = f(x)} should pass through sample points, and
* were the objective function is the squared sum of the residuals
* <code>f(x<sub>i</sub>) - y<sub>i</sub></code> for observed points
* <code>(x<sub>i</sub>, y<sub>i</sub>)</code>.
*/
package org.apache.commons.math3.fitting;

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
/**
* Base class for all convergence checker implementations.
*
* @param <PAIR> Type of (point, value) pair.
*
* @version $Id: AbstractConvergenceChecker.java 1370215 2012-08-07 12:38:59Z sebb $
* @since 3.0
*/
public abstract class AbstractConvergenceChecker<PAIR>
implements ConvergenceChecker<PAIR> {
/**
* Relative tolerance threshold.
*/
private final double relativeThreshold;
/**
* Absolute tolerance threshold.
*/
private final double absoluteThreshold;
/**
* Build an instance with a specified thresholds.
*
* @param relativeThreshold relative tolerance threshold
* @param absoluteThreshold absolute tolerance threshold
*/
public AbstractConvergenceChecker(final double relativeThreshold,
final double absoluteThreshold) {
this.relativeThreshold = relativeThreshold;
this.absoluteThreshold = absoluteThreshold;
}
/**
* @return the relative threshold.
*/
public double getRelativeThreshold() {
return relativeThreshold;
}
/**
* @return the absolute threshold.
*/
public double getAbsoluteThreshold() {
return absoluteThreshold;
}
/**
* {@inheritDoc}
*/
public abstract boolean converged(int iteration,
PAIR previous,
PAIR current);
}

View File

@ -0,0 +1,214 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.exception.MathIllegalStateException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NullArgumentException;
import org.apache.commons.math3.random.RandomVectorGenerator;
import org.apache.commons.math3.optim.InitialGuess;
/**
* Base class multi-start optimizer for a multivariate function.
* <br/>
* This class wraps an optimizer in order to use it several times in
* turn with different starting points (trying to avoid being trapped
* in a local extremum when looking for a global one).
* <em>It is not a "user" class.</em>
*
* @param <PAIR> Type of the point/value pair returned by the optimization
* algorithm.
*
* @version $Id$
* @since 3.0
*/
public abstract class BaseMultiStartMultivariateOptimizer<PAIR>
extends BaseMultivariateOptimizer<PAIR> {
/** Underlying classical optimizer. */
private final BaseMultivariateOptimizer<PAIR> optimizer;
/** Number of evaluations already performed for all starts. */
private int totalEvaluations;
/** Number of starts to go. */
private int starts;
/** Random generator for multi-start. */
private RandomVectorGenerator generator;
/** Optimization data. */
private OptimizationData[] optimData;
/**
* Location in {@link #optimData} where the updated maximum
* number of evaluations will be stored.
*/
private int maxEvalIndex = -1;
/**
* Location in {@link #optimData} where the updated start value
* will be stored.
*/
private int initialGuessIndex = -1;
/**
* Create a multi-start optimizer from a single-start optimizer.
*
* @param optimizer Single-start optimizer to wrap.
* @param starts Number of starts to perform. If {@code starts == 1},
* the {@link #optimize(OptimizationData[]) optimize} will return the
* same solution as the given {@code optimizer} would return.
* @param generator Random vector generator to use for restarts.
* @throws NullArgumentException if {@code optimizer} or {@code generator}
* is {@code null}.
* @throws NotStrictlyPositiveException if {@code starts < 1}.
*/
public BaseMultiStartMultivariateOptimizer(final BaseMultivariateOptimizer<PAIR> optimizer,
final int starts,
final RandomVectorGenerator generator) {
super(optimizer.getConvergenceChecker());
if (optimizer == null ||
generator == null) {
throw new NullArgumentException();
}
if (starts < 1) {
throw new NotStrictlyPositiveException(starts);
}
this.optimizer = optimizer;
this.starts = starts;
this.generator = generator;
}
/** {@inheritDoc} */
@Override
public int getEvaluations() {
return totalEvaluations;
}
/**
* Gets all the optima found during the last call to {@code optimize}.
* The optimizer stores all the optima found during a set of
* restarts. The {@code optimize} method returns the best point only.
* This method returns all the points found at the end of each starts,
* including the best one already returned by the {@code optimize} method.
* <br/>
* The returned array as one element for each start as specified
* in the constructor. It is ordered with the results from the
* runs that did converge first, sorted from best to worst
* objective value (i.e in ascending order if minimizing and in
* descending order if maximizing), followed by {@code null} elements
* corresponding to the runs that did not converge. This means all
* elements will be {@code null} if the {@code optimize} method did throw
* an exception.
* This also means that if the first element is not {@code null}, it is
* the best point found across all starts.
* <br/>
* The behaviour is undefined if this method is called before
* {@code optimize}; it will likely throw {@code NullPointerException}.
*
* @return an array containing the optima sorted from best to worst.
*/
public abstract PAIR[] getOptima();
/**
* {@inheritDoc}
*
* @throws MathIllegalStateException if {@code optData} does not contain an
* instance of {@link MaxEval} or {@link InitialGuess}.
*/
@Override
public PAIR optimize(OptimizationData... optData) {
// Store arguments in order to pass them to the internal optimizer.
optimData = optData;
// Set up base class and perform computations.
return super.optimize(optData);
}
/** {@inheritDoc} */
@Override
protected PAIR doOptimize() {
// Remove all instances of "MaxEval" and "InitialGuess" from the
// array that will be passed to the internal optimizer.
// The former is to enforce smaller numbers of allowed evaluations
// (according to how many have been used up already), and the latter
// to impose a different start value for each start.
for (int i = 0; i < optimData.length; i++) {
if (optimData[i] instanceof MaxEval) {
optimData[i] = null;
maxEvalIndex = i;
}
if (optimData[i] instanceof InitialGuess) {
optimData[i] = null;
initialGuessIndex = i;
continue;
}
}
if (maxEvalIndex == -1) {
throw new MathIllegalStateException();
}
if (initialGuessIndex == -1) {
throw new MathIllegalStateException();
}
RuntimeException lastException = null;
totalEvaluations = 0;
clear();
final int maxEval = getMaxEvaluations();
final double[] min = getLowerBound();
final double[] max = getUpperBound();
final double[] startPoint = getStartPoint();
// Multi-start loop.
for (int i = 0; i < starts; i++) {
// CHECKSTYLE: stop IllegalCatch
try {
// Decrease number of allowed evaluations.
optimData[maxEvalIndex] = new MaxEval(maxEval - totalEvaluations);
// New start value.
final double[] s = (i == 0) ?
startPoint :
generator.nextVector(); // XXX This does not enforce bounds!
optimData[initialGuessIndex] = new InitialGuess(s);
// Optimize.
final PAIR result = optimizer.optimize(optimData);
store(result);
} catch (RuntimeException mue) {
lastException = mue;
}
// CHECKSTYLE: resume IllegalCatch
totalEvaluations += optimizer.getEvaluations();
}
final PAIR[] optima = getOptima();
if (optima.length == 0) {
// All runs failed.
throw lastException; // Cannot be null if starts >= 1.
}
// Return the best optimum.
return optima[0];
}
/**
* Method that will be called in order to store each found optimum.
*
* @param optimum Result of an optimization run.
*/
protected abstract void store(PAIR optimum);
/**
* Method that will called in order to clear all stored optima.
*/
protected abstract void clear();
}

View File

@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
/**
* Base class for implementing optimizers for multivariate functions.
* It contains the boiler-plate code for initial guess and bounds
* specifications.
* <em>It is not a "user" class.</em>
*
* @param <PAIR> Type of the point/value pair returned by the optimization
* algorithm.
*
* @version $Id$
* @since 3.1
*/
public abstract class BaseMultivariateOptimizer<PAIR>
extends BaseOptimizer<PAIR> {
/** Initial guess. */
private double[] start;
/** Lower bounds. */
private double[] lowerBound;
/** Upper bounds. */
private double[] upperBound;
/**
* @param checker Convergence checker.
*/
protected BaseMultivariateOptimizer(ConvergenceChecker<PAIR> checker) {
super(checker);
}
/**
* {@inheritDoc}
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link MaxEval}</li>
* <li>{@link InitialGuess}</li>
* <li>{@link SimpleBounds}</li>
* </ul>
* @return {@inheritDoc}
*/
@Override
public PAIR optimize(OptimizationData... optData) {
// Retrieve settings.
parseOptimizationData(optData);
// Check input consistency.
checkParameters();
// Perform optimization.
return super.optimize(optData);
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
*
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link InitialGuess}</li>
* <li>{@link SimpleBounds}</li>
* </ul>
*/
private void parseOptimizationData(OptimizationData... optData) {
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData) {
if (data instanceof InitialGuess) {
start = ((InitialGuess) data).getInitialGuess();
continue;
}
if (data instanceof SimpleBounds) {
final SimpleBounds bounds = (SimpleBounds) data;
lowerBound = bounds.getLower();
upperBound = bounds.getUpper();
continue;
}
}
}
/**
* Gets the initial guess.
*
* @return the initial guess, or {@code null} if not set.
*/
public double[] getStartPoint() {
return start == null ? null : start.clone();
}
/**
* @return the lower bounds, or {@code null} if not set.
*/
public double[] getLowerBound() {
return lowerBound == null ? null : lowerBound.clone();
}
/**
* @return the upper bounds, or {@code null} if not set.
*/
public double[] getUpperBound() {
return upperBound == null ? null : upperBound.clone();
}
/**
* Check parameters consistency.
*/
private void checkParameters() {
if (start != null) {
final int dim = start.length;
if (lowerBound != null) {
if (lowerBound.length != dim) {
throw new DimensionMismatchException(lowerBound.length, dim);
}
for (int i = 0; i < dim; i++) {
final double v = start[i];
final double lo = lowerBound[i];
if (v < lo) {
throw new NumberIsTooSmallException(v, lo, true);
}
}
}
if (upperBound != null) {
if (upperBound.length != dim) {
throw new DimensionMismatchException(upperBound.length, dim);
}
for (int i = 0; i < dim; i++) {
final double v = start[i];
final double hi = upperBound[i];
if (v > hi) {
throw new NumberIsTooLargeException(v, hi, true);
}
}
}
}
}
}

View File

@ -0,0 +1,218 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.util.Incrementor;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.exception.TooManyIterationsException;
/**
* Base class for implementing optimizers.
* It contains the boiler-plate code for counting the number of evaluations
* of the objective function and the number of iterations of the algorithm,
* and storing the convergence checker.
* <em>It is not a "user" class.</em>
*
* @param <PAIR> Type of the point/value pair returned by the optimization
* algorithm.
*
* @version $Id$
* @since 3.1
*/
public abstract class BaseOptimizer<PAIR> {
/** Evaluations counter. */
protected final Incrementor evaluations;
/** Iterations counter. */
protected final Incrementor iterations;
/** Convergence checker. */
private ConvergenceChecker<PAIR> checker;
/**
* @param checker Convergence checker.
*/
protected BaseOptimizer(ConvergenceChecker<PAIR> checker) {
this.checker = checker;
evaluations = new Incrementor(0, new MaxEvalCallback());
iterations = new Incrementor(0, new MaxIterCallback());
}
/**
* Gets the maximal number of function evaluations.
*
* @return the maximal number of function evaluations.
*/
public int getMaxEvaluations() {
return evaluations.getMaximalCount();
}
/**
* Gets the number of evaluations of the objective function.
* The number of evaluations corresponds to the last call to the
* {@code optimize} method. It is 0 if the method has not been
* called yet.
*
* @return the number of evaluations of the objective function.
*/
public int getEvaluations() {
return evaluations.getCount();
}
/**
* Gets the maximal number of iterations.
*
* @return the maximal number of iterations.
*/
public int getMaxIterations() {
return iterations.getMaximalCount();
}
/**
* Gets the number of iterations performed by the algorithm.
* The number iterations corresponds to the last call to the
* {@code optimize} method. It is 0 if the method has not been
* called yet.
*
* @return the number of evaluations of the objective function.
*/
public int getIterations() {
return iterations.getCount();
}
/**
* Gets the convergence checker.
*
* @return the object used to check for convergence.
*/
public ConvergenceChecker<PAIR> getConvergenceChecker() {
return checker;
}
/**
* Stores data and performs the optimization.
*
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link MaxEval}</li>
* <li>{@link MaxIter}</li>
* </ul>
* @return a point/value pair that satifies the convergence criteria.
* @throws TooManyEvaluationsException if the maximal number of
* evaluations is exceeded.
* @throws TooManyIterationsException if the maximal number of
* iterations is exceeded.
*/
public PAIR optimize(OptimizationData... optData)
throws TooManyEvaluationsException,
TooManyIterationsException {
// Retrieve settings.
parseOptimizationData(optData);
// Reset counters.
evaluations.resetCount();
iterations.resetCount();
// Perform optimization.
return doOptimize();
}
/**
* Performs the bulk of the optimization algorithm.
*
* @return the point/value pair giving the optimal value of the
* objective function.
*/
protected abstract PAIR doOptimize();
/**
* Increment the evaluation count.
*
* @throws TooManyEvaluationsException if the allowed evaluations
* have been exhausted.
*/
protected void incrementEvaluationCount()
throws TooManyEvaluationsException {
evaluations.incrementCount();
}
/**
* Increment the iteration count.
*
* @throws TooManyIterationsException if the allowed iterations
* have been exhausted.
*/
protected void incrementIterationCount()
throws TooManyIterationsException {
iterations.incrementCount();
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link MaxEval}</li>
* <li>{@link MaxIter}</li>
* </ul>
*/
private void parseOptimizationData(OptimizationData... optData) {
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData) {
if (data instanceof MaxEval) {
evaluations.setMaximalCount(((MaxEval) data).getMaxEval());
continue;
}
if (data instanceof MaxIter) {
iterations.setMaximalCount(((MaxIter) data).getMaxIter());
continue;
}
}
}
/**
* Defines the action to perform when reaching the maximum number
* of evaluations.
*/
private static class MaxEvalCallback
implements Incrementor.MaxCountExceededCallback {
/**
* {@inheritDoc}
* @throws TooManyEvaluationsException.
*/
public void trigger(int max) {
throw new TooManyEvaluationsException(max);
}
}
/**
* Defines the action to perform when reaching the maximum number
* of evaluations.
*/
private static class MaxIterCallback
implements Incrementor.MaxCountExceededCallback {
/**
* {@inheritDoc}
* @throws TooManyIterationsException.
*/
public void trigger(int max) {
throw new TooManyIterationsException(max);
}
}
}

View File

@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
/**
* This interface specifies how to check if an optimization algorithm has
* converged.
* <br/>
* Deciding if convergence has been reached is a problem-dependent issue. The
* user should provide a class implementing this interface to allow the
* optimization algorithm to stop its search according to the problem at hand.
* <br/>
* For convenience, three implementations that fit simple needs are already
* provided: {@link SimpleValueChecker}, {@link SimpleVectorValueChecker} and
* {@link SimplePointChecker}. The first two consider that convergence is
* reached when the objective function value does not change much anymore, it
* does not use the point set at all.
* The third one considers that convergence is reached when the input point
* set does not change much anymore, it does not use objective function value
* at all.
*
* @param <PAIR> Type of the (point, objective value) pair.
*
* @see org.apache.commons.math3.optim.SimplePointChecker
* @see org.apache.commons.math3.optim.SimpleValueChecker
* @see org.apache.commons.math3.optim.SimpleVectorValueChecker
*
* @version $Id: ConvergenceChecker.java 1364392 2012-07-22 18:27:12Z tn $
* @since 3.0
*/
public interface ConvergenceChecker<PAIR> {
/**
* Check if the optimization algorithm has converged.
*
* @param iteration Current iteration.
* @param previous Best point in the previous iteration.
* @param current Best point in the current iteration.
* @return {@code true} if the algorithm is considered to have converged.
*/
boolean converged(int iteration, PAIR previous, PAIR current);
}

View File

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
/**
* Goal type for an optimization problem (minimization or maximization of
* a scalar function.
*
* @version $Id: GoalType.java 1364392 2012-07-22 18:27:12Z tn $
* @since 2.0
*/
public enum GoalType implements OptimizationData {
/** Maximization. */
MAXIMIZE,
/** Minimization. */
MINIMIZE
}

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
/**
* Starting point (first guess) of the optimization procedure.
* <br/>
* Immutable class.
*
* @version $Id: InitialGuess.java 1416643 2012-12-03 19:37:14Z tn $
* @since 3.1
*/
public class InitialGuess implements OptimizationData {
/** Initial guess. */
private final double[] init;
/**
* @param startPoint Initial guess.
*/
public InitialGuess(double[] startPoint) {
init = startPoint.clone();
}
/**
* Gets the initial guess.
*
* @return the initial guess.
*/
public double[] getInitialGuess() {
return init.clone();
}
}

View File

@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
/**
* Maximum number of evaluations of the function to be optimized.
*
* @version $Id$
* @since 3.1
*/
public class MaxEval implements OptimizationData {
/** Allowed number of evalutations. */
private final int maxEval;
/**
* @param max Allowed number of evalutations.
* @throws NotStrictlyPositiveException if {@code max <= 0}.
*/
public MaxEval(int max) {
if (max <= 0) {
throw new NotStrictlyPositiveException(max);
}
maxEval = max;
}
/**
* Gets the maximum number of evaluations.
*
* @return the allowed number of evaluations.
*/
public int getMaxEval() {
return maxEval;
}
/**
* Factory method that creates instance of this class that represents
* a virtually unlimited number of evaluations.
*
* @return a new instance suitable for allowing {@link Integer#MAX_VALUE}
* evaluations.
*/
public static MaxEval unlimited() {
return new MaxEval(Integer.MAX_VALUE);
}
}

View File

@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
/**
* Maximum number of iterations performed by an (iterative) algorithm.
*
* @version $Id$
* @since 3.1
*/
public class MaxIter implements OptimizationData {
/** Allowed number of evalutations. */
private final int maxIter;
/**
* @param max Allowed number of iterations.
* @throws NotStrictlyPositiveException if {@code max <= 0}.
*/
public MaxIter(int max) {
if (max <= 0) {
throw new NotStrictlyPositiveException(max);
}
maxIter = max;
}
/**
* Gets the maximum number of evaluations.
*
* @return the allowed number of evaluations.
*/
public int getMaxIter() {
return maxIter;
}
/**
* Factory method that creates instance of this class that represents
* a virtually unlimited number of iterations.
*
* @return a new instance suitable for allowing {@link Integer#MAX_VALUE}
* evaluations.
*/
public static MaxIter unlimited() {
return new MaxIter(Integer.MAX_VALUE);
}
}

View File

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.analysis.MultivariateFunction;
/**
* Scalar function to be optimized.
*
* @version $Id$
* @since 3.1
*/
public class ObjectiveFunction implements OptimizationData {
/** Function to be optimized. */
private final MultivariateFunction function;
/**
* @param f Function to be optimized.
*/
public ObjectiveFunction(MultivariateFunction f) {
function = f;
}
/**
* Gets the function to be optimized.
*
* @return the objective function.
*/
public MultivariateFunction getObjectiveFunction() {
return function;
}
}

View File

@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
/**
* Marker interface.
* Implementations will provide functionality (optional or required) needed
* by the optimizers, and those will need to check the actual type of the
* arguments and perform the appropriate cast in order to access the data
* they need.
*
* @version $Id: OptimizationData.java 1416643 2012-12-03 19:37:14Z tn $
* @since 3.1
*/
public interface OptimizationData {}

View File

@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import java.io.Serializable;
import org.apache.commons.math3.util.Pair;
/**
* This class holds a point and the value of an objective function at
* that point.
*
* @see PointVectorValuePair
* @see org.apache.commons.math3.analysis.MultivariateFunction
* @version $Id: PointValuePair.java 1416643 2012-12-03 19:37:14Z tn $
* @since 3.0
*/
public class PointValuePair extends Pair<double[], Double> implements Serializable {
/** Serializable UID. */
private static final long serialVersionUID = 20120513L;
/**
* Builds a point/objective function value pair.
*
* @param point Point coordinates. This instance will store
* a copy of the array, not the array passed as argument.
* @param value Value of the objective function at the point.
*/
public PointValuePair(final double[] point,
final double value) {
this(point, value, true);
}
/**
* Builds a point/objective function value pair.
*
* @param point Point coordinates.
* @param value Value of the objective function at the point.
* @param copyArray if {@code true}, the input array will be copied,
* otherwise it will be referenced.
*/
public PointValuePair(final double[] point,
final double value,
final boolean copyArray) {
super(copyArray ? ((point == null) ? null :
point.clone()) :
point,
value);
}
/**
* Gets the point.
*
* @return a copy of the stored point.
*/
public double[] getPoint() {
final double[] p = getKey();
return p == null ? null : p.clone();
}
/**
* Gets a reference to the point.
*
* @return a reference to the internal array storing the point.
*/
public double[] getPointRef() {
return getKey();
}
/**
* Replace the instance with a data transfer object for serialization.
* @return data transfer object that will be serialized
*/
private Object writeReplace() {
return new DataTransferObject(getKey(), getValue());
}
/** Internal class used only for serialization. */
private static class DataTransferObject implements Serializable {
/** Serializable UID. */
private static final long serialVersionUID = 20120513L;
/**
* Point coordinates.
* @Serial
*/
private final double[] point;
/**
* Value of the objective function at the point.
* @Serial
*/
private final double value;
/** Simple constructor.
* @param point Point coordinates.
* @param value Value of the objective function at the point.
*/
public DataTransferObject(final double[] point, final double value) {
this.point = point.clone();
this.value = value;
}
/** Replace the deserialized data transfer object with a {@link PointValuePair}.
* @return replacement {@link PointValuePair}
*/
private Object readResolve() {
return new PointValuePair(point, value, false);
}
}
}

View File

@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import java.io.Serializable;
import org.apache.commons.math3.util.Pair;
/**
* This class holds a point and the vectorial value of an objective function at
* that point.
*
* @see PointValuePair
* @see org.apache.commons.math3.analysis.MultivariateVectorFunction
* @version $Id: PointVectorValuePair.java 1416643 2012-12-03 19:37:14Z tn $
* @since 3.0
*/
public class PointVectorValuePair extends Pair<double[], double[]> implements Serializable {
/** Serializable UID. */
private static final long serialVersionUID = 20120513L;
/**
* Builds a point/objective function value pair.
*
* @param point Point coordinates. This instance will store
* a copy of the array, not the array passed as argument.
* @param value Value of the objective function at the point.
*/
public PointVectorValuePair(final double[] point,
final double[] value) {
this(point, value, true);
}
/**
* Build a point/objective function value pair.
*
* @param point Point coordinates.
* @param value Value of the objective function at the point.
* @param copyArray if {@code true}, the input arrays will be copied,
* otherwise they will be referenced.
*/
public PointVectorValuePair(final double[] point,
final double[] value,
final boolean copyArray) {
super(copyArray ?
((point == null) ? null :
point.clone()) :
point,
copyArray ?
((value == null) ? null :
value.clone()) :
value);
}
/**
* Gets the point.
*
* @return a copy of the stored point.
*/
public double[] getPoint() {
final double[] p = getKey();
return p == null ? null : p.clone();
}
/**
* Gets a reference to the point.
*
* @return a reference to the internal array storing the point.
*/
public double[] getPointRef() {
return getKey();
}
/**
* Gets the value of the objective function.
*
* @return a copy of the stored value of the objective function.
*/
@Override
public double[] getValue() {
final double[] v = super.getValue();
return v == null ? null : v.clone();
}
/**
* Gets a reference to the value of the objective function.
*
* @return a reference to the internal array storing the value of
* the objective function.
*/
public double[] getValueRef() {
return super.getValue();
}
/**
* Replace the instance with a data transfer object for serialization.
* @return data transfer object that will be serialized
*/
private Object writeReplace() {
return new DataTransferObject(getKey(), getValue());
}
/** Internal class used only for serialization. */
private static class DataTransferObject implements Serializable {
/** Serializable UID. */
private static final long serialVersionUID = 20120513L;
/**
* Point coordinates.
* @Serial
*/
private final double[] point;
/**
* Value of the objective function at the point.
* @Serial
*/
private final double[] value;
/** Simple constructor.
* @param point Point coordinates.
* @param value Value of the objective function at the point.
*/
public DataTransferObject(final double[] point, final double[] value) {
this.point = point.clone();
this.value = value.clone();
}
/** Replace the deserialized data transfer object with a {@link PointValuePair}.
* @return replacement {@link PointValuePair}
*/
private Object readResolve() {
return new PointVectorValuePair(point, value, false);
}
}
}

View File

@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import java.util.Arrays;
/**
* Simple optimization constraints: lower and upper bounds.
* The valid range of the parameters is an interval that can be infinite
* (in one or both directions).
* <br/>
* Immutable class.
*
* @version $Id: SimpleBounds.java 1416643 2012-12-03 19:37:14Z tn $
* @since 3.1
*/
public class SimpleBounds implements OptimizationData {
/** Lower bounds. */
private final double[] lower;
/** Upper bounds. */
private final double[] upper;
/**
* @param lB Lower bounds.
* @param uB Upper bounds.
*/
public SimpleBounds(double[] lB,
double[] uB) {
lower = lB.clone();
upper = uB.clone();
}
/**
* Gets the lower bounds.
*
* @return the lower bounds.
*/
public double[] getLower() {
return lower.clone();
}
/**
* Gets the upper bounds.
*
* @return the upper bounds.
*/
public double[] getUpper() {
return upper.clone();
}
/**
* Factory method that creates instance of this class that represents
* unbounded ranges.
*
* @param dim Number of parameters.
* @return a new instance suitable for passing to an optimizer that
* requires bounds specification.
*/
public static SimpleBounds unbounded(int dim) {
final double[] lB = new double[dim];
Arrays.fill(lB, Double.NEGATIVE_INFINITY);
final double[] uB = new double[dim];
Arrays.fill(uB, Double.POSITIVE_INFINITY);
return new SimpleBounds(lB, uB);
}
}

View File

@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
/**
* Simple implementation of the {@link ConvergenceChecker} interface using
* only point coordinates.
*
* Convergence is considered to have been reached if either the relative
* difference between each point coordinate are smaller than a threshold
* or if either the absolute difference between the point coordinates are
* smaller than another threshold.
* <br/>
* The {@link #converged(int,Pair,Pair) converged} method will also return
* {@code true} if the number of iterations has been set (see
* {@link #SimplePointChecker(double,double,int) this constructor}).
*
* @param <PAIR> Type of the (point, value) pair.
* The type of the "value" part of the pair (not used by this class).
*
* @version $Id: SimplePointChecker.java 1413127 2012-11-24 04:37:30Z psteitz $
* @since 3.0
*/
public class SimplePointChecker<PAIR extends Pair<double[], ? extends Object>>
extends AbstractConvergenceChecker<PAIR> {
/**
* If {@link #maxIterationCount} is set to this value, the number of
* iterations will never cause {@link #converged(int, Pair, Pair)}
* to return {@code true}.
*/
private static final int ITERATION_CHECK_DISABLED = -1;
/**
* Number of iterations after which the
* {@link #converged(int, Pair, Pair)} method
* will return true (unless the check is disabled).
*/
private final int maxIterationCount;
/**
* Build an instance with specified thresholds.
* In order to perform only relative checks, the absolute tolerance
* must be set to a negative value. In order to perform only absolute
* checks, the relative tolerance must be set to a negative value.
*
* @param relativeThreshold relative tolerance threshold
* @param absoluteThreshold absolute tolerance threshold
*/
public SimplePointChecker(final double relativeThreshold,
final double absoluteThreshold) {
super(relativeThreshold, absoluteThreshold);
maxIterationCount = ITERATION_CHECK_DISABLED;
}
/**
* Builds an instance with specified thresholds.
* In order to perform only relative checks, the absolute tolerance
* must be set to a negative value. In order to perform only absolute
* checks, the relative tolerance must be set to a negative value.
*
* @param relativeThreshold Relative tolerance threshold.
* @param absoluteThreshold Absolute tolerance threshold.
* @param maxIter Maximum iteration count.
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
*
* @since 3.1
*/
public SimplePointChecker(final double relativeThreshold,
final double absoluteThreshold,
final int maxIter) {
super(relativeThreshold, absoluteThreshold);
if (maxIter <= 0) {
throw new NotStrictlyPositiveException(maxIter);
}
maxIterationCount = maxIter;
}
/**
* Check if the optimization algorithm has converged considering the
* last two points.
* This method may be called several times from the same algorithm
* iteration with different points. This can be detected by checking the
* iteration number at each call if needed. Each time this method is
* called, the previous and current point correspond to points with the
* same role at each iteration, so they can be compared. As an example,
* simplex-based algorithms call this method for all points of the simplex,
* not only for the best or worst ones.
*
* @param iteration Index of current iteration
* @param previous Best point in the previous iteration.
* @param current Best point in the current iteration.
* @return {@code true} if the arguments satify the convergence criterion.
*/
@Override
public boolean converged(final int iteration,
final PAIR previous,
final PAIR current) {
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
if (iteration >= maxIterationCount) {
return true;
}
}
final double[] p = previous.getKey();
final double[] c = current.getKey();
for (int i = 0; i < p.length; ++i) {
final double pi = p[i];
final double ci = c[i];
final double difference = FastMath.abs(pi - ci);
final double size = FastMath.max(FastMath.abs(pi), FastMath.abs(ci));
if (difference > size * getRelativeThreshold() &&
difference > getAbsoluteThreshold()) {
return false;
}
}
return true;
}
}

View File

@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
/**
* Simple implementation of the {@link ConvergenceChecker} interface using
* only objective function values.
*
* Convergence is considered to have been reached if either the relative
* difference between the objective function values is smaller than a
* threshold or if either the absolute difference between the objective
* function values is smaller than another threshold.
* <br/>
* The {@link #converged(int,PointValuePair,PointValuePair) converged}
* method will also return {@code true} if the number of iterations has been set
* (see {@link #SimpleValueChecker(double,double,int) this constructor}).
*
* @version $Id: SimpleValueChecker.java 1416643 2012-12-03 19:37:14Z tn $
* @since 3.0
*/
public class SimpleValueChecker
extends AbstractConvergenceChecker<PointValuePair> {
/**
* If {@link #maxIterationCount} is set to this value, the number of
* iterations will never cause
* {@link #converged(int,PointValuePair,PointValuePair)}
* to return {@code true}.
*/
private static final int ITERATION_CHECK_DISABLED = -1;
/**
* Number of iterations after which the
* {@link #converged(int,PointValuePair,PointValuePair)} method
* will return true (unless the check is disabled).
*/
private final int maxIterationCount;
/** Build an instance with specified thresholds.
*
* In order to perform only relative checks, the absolute tolerance
* must be set to a negative value. In order to perform only absolute
* checks, the relative tolerance must be set to a negative value.
*
* @param relativeThreshold relative tolerance threshold
* @param absoluteThreshold absolute tolerance threshold
*/
public SimpleValueChecker(final double relativeThreshold,
final double absoluteThreshold) {
super(relativeThreshold, absoluteThreshold);
maxIterationCount = ITERATION_CHECK_DISABLED;
}
/**
* Builds an instance with specified thresholds.
*
* In order to perform only relative checks, the absolute tolerance
* must be set to a negative value. In order to perform only absolute
* checks, the relative tolerance must be set to a negative value.
*
* @param relativeThreshold relative tolerance threshold
* @param absoluteThreshold absolute tolerance threshold
* @param maxIter Maximum iteration count.
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
*
* @since 3.1
*/
public SimpleValueChecker(final double relativeThreshold,
final double absoluteThreshold,
final int maxIter) {
super(relativeThreshold, absoluteThreshold);
if (maxIter <= 0) {
throw new NotStrictlyPositiveException(maxIter);
}
maxIterationCount = maxIter;
}
/**
* Check if the optimization algorithm has converged considering the
* last two points.
* This method may be called several time from the same algorithm
* iteration with different points. This can be detected by checking the
* iteration number at each call if needed. Each time this method is
* called, the previous and current point correspond to points with the
* same role at each iteration, so they can be compared. As an example,
* simplex-based algorithms call this method for all points of the simplex,
* not only for the best or worst ones.
*
* @param iteration Index of current iteration
* @param previous Best point in the previous iteration.
* @param current Best point in the current iteration.
* @return {@code true} if the algorithm has converged.
*/
@Override
public boolean converged(final int iteration,
final PointValuePair previous,
final PointValuePair current) {
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
if (iteration >= maxIterationCount) {
return true;
}
}
final double p = previous.getValue();
final double c = current.getValue();
final double difference = FastMath.abs(p - c);
final double size = FastMath.max(FastMath.abs(p), FastMath.abs(c));
return difference <= size * getRelativeThreshold() ||
difference <= getAbsoluteThreshold();
}
}

View File

@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
/**
* Simple implementation of the {@link ConvergenceChecker} interface using
* only objective function values.
*
* Convergence is considered to have been reached if either the relative
* difference between the objective function values is smaller than a
* threshold or if either the absolute difference between the objective
* function values is smaller than another threshold for all vectors elements.
* <br/>
* The {@link #converged(int,PointVectorValuePair,PointVectorValuePair) converged}
* method will also return {@code true} if the number of iterations has been set
* (see {@link #SimpleVectorValueChecker(double,double,int) this constructor}).
*
* @version $Id: SimpleVectorValueChecker.java 1416643 2012-12-03 19:37:14Z tn $
* @since 3.0
*/
public class SimpleVectorValueChecker
extends AbstractConvergenceChecker<PointVectorValuePair> {
/**
* If {@link #maxIterationCount} is set to this value, the number of
* iterations will never cause
* {@link #converged(int,PointVectorValuePair,PointVectorValuePair)}
* to return {@code true}.
*/
private static final int ITERATION_CHECK_DISABLED = -1;
/**
* Number of iterations after which the
* {@link #converged(int,PointVectorValuePair,PointVectorValuePair)} method
* will return true (unless the check is disabled).
*/
private final int maxIterationCount;
/**
* Build an instance with specified thresholds.
*
* In order to perform only relative checks, the absolute tolerance
* must be set to a negative value. In order to perform only absolute
* checks, the relative tolerance must be set to a negative value.
*
* @param relativeThreshold relative tolerance threshold
* @param absoluteThreshold absolute tolerance threshold
*/
public SimpleVectorValueChecker(final double relativeThreshold,
final double absoluteThreshold) {
super(relativeThreshold, absoluteThreshold);
maxIterationCount = ITERATION_CHECK_DISABLED;
}
/**
* Builds an instance with specified tolerance thresholds and
* iteration count.
*
* In order to perform only relative checks, the absolute tolerance
* must be set to a negative value. In order to perform only absolute
* checks, the relative tolerance must be set to a negative value.
*
* @param relativeThreshold Relative tolerance threshold.
* @param absoluteThreshold Absolute tolerance threshold.
* @param maxIter Maximum iteration count.
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
*
* @since 3.1
*/
public SimpleVectorValueChecker(final double relativeThreshold,
final double absoluteThreshold,
final int maxIter) {
super(relativeThreshold, absoluteThreshold);
if (maxIter <= 0) {
throw new NotStrictlyPositiveException(maxIter);
}
maxIterationCount = maxIter;
}
/**
* Check if the optimization algorithm has converged considering the
* last two points.
* This method may be called several times from the same algorithm
* iteration with different points. This can be detected by checking the
* iteration number at each call if needed. Each time this method is
* called, the previous and current point correspond to points with the
* same role at each iteration, so they can be compared. As an example,
* simplex-based algorithms call this method for all points of the simplex,
* not only for the best or worst ones.
*
* @param iteration Index of current iteration
* @param previous Best point in the previous iteration.
* @param current Best point in the current iteration.
* @return {@code true} if the arguments satify the convergence criterion.
*/
@Override
public boolean converged(final int iteration,
final PointVectorValuePair previous,
final PointVectorValuePair current) {
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
if (iteration >= maxIterationCount) {
return true;
}
}
final double[] p = previous.getValueRef();
final double[] c = current.getValueRef();
for (int i = 0; i < p.length; ++i) {
final double pi = p[i];
final double ci = c[i];
final double difference = FastMath.abs(pi - ci);
final double size = FastMath.max(FastMath.abs(pi), FastMath.abs(ci));
if (difference > size * getRelativeThreshold() &&
difference > getAbsoluteThreshold()) {
return false;
}
}
return true;
}
}

View File

@ -0,0 +1,229 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.ArrayRealVector;
/**
* A linear constraint for a linear optimization problem.
* <p>
* A linear constraint has one of the forms:
* <ul>
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> = v</li>
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> &lt;= v</li>
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> >= v</li>
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> =
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> &lt;=
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> >=
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
* </ul>
* The c<sub>i</sub>, l<sub>i</sub> or r<sub>i</sub> are the coefficients of the constraints, the x<sub>i</sub>
* are the coordinates of the current point and v is the value of the constraint.
* </p>
*
* @version $Id: LinearConstraint.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class LinearConstraint implements Serializable {
/** Serializable version identifier. */
private static final long serialVersionUID = -764632794033034092L;
/** Coefficients of the constraint (left hand side). */
private final transient RealVector coefficients;
/** Relationship between left and right hand sides (=, &lt;=, >=). */
private final Relationship relationship;
/** Value of the constraint (right hand side). */
private final double value;
/**
* Build a constraint involving a single linear equation.
* <p>
* A linear constraint with a single linear equation has one of the forms:
* <ul>
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> = v</li>
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> &lt;= v</li>
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> >= v</li>
* </ul>
* </p>
* @param coefficients The coefficients of the constraint (left hand side)
* @param relationship The type of (in)equality used in the constraint
* @param value The value of the constraint (right hand side)
*/
public LinearConstraint(final double[] coefficients,
final Relationship relationship,
final double value) {
this(new ArrayRealVector(coefficients), relationship, value);
}
/**
* Build a constraint involving a single linear equation.
* <p>
* A linear constraint with a single linear equation has one of the forms:
* <ul>
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> = v</li>
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> &lt;= v</li>
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> >= v</li>
* </ul>
* </p>
* @param coefficients The coefficients of the constraint (left hand side)
* @param relationship The type of (in)equality used in the constraint
* @param value The value of the constraint (right hand side)
*/
public LinearConstraint(final RealVector coefficients,
final Relationship relationship,
final double value) {
this.coefficients = coefficients;
this.relationship = relationship;
this.value = value;
}
/**
* Build a constraint involving two linear equations.
* <p>
* A linear constraint with two linear equation has one of the forms:
* <ul>
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> =
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> &lt;=
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> >=
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
* </ul>
* </p>
* @param lhsCoefficients The coefficients of the linear expression on the left hand side of the constraint
* @param lhsConstant The constant term of the linear expression on the left hand side of the constraint
* @param relationship The type of (in)equality used in the constraint
* @param rhsCoefficients The coefficients of the linear expression on the right hand side of the constraint
* @param rhsConstant The constant term of the linear expression on the right hand side of the constraint
*/
public LinearConstraint(final double[] lhsCoefficients, final double lhsConstant,
final Relationship relationship,
final double[] rhsCoefficients, final double rhsConstant) {
double[] sub = new double[lhsCoefficients.length];
for (int i = 0; i < sub.length; ++i) {
sub[i] = lhsCoefficients[i] - rhsCoefficients[i];
}
this.coefficients = new ArrayRealVector(sub, false);
this.relationship = relationship;
this.value = rhsConstant - lhsConstant;
}
/**
* Build a constraint involving two linear equations.
* <p>
* A linear constraint with two linear equation has one of the forms:
* <ul>
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> =
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> &lt;=
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> >=
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
* </ul>
* </p>
* @param lhsCoefficients The coefficients of the linear expression on the left hand side of the constraint
* @param lhsConstant The constant term of the linear expression on the left hand side of the constraint
* @param relationship The type of (in)equality used in the constraint
* @param rhsCoefficients The coefficients of the linear expression on the right hand side of the constraint
* @param rhsConstant The constant term of the linear expression on the right hand side of the constraint
*/
public LinearConstraint(final RealVector lhsCoefficients, final double lhsConstant,
final Relationship relationship,
final RealVector rhsCoefficients, final double rhsConstant) {
this.coefficients = lhsCoefficients.subtract(rhsCoefficients);
this.relationship = relationship;
this.value = rhsConstant - lhsConstant;
}
/**
* Gets the coefficients of the constraint (left hand side).
*
* @return the coefficients of the constraint (left hand side).
*/
public RealVector getCoefficients() {
return coefficients;
}
/**
* Gets the relationship between left and right hand sides.
*
* @return the relationship between left and right hand sides.
*/
public Relationship getRelationship() {
return relationship;
}
/**
* Gets the value of the constraint (right hand side).
*
* @return the value of the constraint (right hand side).
*/
public double getValue() {
return value;
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other instanceof LinearConstraint) {
LinearConstraint rhs = (LinearConstraint) other;
return relationship == rhs.relationship &&
value == rhs.value &&
coefficients.equals(rhs.coefficients);
}
return false;
}
@Override
public int hashCode() {
return relationship.hashCode() ^
Double.valueOf(value).hashCode() ^
coefficients.hashCode();
}
/**
* Serialize the instance.
* @param oos stream where object should be written
* @throws IOException if object cannot be written to stream
*/
private void writeObject(ObjectOutputStream oos)
throws IOException {
oos.defaultWriteObject();
MatrixUtils.serializeRealVector(coefficients, oos);
}
/**
* Deserialize the instance.
* @param ois stream from which the object should be read
* @throws ClassNotFoundException if a class in the stream cannot be found
* @throws IOException if object cannot be read from the stream
*/
private void readObject(ObjectInputStream ois)
throws ClassNotFoundException, IOException {
ois.defaultReadObject();
MatrixUtils.deserializeRealVector(this, "coefficients", ois);
}
}

View File

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
import java.util.Set;
import java.util.HashSet;
import java.util.Collection;
import java.util.Collections;
import org.apache.commons.math3.optim.OptimizationData;
/**
* Class that represents a set of {@link LinearConstraint linear constraints}.
*
* @version $Id$
* @since 3.1
*/
public class LinearConstraintSet implements OptimizationData {
/** Set of constraints. */
private final Set<LinearConstraint> linearConstraints
= new HashSet<LinearConstraint>();
/**
* Creates a set containing the given constraints.
*
* @param constraints Constraints.
*/
public LinearConstraintSet(LinearConstraint... constraints) {
for (LinearConstraint c : constraints) {
linearConstraints.add(c);
}
}
/**
* Creates a set containing the given constraints.
*
* @param constraints Constraints.
*/
public LinearConstraintSet(Collection<LinearConstraint> constraints) {
linearConstraints.addAll(constraints);
}
/**
* Gets the set of linear constraints.
*
* @return the constraints.
*/
public Collection<LinearConstraint> getConstraints() {
return Collections.unmodifiableSet(linearConstraints);
}
}

View File

@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.optim.OptimizationData;
/**
* An objective function for a linear optimization problem.
* <p>
* A linear objective function has one the form:
* <pre>
* c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> + d
* </pre>
* The c<sub>i</sub> and d are the coefficients of the equation,
* the x<sub>i</sub> are the coordinates of the current point.
* </p>
*
* @version $Id: LinearObjectiveFunction.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class LinearObjectiveFunction
implements MultivariateFunction,
OptimizationData,
Serializable {
/** Serializable version identifier. */
private static final long serialVersionUID = -4531815507568396090L;
/** Coefficients of the linear equation (c<sub>i</sub>). */
private final transient RealVector coefficients;
/** Constant term of the linear equation. */
private final double constantTerm;
/**
* @param coefficients Coefficients for the linear equation being optimized.
* @param constantTerm Constant term of the linear equation.
*/
public LinearObjectiveFunction(double[] coefficients, double constantTerm) {
this(new ArrayRealVector(coefficients), constantTerm);
}
/**
* @param coefficients Coefficients for the linear equation being optimized.
* @param constantTerm Constant term of the linear equation.
*/
public LinearObjectiveFunction(RealVector coefficients, double constantTerm) {
this.coefficients = coefficients;
this.constantTerm = constantTerm;
}
/**
* Gets the coefficients of the linear equation being optimized.
*
* @return coefficients of the linear equation being optimized.
*/
public RealVector getCoefficients() {
return coefficients;
}
/**
* Gets the constant of the linear equation being optimized.
*
* @return constant of the linear equation being optimized.
*/
public double getConstantTerm() {
return constantTerm;
}
/**
* Computes the value of the linear equation at the current point.
*
* @param point Point at which linear equation must be evaluated.
* @return the value of the linear equation at the current point.
*/
public double value(final double[] point) {
return value(new ArrayRealVector(point, false));
}
/**
* Computes the value of the linear equation at the current point.
*
* @param point Point at which linear equation must be evaluated.
* @return the value of the linear equation at the current point.
*/
public double value(final RealVector point) {
return coefficients.dotProduct(point) + constantTerm;
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other instanceof LinearObjectiveFunction) {
LinearObjectiveFunction rhs = (LinearObjectiveFunction) other;
return (constantTerm == rhs.constantTerm) && coefficients.equals(rhs.coefficients);
}
return false;
}
@Override
public int hashCode() {
return Double.valueOf(constantTerm).hashCode() ^ coefficients.hashCode();
}
/**
* Serialize the instance.
* @param oos stream where object should be written
* @throws IOException if object cannot be written to stream
*/
private void writeObject(ObjectOutputStream oos)
throws IOException {
oos.defaultWriteObject();
MatrixUtils.serializeRealVector(coefficients, oos);
}
/**
* Deserialize the instance.
* @param ois stream from which the object should be read
* @throws ClassNotFoundException if a class in the stream cannot be found
* @throws IOException if object cannot be read from the stream
*/
private void readObject(ObjectInputStream ois)
throws ClassNotFoundException, IOException {
ois.defaultReadObject();
MatrixUtils.deserializeRealVector(this, "coefficients", ois);
}
}

View File

@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
import java.util.Collection;
import java.util.Collections;
import org.apache.commons.math3.exception.TooManyIterationsException;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer;
/**
* Base class for implementing linear optimizers.
*
* @version $Id: AbstractLinearOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
* @since 3.1
*/
public abstract class LinearOptimizer
extends MultivariateOptimizer {
/**
* Linear objective function.
*/
private LinearObjectiveFunction function;
/**
* Linear constraints.
*/
private Collection<LinearConstraint> linearConstraints;
/**
* Whether to restrict the variables to non-negative values.
*/
private boolean nonNegative;
/**
* Simple constructor with default settings.
*
*/
protected LinearOptimizer() {
super(null); // No convergence checker.
}
/**
* @return {@code true} if the variables are restricted to non-negative values.
*/
protected boolean isRestrictedToNonNegative() {
return nonNegative;
}
/**
* @return the optimization type.
*/
protected LinearObjectiveFunction getFunction() {
return function;
}
/**
* @return the optimization type.
*/
protected Collection<LinearConstraint> getConstraints() {
return Collections.unmodifiableCollection(linearConstraints);
}
/**
* {@inheritDoc}
*
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link org.apache.commons.math3.optim.MaxIter}</li>
* <li>{@link LinearObjectiveFunction}</li>
* <li>{@link LinearConstraintSet}</li>
* <li>{@link NonNegativeConstraint}</li>
* </ul>
* @return {@inheritDoc}
* @throws TooManyIterationsException if the maximal number of
* iterations is exceeded.
*/
@Override
public PointValuePair optimize(OptimizationData... optData)
throws TooManyIterationsException {
// Retrieve settings.
parseOptimizationData(optData);
// Set up base class and perform computation.
return super.optimize(optData);
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link LinearObjectiveFunction}</li>
* <li>{@link LinearConstraintSet}</li>
* <li>{@link NonNegativeConstraint}</li>
* </ul>
*/
private void parseOptimizationData(OptimizationData... optData) {
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData) {
if (data instanceof LinearObjectiveFunction) {
function = (LinearObjectiveFunction) data;
continue;
}
if (data instanceof LinearConstraintSet) {
linearConstraints = ((LinearConstraintSet) data).getConstraints();
continue;
}
if (data instanceof NonNegativeConstraint) {
nonNegative = ((NonNegativeConstraint) data).isRestrictedToNonNegative();
continue;
}
}
}
}

View File

@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
import org.apache.commons.math3.exception.MathIllegalStateException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
/**
* This class represents exceptions thrown by optimizers when no solution fulfills the constraints.
*
* @version $Id: NoFeasibleSolutionException.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class NoFeasibleSolutionException extends MathIllegalStateException {
/** Serializable version identifier. */
private static final long serialVersionUID = -3044253632189082760L;
/**
* Simple constructor using a default message.
*/
public NoFeasibleSolutionException() {
super(LocalizedFormats.NO_FEASIBLE_SOLUTION);
}
}

View File

@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
import org.apache.commons.math3.optim.OptimizationData;
/**
* A constraint for a linear optimization problem indicating whether all
* variables must be restricted to non-negative values.
*
* @version $Id$
* @since 3.1
*/
public class NonNegativeConstraint implements OptimizationData {
/** Whether the variables are all positive. */
private final boolean isRestricted;
/**
* @param restricted If {@code true}, all the variables must be positive.
*/
public NonNegativeConstraint(boolean restricted) {
isRestricted = restricted;
}
/**
* Indicates whether all the variables must be restricted to non-negative
* values.
*
* @return {@code true} if all the variables must be positive.
*/
public boolean isRestrictedToNonNegative() {
return isRestricted;
}
}

View File

@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
/**
* Types of relationships between two cells in a Solver {@link LinearConstraint}.
*
* @version $Id: Relationship.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public enum Relationship {
/** Equality relationship. */
EQ("="),
/** Lesser than or equal relationship. */
LEQ("<="),
/** Greater than or equal relationship. */
GEQ(">=");
/** Display string for the relationship. */
private final String stringValue;
/**
* Simple constructor.
*
* @param stringValue Display string for the relationship.
*/
private Relationship(String stringValue) {
this.stringValue = stringValue;
}
@Override
public String toString() {
return stringValue;
}
/**
* Gets the relationship obtained when multiplying all coefficients by -1.
*
* @return the opposite relationship.
*/
public Relationship oppositeRelationship() {
switch (this) {
case LEQ :
return GEQ;
case GEQ :
return LEQ;
default :
return EQ;
}
}
}

View File

@ -0,0 +1,245 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.exception.TooManyIterationsException;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.util.Precision;
/**
* Solves a linear problem using the "Two-Phase Simplex" method.
*
* @version $Id: SimplexSolver.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class SimplexSolver extends LinearOptimizer {
/** Default amount of error to accept for algorithm convergence. */
private static final double DEFAULT_EPSILON = 1.0e-6;
/** Default amount of error to accept in floating point comparisons (as ulps). */
private static final int DEFAULT_ULPS = 10;
/** Amount of error to accept for algorithm convergence. */
private final double epsilon;
/** Amount of error to accept in floating point comparisons (as ulps). */
private final int maxUlps;
/**
* Builds a simplex solver with default settings.
*/
public SimplexSolver() {
this(DEFAULT_EPSILON, DEFAULT_ULPS);
}
/**
* Builds a simplex solver with a specified accepted amount of error.
*
* @param epsilon Amount of error to accept for algorithm convergence.
* @param maxUlps Amount of error to accept in floating point comparisons.
*/
public SimplexSolver(final double epsilon,
final int maxUlps) {
this.epsilon = epsilon;
this.maxUlps = maxUlps;
}
/**
* Returns the column with the most negative coefficient in the objective function row.
*
* @param tableau Simple tableau for the problem.
* @return the column with the most negative coefficient.
*/
private Integer getPivotColumn(SimplexTableau tableau) {
double minValue = 0;
Integer minPos = null;
for (int i = tableau.getNumObjectiveFunctions(); i < tableau.getWidth() - 1; i++) {
final double entry = tableau.getEntry(0, i);
// check if the entry is strictly smaller than the current minimum
// do not use a ulp/epsilon check
if (entry < minValue) {
minValue = entry;
minPos = i;
}
}
return minPos;
}
/**
* Returns the row with the minimum ratio as given by the minimum ratio test (MRT).
*
* @param tableau Simple tableau for the problem.
* @param col Column to test the ratio of (see {@link #getPivotColumn(SimplexTableau)}).
* @return the row with the minimum ratio.
*/
private Integer getPivotRow(SimplexTableau tableau, final int col) {
// create a list of all the rows that tie for the lowest score in the minimum ratio test
List<Integer> minRatioPositions = new ArrayList<Integer>();
double minRatio = Double.MAX_VALUE;
for (int i = tableau.getNumObjectiveFunctions(); i < tableau.getHeight(); i++) {
final double rhs = tableau.getEntry(i, tableau.getWidth() - 1);
final double entry = tableau.getEntry(i, col);
if (Precision.compareTo(entry, 0d, maxUlps) > 0) {
final double ratio = rhs / entry;
// check if the entry is strictly equal to the current min ratio
// do not use a ulp/epsilon check
final int cmp = Double.compare(ratio, minRatio);
if (cmp == 0) {
minRatioPositions.add(i);
} else if (cmp < 0) {
minRatio = ratio;
minRatioPositions = new ArrayList<Integer>();
minRatioPositions.add(i);
}
}
}
if (minRatioPositions.size() == 0) {
return null;
} else if (minRatioPositions.size() > 1) {
// there's a degeneracy as indicated by a tie in the minimum ratio test
// 1. check if there's an artificial variable that can be forced out of the basis
if (tableau.getNumArtificialVariables() > 0) {
for (Integer row : minRatioPositions) {
for (int i = 0; i < tableau.getNumArtificialVariables(); i++) {
int column = i + tableau.getArtificialVariableOffset();
final double entry = tableau.getEntry(row, column);
if (Precision.equals(entry, 1d, maxUlps) && row.equals(tableau.getBasicRow(column))) {
return row;
}
}
}
}
// 2. apply Bland's rule to prevent cycling:
// take the row for which the corresponding basic variable has the smallest index
//
// see http://www.stanford.edu/class/msande310/blandrule.pdf
// see http://en.wikipedia.org/wiki/Bland%27s_rule (not equivalent to the above paper)
//
// Additional heuristic: if we did not get a solution after half of maxIterations
// revert to the simple case of just returning the top-most row
// This heuristic is based on empirical data gathered while investigating MATH-828.
if (getEvaluations() < getMaxEvaluations() / 2) {
Integer minRow = null;
int minIndex = tableau.getWidth();
final int varStart = tableau.getNumObjectiveFunctions();
final int varEnd = tableau.getWidth() - 1;
for (Integer row : minRatioPositions) {
for (int i = varStart; i < varEnd && !row.equals(minRow); i++) {
final Integer basicRow = tableau.getBasicRow(i);
if (basicRow != null && basicRow.equals(row)) {
if (i < minIndex) {
minIndex = i;
minRow = row;
}
}
}
}
return minRow;
}
}
return minRatioPositions.get(0);
}
/**
* Runs one iteration of the Simplex method on the given model.
*
* @param tableau Simple tableau for the problem.
* @throws TooManyIterationsException if the allowed number of iterations has been exhausted.
* @throws UnboundedSolutionException if the model is found not to have a bounded solution.
*/
protected void doIteration(final SimplexTableau tableau)
throws TooManyIterationsException,
UnboundedSolutionException {
incrementIterationCount();
Integer pivotCol = getPivotColumn(tableau);
Integer pivotRow = getPivotRow(tableau, pivotCol);
if (pivotRow == null) {
throw new UnboundedSolutionException();
}
// set the pivot element to 1
double pivotVal = tableau.getEntry(pivotRow, pivotCol);
tableau.divideRow(pivotRow, pivotVal);
// set the rest of the pivot column to 0
for (int i = 0; i < tableau.getHeight(); i++) {
if (i != pivotRow) {
final double multiplier = tableau.getEntry(i, pivotCol);
tableau.subtractRow(i, pivotRow, multiplier);
}
}
}
/**
* Solves Phase 1 of the Simplex method.
*
* @param tableau Simple tableau for the problem.
* @throws TooManyIterationsException if the allowed number of iterations has been exhausted.
* @throws UnboundedSolutionException if the model is found not to have a bounded solution.
* @throws NoFeasibleSolutionException if there is no feasible solution?
*/
protected void solvePhase1(final SimplexTableau tableau)
throws TooManyIterationsException,
UnboundedSolutionException,
NoFeasibleSolutionException {
// make sure we're in Phase 1
if (tableau.getNumArtificialVariables() == 0) {
return;
}
while (!tableau.isOptimal()) {
doIteration(tableau);
}
// if W is not zero then we have no feasible solution
if (!Precision.equals(tableau.getEntry(0, tableau.getRhsOffset()), 0d, epsilon)) {
throw new NoFeasibleSolutionException();
}
}
/** {@inheritDoc} */
@Override
public PointValuePair doOptimize()
throws TooManyIterationsException,
UnboundedSolutionException,
NoFeasibleSolutionException {
final SimplexTableau tableau =
new SimplexTableau(getFunction(),
getConstraints(),
getGoalType(),
isRestrictedToNonNegative(),
epsilon,
maxUlps);
solvePhase1(tableau);
tableau.dropPhase1Objective();
while (!tableau.isOptimal()) {
doIteration(tableau);
}
return tableau.getSolution();
}
}

View File

@ -0,0 +1,637 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Precision;
/**
* A tableau for use in the Simplex method.
*
* <p>
* Example:
* <pre>
* W | Z | x1 | x2 | x- | s1 | s2 | a1 | RHS
* ---------------------------------------------------
* -1 0 0 0 0 0 0 1 0 &lt;= phase 1 objective
* 0 1 -15 -10 0 0 0 0 0 &lt;= phase 2 objective
* 0 0 1 0 0 1 0 0 2 &lt;= constraint 1
* 0 0 0 1 0 0 1 0 3 &lt;= constraint 2
* 0 0 1 1 0 0 0 1 4 &lt;= constraint 3
* </pre>
* W: Phase 1 objective function</br>
* Z: Phase 2 objective function</br>
* x1 &amp; x2: Decision variables</br>
* x-: Extra decision variable to allow for negative values</br>
* s1 &amp; s2: Slack/Surplus variables</br>
* a1: Artificial variable</br>
* RHS: Right hand side</br>
* </p>
* @version $Id: SimplexTableau.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
class SimplexTableau implements Serializable {
/** Column label for negative vars. */
private static final String NEGATIVE_VAR_COLUMN_LABEL = "x-";
/** Default amount of error to accept in floating point comparisons (as ulps). */
private static final int DEFAULT_ULPS = 10;
/** The cut-off threshold to zero-out entries. */
private static final double CUTOFF_THRESHOLD = 1e-12;
/** Serializable version identifier. */
private static final long serialVersionUID = -1369660067587938365L;
/** Linear objective function. */
private final LinearObjectiveFunction f;
/** Linear constraints. */
private final List<LinearConstraint> constraints;
/** Whether to restrict the variables to non-negative values. */
private final boolean restrictToNonNegative;
/** The variables each column represents */
private final List<String> columnLabels = new ArrayList<String>();
/** Simple tableau. */
private transient RealMatrix tableau;
/** Number of decision variables. */
private final int numDecisionVariables;
/** Number of slack variables. */
private final int numSlackVariables;
/** Number of artificial variables. */
private int numArtificialVariables;
/** Amount of error to accept when checking for optimality. */
private final double epsilon;
/** Amount of error to accept in floating point comparisons. */
private final int maxUlps;
/**
* Builds a tableau for a linear problem.
*
* @param f Linear objective function.
* @param constraints Linear constraints.
* @param goalType Optimization goal: either {@link GoalType#MAXIMIZE}
* or {@link GoalType#MINIMIZE}.
* @param restrictToNonNegative Whether to restrict the variables to non-negative values.
* @param epsilon Amount of error to accept when checking for optimality.
*/
SimplexTableau(final LinearObjectiveFunction f,
final Collection<LinearConstraint> constraints,
final GoalType goalType,
final boolean restrictToNonNegative,
final double epsilon) {
this(f, constraints, goalType, restrictToNonNegative, epsilon, DEFAULT_ULPS);
}
/**
* Build a tableau for a linear problem.
* @param f linear objective function
* @param constraints linear constraints
* @param goalType type of optimization goal: either {@link GoalType#MAXIMIZE} or {@link GoalType#MINIMIZE}
* @param restrictToNonNegative whether to restrict the variables to non-negative values
* @param epsilon amount of error to accept when checking for optimality
* @param maxUlps amount of error to accept in floating point comparisons
*/
SimplexTableau(final LinearObjectiveFunction f,
final Collection<LinearConstraint> constraints,
final GoalType goalType,
final boolean restrictToNonNegative,
final double epsilon,
final int maxUlps) {
this.f = f;
this.constraints = normalizeConstraints(constraints);
this.restrictToNonNegative = restrictToNonNegative;
this.epsilon = epsilon;
this.maxUlps = maxUlps;
this.numDecisionVariables = f.getCoefficients().getDimension() +
(restrictToNonNegative ? 0 : 1);
this.numSlackVariables = getConstraintTypeCounts(Relationship.LEQ) +
getConstraintTypeCounts(Relationship.GEQ);
this.numArtificialVariables = getConstraintTypeCounts(Relationship.EQ) +
getConstraintTypeCounts(Relationship.GEQ);
this.tableau = createTableau(goalType == GoalType.MAXIMIZE);
initializeColumnLabels();
}
/**
* Initialize the labels for the columns.
*/
protected void initializeColumnLabels() {
if (getNumObjectiveFunctions() == 2) {
columnLabels.add("W");
}
columnLabels.add("Z");
for (int i = 0; i < getOriginalNumDecisionVariables(); i++) {
columnLabels.add("x" + i);
}
if (!restrictToNonNegative) {
columnLabels.add(NEGATIVE_VAR_COLUMN_LABEL);
}
for (int i = 0; i < getNumSlackVariables(); i++) {
columnLabels.add("s" + i);
}
for (int i = 0; i < getNumArtificialVariables(); i++) {
columnLabels.add("a" + i);
}
columnLabels.add("RHS");
}
/**
* Create the tableau by itself.
* @param maximize if true, goal is to maximize the objective function
* @return created tableau
*/
protected RealMatrix createTableau(final boolean maximize) {
// create a matrix of the correct size
int width = numDecisionVariables + numSlackVariables +
numArtificialVariables + getNumObjectiveFunctions() + 1; // + 1 is for RHS
int height = constraints.size() + getNumObjectiveFunctions();
Array2DRowRealMatrix matrix = new Array2DRowRealMatrix(height, width);
// initialize the objective function rows
if (getNumObjectiveFunctions() == 2) {
matrix.setEntry(0, 0, -1);
}
int zIndex = (getNumObjectiveFunctions() == 1) ? 0 : 1;
matrix.setEntry(zIndex, zIndex, maximize ? 1 : -1);
RealVector objectiveCoefficients =
maximize ? f.getCoefficients().mapMultiply(-1) : f.getCoefficients();
copyArray(objectiveCoefficients.toArray(), matrix.getDataRef()[zIndex]);
matrix.setEntry(zIndex, width - 1,
maximize ? f.getConstantTerm() : -1 * f.getConstantTerm());
if (!restrictToNonNegative) {
matrix.setEntry(zIndex, getSlackVariableOffset() - 1,
getInvertedCoefficientSum(objectiveCoefficients));
}
// initialize the constraint rows
int slackVar = 0;
int artificialVar = 0;
for (int i = 0; i < constraints.size(); i++) {
LinearConstraint constraint = constraints.get(i);
int row = getNumObjectiveFunctions() + i;
// decision variable coefficients
copyArray(constraint.getCoefficients().toArray(), matrix.getDataRef()[row]);
// x-
if (!restrictToNonNegative) {
matrix.setEntry(row, getSlackVariableOffset() - 1,
getInvertedCoefficientSum(constraint.getCoefficients()));
}
// RHS
matrix.setEntry(row, width - 1, constraint.getValue());
// slack variables
if (constraint.getRelationship() == Relationship.LEQ) {
matrix.setEntry(row, getSlackVariableOffset() + slackVar++, 1); // slack
} else if (constraint.getRelationship() == Relationship.GEQ) {
matrix.setEntry(row, getSlackVariableOffset() + slackVar++, -1); // excess
}
// artificial variables
if ((constraint.getRelationship() == Relationship.EQ) ||
(constraint.getRelationship() == Relationship.GEQ)) {
matrix.setEntry(0, getArtificialVariableOffset() + artificialVar, 1);
matrix.setEntry(row, getArtificialVariableOffset() + artificialVar++, 1);
matrix.setRowVector(0, matrix.getRowVector(0).subtract(matrix.getRowVector(row)));
}
}
return matrix;
}
/**
* Get new versions of the constraints which have positive right hand sides.
* @param originalConstraints original (not normalized) constraints
* @return new versions of the constraints
*/
public List<LinearConstraint> normalizeConstraints(Collection<LinearConstraint> originalConstraints) {
List<LinearConstraint> normalized = new ArrayList<LinearConstraint>();
for (LinearConstraint constraint : originalConstraints) {
normalized.add(normalize(constraint));
}
return normalized;
}
/**
* Get a new equation equivalent to this one with a positive right hand side.
* @param constraint reference constraint
* @return new equation
*/
private LinearConstraint normalize(final LinearConstraint constraint) {
if (constraint.getValue() < 0) {
return new LinearConstraint(constraint.getCoefficients().mapMultiply(-1),
constraint.getRelationship().oppositeRelationship(),
-1 * constraint.getValue());
}
return new LinearConstraint(constraint.getCoefficients(),
constraint.getRelationship(), constraint.getValue());
}
/**
* Get the number of objective functions in this tableau.
* @return 2 for Phase 1. 1 for Phase 2.
*/
protected final int getNumObjectiveFunctions() {
return this.numArtificialVariables > 0 ? 2 : 1;
}
/**
* Get a count of constraints corresponding to a specified relationship.
* @param relationship relationship to count
* @return number of constraint with the specified relationship
*/
private int getConstraintTypeCounts(final Relationship relationship) {
int count = 0;
for (final LinearConstraint constraint : constraints) {
if (constraint.getRelationship() == relationship) {
++count;
}
}
return count;
}
/**
* Get the -1 times the sum of all coefficients in the given array.
* @param coefficients coefficients to sum
* @return the -1 times the sum of all coefficients in the given array.
*/
protected static double getInvertedCoefficientSum(final RealVector coefficients) {
double sum = 0;
for (double coefficient : coefficients.toArray()) {
sum -= coefficient;
}
return sum;
}
/**
* Checks whether the given column is basic.
* @param col index of the column to check
* @return the row that the variable is basic in. null if the column is not basic
*/
protected Integer getBasicRow(final int col) {
Integer row = null;
for (int i = 0; i < getHeight(); i++) {
final double entry = getEntry(i, col);
if (Precision.equals(entry, 1d, maxUlps) && (row == null)) {
row = i;
} else if (!Precision.equals(entry, 0d, maxUlps)) {
return null;
}
}
return row;
}
/**
* Removes the phase 1 objective function, positive cost non-artificial variables,
* and the non-basic artificial variables from this tableau.
*/
protected void dropPhase1Objective() {
if (getNumObjectiveFunctions() == 1) {
return;
}
Set<Integer> columnsToDrop = new TreeSet<Integer>();
columnsToDrop.add(0);
// positive cost non-artificial variables
for (int i = getNumObjectiveFunctions(); i < getArtificialVariableOffset(); i++) {
final double entry = tableau.getEntry(0, i);
if (Precision.compareTo(entry, 0d, epsilon) > 0) {
columnsToDrop.add(i);
}
}
// non-basic artificial variables
for (int i = 0; i < getNumArtificialVariables(); i++) {
int col = i + getArtificialVariableOffset();
if (getBasicRow(col) == null) {
columnsToDrop.add(col);
}
}
double[][] matrix = new double[getHeight() - 1][getWidth() - columnsToDrop.size()];
for (int i = 1; i < getHeight(); i++) {
int col = 0;
for (int j = 0; j < getWidth(); j++) {
if (!columnsToDrop.contains(j)) {
matrix[i - 1][col++] = tableau.getEntry(i, j);
}
}
}
// remove the columns in reverse order so the indices are correct
Integer[] drop = columnsToDrop.toArray(new Integer[columnsToDrop.size()]);
for (int i = drop.length - 1; i >= 0; i--) {
columnLabels.remove((int) drop[i]);
}
this.tableau = new Array2DRowRealMatrix(matrix);
this.numArtificialVariables = 0;
}
/**
* @param src the source array
* @param dest the destination array
*/
private void copyArray(final double[] src, final double[] dest) {
System.arraycopy(src, 0, dest, getNumObjectiveFunctions(), src.length);
}
/**
* Returns whether the problem is at an optimal state.
* @return whether the model has been solved
*/
boolean isOptimal() {
for (int i = getNumObjectiveFunctions(); i < getWidth() - 1; i++) {
final double entry = tableau.getEntry(0, i);
if (Precision.compareTo(entry, 0d, epsilon) < 0) {
return false;
}
}
return true;
}
/**
* Get the current solution.
* @return current solution
*/
protected PointValuePair getSolution() {
int negativeVarColumn = columnLabels.indexOf(NEGATIVE_VAR_COLUMN_LABEL);
Integer negativeVarBasicRow = negativeVarColumn > 0 ? getBasicRow(negativeVarColumn) : null;
double mostNegative = negativeVarBasicRow == null ? 0 : getEntry(negativeVarBasicRow, getRhsOffset());
Set<Integer> basicRows = new HashSet<Integer>();
double[] coefficients = new double[getOriginalNumDecisionVariables()];
for (int i = 0; i < coefficients.length; i++) {
int colIndex = columnLabels.indexOf("x" + i);
if (colIndex < 0) {
coefficients[i] = 0;
continue;
}
Integer basicRow = getBasicRow(colIndex);
if (basicRow != null && basicRow == 0) {
// if the basic row is found to be the objective function row
// set the coefficient to 0 -> this case handles unconstrained
// variables that are still part of the objective function
coefficients[i] = 0;
} else if (basicRows.contains(basicRow)) {
// if multiple variables can take a given value
// then we choose the first and set the rest equal to 0
coefficients[i] = 0 - (restrictToNonNegative ? 0 : mostNegative);
} else {
basicRows.add(basicRow);
coefficients[i] =
(basicRow == null ? 0 : getEntry(basicRow, getRhsOffset())) -
(restrictToNonNegative ? 0 : mostNegative);
}
}
return new PointValuePair(coefficients, f.value(coefficients));
}
/**
* Subtracts a multiple of one row from another.
* <p>
* After application of this operation, the following will hold:
* <pre>minuendRow = minuendRow - multiple * subtrahendRow</pre>
*
* @param dividendRow index of the row
* @param divisor value of the divisor
*/
protected void divideRow(final int dividendRow, final double divisor) {
for (int j = 0; j < getWidth(); j++) {
tableau.setEntry(dividendRow, j, tableau.getEntry(dividendRow, j) / divisor);
}
}
/**
* Subtracts a multiple of one row from another.
* <p>
* After application of this operation, the following will hold:
* <pre>minuendRow = minuendRow - multiple * subtrahendRow</pre>
*
* @param minuendRow row index
* @param subtrahendRow row index
* @param multiple multiplication factor
*/
protected void subtractRow(final int minuendRow, final int subtrahendRow,
final double multiple) {
for (int i = 0; i < getWidth(); i++) {
double result = tableau.getEntry(minuendRow, i) - tableau.getEntry(subtrahendRow, i) * multiple;
// cut-off values smaller than the CUTOFF_THRESHOLD, otherwise may lead to numerical instabilities
if (FastMath.abs(result) < CUTOFF_THRESHOLD) {
result = 0.0;
}
tableau.setEntry(minuendRow, i, result);
}
}
/**
* Get the width of the tableau.
* @return width of the tableau
*/
protected final int getWidth() {
return tableau.getColumnDimension();
}
/**
* Get the height of the tableau.
* @return height of the tableau
*/
protected final int getHeight() {
return tableau.getRowDimension();
}
/**
* Get an entry of the tableau.
* @param row row index
* @param column column index
* @return entry at (row, column)
*/
protected final double getEntry(final int row, final int column) {
return tableau.getEntry(row, column);
}
/**
* Set an entry of the tableau.
* @param row row index
* @param column column index
* @param value for the entry
*/
protected final void setEntry(final int row, final int column,
final double value) {
tableau.setEntry(row, column, value);
}
/**
* Get the offset of the first slack variable.
* @return offset of the first slack variable
*/
protected final int getSlackVariableOffset() {
return getNumObjectiveFunctions() + numDecisionVariables;
}
/**
* Get the offset of the first artificial variable.
* @return offset of the first artificial variable
*/
protected final int getArtificialVariableOffset() {
return getNumObjectiveFunctions() + numDecisionVariables + numSlackVariables;
}
/**
* Get the offset of the right hand side.
* @return offset of the right hand side
*/
protected final int getRhsOffset() {
return getWidth() - 1;
}
/**
* Get the number of decision variables.
* <p>
* If variables are not restricted to positive values, this will include 1 extra decision variable to represent
* the absolute value of the most negative variable.
*
* @return number of decision variables
* @see #getOriginalNumDecisionVariables()
*/
protected final int getNumDecisionVariables() {
return numDecisionVariables;
}
/**
* Get the original number of decision variables.
* @return original number of decision variables
* @see #getNumDecisionVariables()
*/
protected final int getOriginalNumDecisionVariables() {
return f.getCoefficients().getDimension();
}
/**
* Get the number of slack variables.
* @return number of slack variables
*/
protected final int getNumSlackVariables() {
return numSlackVariables;
}
/**
* Get the number of artificial variables.
* @return number of artificial variables
*/
protected final int getNumArtificialVariables() {
return numArtificialVariables;
}
/**
* Get the tableau data.
* @return tableau data
*/
protected final double[][] getData() {
return tableau.getData();
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other instanceof SimplexTableau) {
SimplexTableau rhs = (SimplexTableau) other;
return (restrictToNonNegative == rhs.restrictToNonNegative) &&
(numDecisionVariables == rhs.numDecisionVariables) &&
(numSlackVariables == rhs.numSlackVariables) &&
(numArtificialVariables == rhs.numArtificialVariables) &&
(epsilon == rhs.epsilon) &&
(maxUlps == rhs.maxUlps) &&
f.equals(rhs.f) &&
constraints.equals(rhs.constraints) &&
tableau.equals(rhs.tableau);
}
return false;
}
@Override
public int hashCode() {
return Boolean.valueOf(restrictToNonNegative).hashCode() ^
numDecisionVariables ^
numSlackVariables ^
numArtificialVariables ^
Double.valueOf(epsilon).hashCode() ^
maxUlps ^
f.hashCode() ^
constraints.hashCode() ^
tableau.hashCode();
}
/**
* Serialize the instance.
* @param oos stream where object should be written
* @throws IOException if object cannot be written to stream
*/
private void writeObject(ObjectOutputStream oos)
throws IOException {
oos.defaultWriteObject();
MatrixUtils.serializeRealMatrix(tableau, oos);
}
/**
* Deserialize the instance.
* @param ois stream from which the object should be read
* @throws ClassNotFoundException if a class in the stream cannot be found
* @throws IOException if object cannot be read from the stream
*/
private void readObject(ObjectInputStream ois)
throws ClassNotFoundException, IOException {
ois.defaultReadObject();
MatrixUtils.deserializeRealMatrix(this, "tableau", ois);
}
}

View File

@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
import org.apache.commons.math3.exception.MathIllegalStateException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
/**
* This class represents exceptions thrown by optimizers when a solution escapes to infinity.
*
* @version $Id: UnboundedSolutionException.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class UnboundedSolutionException extends MathIllegalStateException {
/** Serializable version identifier. */
private static final long serialVersionUID = 940539497277290619L;
/**
* Simple constructor using a default message.
*/
public UnboundedSolutionException() {
super(LocalizedFormats.UNBOUNDED_SOLUTION);
}
}

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
/**
* Optimization algorithms for linear constrained problems.
*/

View File

@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
/**
* Base class for implementing optimizers for multivariate scalar
* differentiable functions.
* It contains boiler-plate code for dealing with gradient evaluation.
*
* @version $Id$
* @since 3.1
*/
public abstract class GradientMultivariateOptimizer
extends MultivariateOptimizer {
/**
* Gradient of the objective function.
*/
private MultivariateVectorFunction gradient;
/**
* @param checker Convergence checker.
*/
protected GradientMultivariateOptimizer(ConvergenceChecker<PointValuePair> checker) {
super(checker);
}
/**
* Compute the gradient vector.
*
* @param params Point at which the gradient must be evaluated.
* @return the gradient at the specified point.
*/
protected double[] computeObjectiveGradient(final double[] params) {
return gradient.value(params);
}
/**
* {@inheritDoc}
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
* <li>{@link org.apache.commons.math3.optim.GoalType}</li>
* <li>{@link org.apache.commons.math3.optim.ObjectiveFunction}</li>
* <li>{@link ObjectiveFunctionGradient}</li>
* </ul>
* @return {@inheritDoc}
* @throws TooManyEvaluationsException if the maximal number of
* evaluations (of the objective function) is exceeded.
*/
@Override
public PointValuePair optimize(OptimizationData... optData)
throws TooManyEvaluationsException {
// Retrieve settings.
parseOptimizationData(optData);
// Set up base class and perform computation.
return super.optimize(optData);
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link ObjectiveFunction}</li>
* </ul>
*/
private void parseOptimizationData(OptimizationData... optData) {
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData) {
if (data instanceof ObjectiveFunctionGradient) {
gradient = ((ObjectiveFunctionGradient) data).getObjectiveFunctionGradient();
// If more data must be parsed, this statement _must_ be
// changed to "continue".
break;
}
}
}
}

View File

@ -0,0 +1,187 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.linear.RealMatrix;
/**
* This class converts
* {@link MultivariateVectorFunction vectorial objective functions} to
* {@link MultivariateFunction scalar objective functions}
* when the goal is to minimize them.
* <br/>
* This class is mostly used when the vectorial objective function represents
* a theoretical result computed from a point set applied to a model and
* the models point must be adjusted to fit the theoretical result to some
* reference observations. The observations may be obtained for example from
* physical measurements whether the model is built from theoretical
* considerations.
* <br/>
* This class computes a possibly weighted squared sum of the residuals, which is
* a scalar value. The residuals are the difference between the theoretical model
* (i.e. the output of the vectorial objective function) and the observations. The
* class implements the {@link MultivariateFunction} interface and can therefore be
* minimized by any optimizer supporting scalar objectives functions.This is one way
* to perform a least square estimation. There are other ways to do this without using
* this converter, as some optimization algorithms directly support vectorial objective
* functions.
* <br/>
* This class support combination of residuals with or without weights and correlations.
*
* @see MultivariateFunction
* @see MultivariateVectorFunction
* @version $Id: LeastSquaresConverter.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class LeastSquaresConverter implements MultivariateFunction {
/** Underlying vectorial function. */
private final MultivariateVectorFunction function;
/** Observations to be compared to objective function to compute residuals. */
private final double[] observations;
/** Optional weights for the residuals. */
private final double[] weights;
/** Optional scaling matrix (weight and correlations) for the residuals. */
private final RealMatrix scale;
/**
* Builds a simple converter for uncorrelated residuals with identical
* weights.
*
* @param function vectorial residuals function to wrap
* @param observations observations to be compared to objective function to compute residuals
*/
public LeastSquaresConverter(final MultivariateVectorFunction function,
final double[] observations) {
this.function = function;
this.observations = observations.clone();
this.weights = null;
this.scale = null;
}
/**
* Builds a simple converter for uncorrelated residuals with the
* specified weights.
* <p>
* The scalar objective function value is computed as:
* <pre>
* objective = &sum;weight<sub>i</sub>(observation<sub>i</sub>-objective<sub>i</sub>)<sup>2</sup>
* </pre>
* </p>
* <p>
* Weights can be used for example to combine residuals with different standard
* deviations. As an example, consider a residuals array in which even elements
* are angular measurements in degrees with a 0.01&deg; standard deviation and
* odd elements are distance measurements in meters with a 15m standard deviation.
* In this case, the weights array should be initialized with value
* 1.0/(0.01<sup>2</sup>) in the even elements and 1.0/(15.0<sup>2</sup>) in the
* odd elements (i.e. reciprocals of variances).
* </p>
* <p>
* The array computed by the objective function, the observations array and the
* weights array must have consistent sizes or a {@link DimensionMismatchException}
* will be triggered while computing the scalar objective.
* </p>
*
* @param function vectorial residuals function to wrap
* @param observations observations to be compared to objective function to compute residuals
* @param weights weights to apply to the residuals
* @throws DimensionMismatchException if the observations vector and the weights
* vector dimensions do not match (objective function dimension is checked only when
* the {@link #value(double[])} method is called)
*/
public LeastSquaresConverter(final MultivariateVectorFunction function,
final double[] observations,
final double[] weights) {
if (observations.length != weights.length) {
throw new DimensionMismatchException(observations.length, weights.length);
}
this.function = function;
this.observations = observations.clone();
this.weights = weights.clone();
this.scale = null;
}
/**
* Builds a simple converter for correlated residuals with the
* specified weights.
* <p>
* The scalar objective function value is computed as:
* <pre>
* objective = y<sup>T</sup>y with y = scale&times;(observation-objective)
* </pre>
* </p>
* <p>
* The array computed by the objective function, the observations array and the
* the scaling matrix must have consistent sizes or a {@link DimensionMismatchException}
* will be triggered while computing the scalar objective.
* </p>
*
* @param function vectorial residuals function to wrap
* @param observations observations to be compared to objective function to compute residuals
* @param scale scaling matrix
* @throws DimensionMismatchException if the observations vector and the scale
* matrix dimensions do not match (objective function dimension is checked only when
* the {@link #value(double[])} method is called)
*/
public LeastSquaresConverter(final MultivariateVectorFunction function,
final double[] observations,
final RealMatrix scale) {
if (observations.length != scale.getColumnDimension()) {
throw new DimensionMismatchException(observations.length, scale.getColumnDimension());
}
this.function = function;
this.observations = observations.clone();
this.weights = null;
this.scale = scale.copy();
}
/** {@inheritDoc} */
public double value(final double[] point) {
// compute residuals
final double[] residuals = function.value(point);
if (residuals.length != observations.length) {
throw new DimensionMismatchException(residuals.length, observations.length);
}
for (int i = 0; i < residuals.length; ++i) {
residuals[i] -= observations[i];
}
// compute sum of squares
double sumSquares = 0;
if (weights != null) {
for (int i = 0; i < residuals.length; ++i) {
final double ri = residuals[i];
sumSquares += weights[i] * ri * ri;
}
} else if (scale != null) {
for (final double yi : scale.operate(residuals)) {
sumSquares += yi * yi;
}
} else {
for (final double ri : residuals) {
sumSquares += ri * ri;
}
}
return sumSquares;
}
}

View File

@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar;
import java.util.Collections;
import java.util.List;
import java.util.ArrayList;
import java.util.Comparator;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NullArgumentException;
import org.apache.commons.math3.random.RandomVectorGenerator;
import org.apache.commons.math3.optim.BaseMultiStartMultivariateOptimizer;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.GoalType;
/**
* Multi-start optimizer.
*
* This class wraps an optimizer in order to use it several times in
* turn with different starting points (trying to avoid being trapped
* in a local extremum when looking for a global one).
*
* @version $Id$
* @since 3.0
*/
public class MultiStartMultivariateOptimizer
extends BaseMultiStartMultivariateOptimizer<PointValuePair> {
/** Underlying optimizer. */
private final MultivariateOptimizer optimizer;
/** Found optima. */
private final List<PointValuePair> optima = new ArrayList<PointValuePair>();
/**
* Create a multi-start optimizer from a single-start optimizer.
*
* @param optimizer Single-start optimizer to wrap.
* @param starts Number of starts to perform.
* If {@code starts == 1}, the result will be same as if {@code optimizer}
* is called directly.
* @param generator Random vector generator to use for restarts.
* @throws NullArgumentException if {@code optimizer} or {@code generator}
* is {@code null}.
* @throws NotStrictlyPositiveException if {@code starts < 1}.
*/
public MultiStartMultivariateOptimizer(final MultivariateOptimizer optimizer,
final int starts,
final RandomVectorGenerator generator)
throws NullArgumentException,
NotStrictlyPositiveException {
super(optimizer, starts, generator);
this.optimizer = optimizer;
}
/**
* {@inheritDoc}
*/
@Override
public PointValuePair[] getOptima() {
Collections.sort(optima, getPairComparator());
return optima.toArray(new PointValuePair[0]);
}
/**
* {@inheritDoc}
*/
@Override
protected void store(PointValuePair optimum) {
optima.add(optimum);
}
/**
* {@inheritDoc}
*/
@Override
protected void clear() {
optima.clear();
}
/**
* @return a comparator for sorting the optima.
*/
private Comparator<PointValuePair> getPairComparator() {
return new Comparator<PointValuePair>() {
public int compare(final PointValuePair o1,
final PointValuePair o2) {
if (o1 == null) {
return (o2 == null) ? 0 : 1;
} else if (o2 == null) {
return -1;
}
final double v1 = o1.getValue();
final double v2 = o2.getValue();
return (optimizer.getGoalType() == GoalType.MINIMIZE) ?
Double.compare(v1, v2) : Double.compare(v2, v1);
}
};
}
}

View File

@ -0,0 +1,295 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.function.Logit;
import org.apache.commons.math3.analysis.function.Sigmoid;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.MathUtils;
/**
* <p>Adapter for mapping bounded {@link MultivariateFunction} to unbounded ones.</p>
*
* <p>
* This adapter can be used to wrap functions subject to simple bounds on
* parameters so they can be used by optimizers that do <em>not</em> directly
* support simple bounds.
* </p>
* <p>
* The principle is that the user function that will be wrapped will see its
* parameters bounded as required, i.e when its {@code value} method is called
* with argument array {@code point}, the elements array will fulfill requirement
* {@code lower[i] <= point[i] <= upper[i]} for all i. Some of the components
* may be unbounded or bounded only on one side if the corresponding bound is
* set to an infinite value. The optimizer will not manage the user function by
* itself, but it will handle this adapter and it is this adapter that will take
* care the bounds are fulfilled. The adapter {@link #value(double[])} method will
* be called by the optimizer with unbound parameters, and the adapter will map
* the unbounded value to the bounded range using appropriate functions like
* {@link Sigmoid} for double bounded elements for example.
* </p>
* <p>
* As the optimizer sees only unbounded parameters, it should be noted that the
* start point or simplex expected by the optimizer should be unbounded, so the
* user is responsible for converting his bounded point to unbounded by calling
* {@link #boundedToUnbounded(double[])} before providing them to the optimizer.
* For the same reason, the point returned by the {@link
* org.apache.commons.math3.optimization.BaseMultivariateOptimizer#optimize(int,
* MultivariateFunction, org.apache.commons.math3.optimization.GoalType, double[])}
* method is unbounded. So to convert this point to bounded, users must call
* {@link #unboundedToBounded(double[])} by themselves!</p>
* <p>
* This adapter is only a poor man solution to simple bounds optimization constraints
* that can be used with simple optimizers like
* {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
* SimplexOptimizer}.
* A better solution is to use an optimizer that directly supports simple bounds like
* {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer
* CMAESOptimizer} or
* {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer
* BOBYQAOptimizer}.
* One caveat of this poor-man's solution is that behavior near the bounds may be
* numerically unstable as bounds are mapped from infinite values.
* Another caveat is that convergence values are evaluated by the optimizer with
* respect to unbounded variables, so there will be scales differences when
* converted to bounded variables.
* </p>
*
* @see MultivariateFunctionPenaltyAdapter
*
* @version $Id: MultivariateFunctionMappingAdapter.java 1416643 2012-12-03 19:37:14Z tn $
* @since 3.0
*/
public class MultivariateFunctionMappingAdapter
implements MultivariateFunction {
/** Underlying bounded function. */
private final MultivariateFunction bounded;
/** Mapping functions. */
private final Mapper[] mappers;
/** Simple constructor.
* @param bounded bounded function
* @param lower lower bounds for each element of the input parameters array
* (some elements may be set to {@code Double.NEGATIVE_INFINITY} for
* unbounded values)
* @param upper upper bounds for each element of the input parameters array
* (some elements may be set to {@code Double.POSITIVE_INFINITY} for
* unbounded values)
* @exception DimensionMismatchException if lower and upper bounds are not
* consistent, either according to dimension or to values
*/
public MultivariateFunctionMappingAdapter(final MultivariateFunction bounded,
final double[] lower, final double[] upper) {
// safety checks
MathUtils.checkNotNull(lower);
MathUtils.checkNotNull(upper);
if (lower.length != upper.length) {
throw new DimensionMismatchException(lower.length, upper.length);
}
for (int i = 0; i < lower.length; ++i) {
// note the following test is written in such a way it also fails for NaN
if (!(upper[i] >= lower[i])) {
throw new NumberIsTooSmallException(upper[i], lower[i], true);
}
}
this.bounded = bounded;
this.mappers = new Mapper[lower.length];
for (int i = 0; i < mappers.length; ++i) {
if (Double.isInfinite(lower[i])) {
if (Double.isInfinite(upper[i])) {
// element is unbounded, no transformation is needed
mappers[i] = new NoBoundsMapper();
} else {
// element is simple-bounded on the upper side
mappers[i] = new UpperBoundMapper(upper[i]);
}
} else {
if (Double.isInfinite(upper[i])) {
// element is simple-bounded on the lower side
mappers[i] = new LowerBoundMapper(lower[i]);
} else {
// element is double-bounded
mappers[i] = new LowerUpperBoundMapper(lower[i], upper[i]);
}
}
}
}
/**
* Maps an array from unbounded to bounded.
*
* @param point Unbounded values.
* @return the bounded values.
*/
public double[] unboundedToBounded(double[] point) {
// Map unbounded input point to bounded point.
final double[] mapped = new double[mappers.length];
for (int i = 0; i < mappers.length; ++i) {
mapped[i] = mappers[i].unboundedToBounded(point[i]);
}
return mapped;
}
/**
* Maps an array from bounded to unbounded.
*
* @param point Bounded values.
* @return the unbounded values.
*/
public double[] boundedToUnbounded(double[] point) {
// Map bounded input point to unbounded point.
final double[] mapped = new double[mappers.length];
for (int i = 0; i < mappers.length; ++i) {
mapped[i] = mappers[i].boundedToUnbounded(point[i]);
}
return mapped;
}
/**
* Compute the underlying function value from an unbounded point.
* <p>
* This method simply bounds the unbounded point using the mappings
* set up at construction and calls the underlying function using
* the bounded point.
* </p>
* @param point unbounded value
* @return underlying function value
* @see #unboundedToBounded(double[])
*/
public double value(double[] point) {
return bounded.value(unboundedToBounded(point));
}
/** Mapping interface. */
private interface Mapper {
/**
* Maps a value from unbounded to bounded.
*
* @param y Unbounded value.
* @return the bounded value.
*/
double unboundedToBounded(double y);
/**
* Maps a value from bounded to unbounded.
*
* @param x Bounded value.
* @return the unbounded value.
*/
double boundedToUnbounded(double x);
}
/** Local class for no bounds mapping. */
private static class NoBoundsMapper implements Mapper {
/** {@inheritDoc} */
public double unboundedToBounded(final double y) {
return y;
}
/** {@inheritDoc} */
public double boundedToUnbounded(final double x) {
return x;
}
}
/** Local class for lower bounds mapping. */
private static class LowerBoundMapper implements Mapper {
/** Low bound. */
private final double lower;
/**
* Simple constructor.
*
* @param lower lower bound
*/
public LowerBoundMapper(final double lower) {
this.lower = lower;
}
/** {@inheritDoc} */
public double unboundedToBounded(final double y) {
return lower + FastMath.exp(y);
}
/** {@inheritDoc} */
public double boundedToUnbounded(final double x) {
return FastMath.log(x - lower);
}
}
/** Local class for upper bounds mapping. */
private static class UpperBoundMapper implements Mapper {
/** Upper bound. */
private final double upper;
/** Simple constructor.
* @param upper upper bound
*/
public UpperBoundMapper(final double upper) {
this.upper = upper;
}
/** {@inheritDoc} */
public double unboundedToBounded(final double y) {
return upper - FastMath.exp(-y);
}
/** {@inheritDoc} */
public double boundedToUnbounded(final double x) {
return -FastMath.log(upper - x);
}
}
/** Local class for lower and bounds mapping. */
private static class LowerUpperBoundMapper implements Mapper {
/** Function from unbounded to bounded. */
private final UnivariateFunction boundingFunction;
/** Function from bounded to unbounded. */
private final UnivariateFunction unboundingFunction;
/**
* Simple constructor.
*
* @param lower lower bound
* @param upper upper bound
*/
public LowerUpperBoundMapper(final double lower, final double upper) {
boundingFunction = new Sigmoid(lower, upper);
unboundingFunction = new Logit(lower, upper);
}
/** {@inheritDoc} */
public double unboundedToBounded(final double y) {
return boundingFunction.value(y);
}
/** {@inheritDoc} */
public double boundedToUnbounded(final double x) {
return unboundingFunction.value(x);
}
}
}

View File

@ -0,0 +1,187 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.MathUtils;
/**
* <p>Adapter extending bounded {@link MultivariateFunction} to an unbouded
* domain using a penalty function.</p>
*
* <p>
* This adapter can be used to wrap functions subject to simple bounds on
* parameters so they can be used by optimizers that do <em>not</em> directly
* support simple bounds.
* </p>
* <p>
* The principle is that the user function that will be wrapped will see its
* parameters bounded as required, i.e when its {@code value} method is called
* with argument array {@code point}, the elements array will fulfill requirement
* {@code lower[i] <= point[i] <= upper[i]} for all i. Some of the components
* may be unbounded or bounded only on one side if the corresponding bound is
* set to an infinite value. The optimizer will not manage the user function by
* itself, but it will handle this adapter and it is this adapter that will take
* care the bounds are fulfilled. The adapter {@link #value(double[])} method will
* be called by the optimizer with unbound parameters, and the adapter will check
* if the parameters is within range or not. If it is in range, then the underlying
* user function will be called, and if it is not the value of a penalty function
* will be returned instead.
* </p>
* <p>
* This adapter is only a poor-man's solution to simple bounds optimization
* constraints that can be used with simple optimizers like
* {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
* SimplexOptimizer}.
* A better solution is to use an optimizer that directly supports simple bounds like
* {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer
* CMAESOptimizer} or
* {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer
* BOBYQAOptimizer}.
* One caveat of this poor-man's solution is that if start point or start simplex
* is completely outside of the allowed range, only the penalty function is used,
* and the optimizer may converge without ever entering the range.
* </p>
*
* @see MultivariateFunctionMappingAdapter
*
* @version $Id: MultivariateFunctionPenaltyAdapter.java 1416643 2012-12-03 19:37:14Z tn $
* @since 3.0
*/
public class MultivariateFunctionPenaltyAdapter
implements MultivariateFunction {
/** Underlying bounded function. */
private final MultivariateFunction bounded;
/** Lower bounds. */
private final double[] lower;
/** Upper bounds. */
private final double[] upper;
/** Penalty offset. */
private final double offset;
/** Penalty scales. */
private final double[] scale;
/**
* Simple constructor.
* <p>
* When the optimizer provided points are out of range, the value of the
* penalty function will be used instead of the value of the underlying
* function. In order for this penalty to be effective in rejecting this
* point during the optimization process, the penalty function value should
* be defined with care. This value is computed as:
* <pre>
* penalty(point) = offset + &sum;<sub>i</sub>[scale[i] * &radic;|point[i]-boundary[i]|]
* </pre>
* where indices i correspond to all the components that violates their boundaries.
* </p>
* <p>
* So when attempting a function minimization, offset should be larger than
* the maximum expected value of the underlying function and scale components
* should all be positive. When attempting a function maximization, offset
* should be lesser than the minimum expected value of the underlying function
* and scale components should all be negative.
* minimization, and lesser than the minimum expected value of the underlying
* function when attempting maximization.
* </p>
* <p>
* These choices for the penalty function have two properties. First, all out
* of range points will return a function value that is worse than the value
* returned by any in range point. Second, the penalty is worse for large
* boundaries violation than for small violations, so the optimizer has an hint
* about the direction in which it should search for acceptable points.
* </p>
* @param bounded bounded function
* @param lower lower bounds for each element of the input parameters array
* (some elements may be set to {@code Double.NEGATIVE_INFINITY} for
* unbounded values)
* @param upper upper bounds for each element of the input parameters array
* (some elements may be set to {@code Double.POSITIVE_INFINITY} for
* unbounded values)
* @param offset base offset of the penalty function
* @param scale scale of the penalty function
* @exception DimensionMismatchException if lower bounds, upper bounds and
* scales are not consistent, either according to dimension or to bounadary
* values
*/
public MultivariateFunctionPenaltyAdapter(final MultivariateFunction bounded,
final double[] lower, final double[] upper,
final double offset, final double[] scale) {
// safety checks
MathUtils.checkNotNull(lower);
MathUtils.checkNotNull(upper);
MathUtils.checkNotNull(scale);
if (lower.length != upper.length) {
throw new DimensionMismatchException(lower.length, upper.length);
}
if (lower.length != scale.length) {
throw new DimensionMismatchException(lower.length, scale.length);
}
for (int i = 0; i < lower.length; ++i) {
// note the following test is written in such a way it also fails for NaN
if (!(upper[i] >= lower[i])) {
throw new NumberIsTooSmallException(upper[i], lower[i], true);
}
}
this.bounded = bounded;
this.lower = lower.clone();
this.upper = upper.clone();
this.offset = offset;
this.scale = scale.clone();
}
/**
* Computes the underlying function value from an unbounded point.
* <p>
* This method simply returns the value of the underlying function
* if the unbounded point already fulfills the bounds, and compute
* a replacement value using the offset and scale if bounds are
* violated, without calling the function at all.
* </p>
* @param point unbounded point
* @return either underlying function value or penalty function value
*/
public double value(double[] point) {
for (int i = 0; i < scale.length; ++i) {
if ((point[i] < lower[i]) || (point[i] > upper[i])) {
// bound violation starting at this component
double sum = 0;
for (int j = i; j < scale.length; ++j) {
final double overshoot;
if (point[j] < lower[j]) {
overshoot = scale[j] * (lower[j] - point[j]);
} else if (point[j] > upper[j]) {
overshoot = scale[j] * (point[j] - upper[j]);
} else {
overshoot = 0;
}
sum += FastMath.sqrt(overshoot);
}
return offset + sum;
}
}
// all boundaries are fulfilled, we are in the expected
// domain of the underlying function
return bounded.value(point);
}
}

View File

@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.BaseMultivariateOptimizer;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.ObjectiveFunction;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
/**
* Base class for a multivariate scalar function optimizer.
*
* @version $Id$
* @since 3.1
*/
public abstract class MultivariateOptimizer
extends BaseMultivariateOptimizer<PointValuePair> {
/** Objective function. */
private MultivariateFunction function;
/** Type of optimization. */
private GoalType goal;
/**
* @param checker Convergence checker.
*/
protected MultivariateOptimizer(ConvergenceChecker<PointValuePair> checker) {
super(checker);
}
/**
* {@inheritDoc}
*
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
* <li>{@link ObjectiveFunction}</li>
* <li>{@link GoalType}</li>
* </ul>
* @return {@inheritDoc}
* @throws TooManyEvaluationsException if the maximal number of
* evaluations is exceeded.
*/
@Override
public PointValuePair optimize(OptimizationData... optData)
throws TooManyEvaluationsException {
// Retrieve settings.
parseOptimizationData(optData);
// Set up base class and perform computation.
return super.optimize(optData);
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link ObjectiveFunction}</li>
* <li>{@link GoalType}</li>
* </ul>
*/
private void parseOptimizationData(OptimizationData... optData) {
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData) {
if (data instanceof GoalType) {
goal = (GoalType) data;
continue;
}
if (data instanceof ObjectiveFunction) {
function = ((ObjectiveFunction) data).getObjectiveFunction();
continue;
}
}
}
/**
* @return the optimization type.
*/
public GoalType getGoalType() {
return goal;
}
/**
* Computes the objective function value.
* This method <em>must</em> be called by subclasses to enforce the
* evaluation counter limit.
*
* @param params Point at which the objective function must be evaluated.
* @return the objective function value at the specified point.
* @throws TooManyEvaluationsException if the maximal number of
* evaluations is exceeded.
*/
protected double computeObjectiveValue(double[] params) {
super.incrementEvaluationCount();
return function.value(params);
}
}

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.optim.OptimizationData;
/**
* Gradient of the scalar function to be optimized.
*
* @version $Id$
* @since 3.1
*/
public class ObjectiveFunctionGradient implements OptimizationData {
/** Function to be optimized. */
private final MultivariateVectorFunction gradient;
/**
* @param g Gradient of the function to be optimized.
*/
public ObjectiveFunctionGradient(MultivariateVectorFunction g) {
gradient = g;
}
/**
* Gets the gradient of the function to be optimized.
*
* @return the objective function gradient.
*/
public MultivariateVectorFunction getObjectiveFunctionGradient() {
return gradient;
}
}

View File

@ -0,0 +1,393 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.gradient;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.solvers.BrentSolver;
import org.apache.commons.math3.analysis.solvers.UnivariateSolver;
import org.apache.commons.math3.exception.MathInternalError;
import org.apache.commons.math3.exception.MathIllegalStateException;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.GradientMultivariateOptimizer;
import org.apache.commons.math3.util.FastMath;
/**
* Non-linear conjugate gradient optimizer.
* <p>
* This class supports both the Fletcher-Reeves and the Polak-Ribière
* update formulas for the conjugate search directions.
* It also supports optional preconditioning.
* </p>
*
* @version $Id: NonLinearConjugateGradientOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class NonLinearConjugateGradientOptimizer
extends GradientMultivariateOptimizer {
/** Update formula for the beta parameter. */
private final Formula updateFormula;
/** Preconditioner (may be null). */
private final Preconditioner preconditioner;
/** solver to use in the line search (may be null). */
private final UnivariateSolver solver;
/** Initial step used to bracket the optimum in line search. */
private double initialStep = 1;
/**
* Constructor with default {@link BrentSolver line search solver} and
* {@link IdentityPreconditioner preconditioner}.
*
* @param updateFormula formula to use for updating the &beta; parameter,
* must be one of {@link Formula#FLETCHER_REEVES} or
* {@link Formula#POLAK_RIBIERE}.
* @param checker Convergence checker.
*/
public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
ConvergenceChecker<PointValuePair> checker) {
this(updateFormula,
checker,
new BrentSolver(),
new IdentityPreconditioner());
}
/**
* Available choices of update formulas for the updating the parameter
* that is used to compute the successive conjugate search directions.
* For non-linear conjugate gradients, there are
* two formulas:
* <ul>
* <li>Fletcher-Reeves formula</li>
* <li>Polak-Ribière formula</li>
* </ul>
*
* On the one hand, the Fletcher-Reeves formula is guaranteed to converge
* if the start point is close enough of the optimum whether the
* Polak-Ribière formula may not converge in rare cases. On the
* other hand, the Polak-Ribière formula is often faster when it
* does converge. Polak-Ribière is often used.
*
* @since 2.0
*/
public static enum Formula {
/** Fletcher-Reeves formula. */
FLETCHER_REEVES,
/** Polak-Ribière formula. */
POLAK_RIBIERE
}
/**
* The initial step is a factor with respect to the search direction
* (which itself is roughly related to the gradient of the function).
* <br/>
* It is used to find an interval that brackets the optimum in line
* search.
*
* @since 3.1
*/
public static class BracketingStep implements OptimizationData {
/** Initial step. */
private final double initialStep;
/**
* @param step Initial step for the bracket search.
*/
public BracketingStep(double step) {
initialStep = step;
}
/**
* Gets the initial step.
*
* @return the initial step.
*/
public double getBracketingStep() {
return initialStep;
}
}
/**
* Constructor with default {@link IdentityPreconditioner preconditioner}.
*
* @param updateFormula formula to use for updating the &beta; parameter,
* must be one of {@link Formula#FLETCHER_REEVES} or
* {@link Formula#POLAK_RIBIERE}.
* @param checker Convergence checker.
* @param lineSearchSolver Solver to use during line search.
*/
public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
ConvergenceChecker<PointValuePair> checker,
final UnivariateSolver lineSearchSolver) {
this(updateFormula,
checker,
lineSearchSolver,
new IdentityPreconditioner());
}
/**
* @param updateFormula formula to use for updating the &beta; parameter,
* must be one of {@link Formula#FLETCHER_REEVES} or
* {@link Formula#POLAK_RIBIERE}.
* @param checker Convergence checker.
* @param lineSearchSolver Solver to use during line search.
* @param preconditioner Preconditioner.
*/
public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
ConvergenceChecker<PointValuePair> checker,
final UnivariateSolver lineSearchSolver,
final Preconditioner preconditioner) {
super(checker);
this.updateFormula = updateFormula;
solver = lineSearchSolver;
this.preconditioner = preconditioner;
initialStep = 1;
}
/**
* {@inheritDoc}
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
* <li>{@link org.apache.commons.math3.optim.GoalType}</li>
* <li>{@link org.apache.commons.math3.optim.ObjectiveFunction}</li>
* <li>{@link org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient}</li>
* <li>{@link BracketingStep}</li>
* </ul>
* @return {@inheritDoc}
* @throws TooManyEvaluationsException if the maximal number of
* evaluations (of the objective function) is exceeded.
*/
@Override
public PointValuePair optimize(OptimizationData... optData)
throws TooManyEvaluationsException {
// Retrieve settings.
parseOptimizationData(optData);
// Set up base class and perform computation.
return super.optimize(optData);
}
/** {@inheritDoc} */
@Override
protected PointValuePair doOptimize() {
final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
final double[] point = getStartPoint();
final GoalType goal = getGoalType();
final int n = point.length;
double[] r = computeObjectiveGradient(point);
if (goal == GoalType.MINIMIZE) {
for (int i = 0; i < n; i++) {
r[i] = -r[i];
}
}
// Initial search direction.
double[] steepestDescent = preconditioner.precondition(point, r);
double[] searchDirection = steepestDescent.clone();
double delta = 0;
for (int i = 0; i < n; ++i) {
delta += r[i] * searchDirection[i];
}
PointValuePair current = null;
int iter = 0;
int maxEval = getMaxEvaluations();
while (true) {
++iter;
final double objective = computeObjectiveValue(point);
PointValuePair previous = current;
current = new PointValuePair(point, objective);
if (previous != null) {
if (checker.converged(iter, previous, current)) {
// We have found an optimum.
return current;
}
}
// Find the optimal step in the search direction.
final UnivariateFunction lsf = new LineSearchFunction(point, searchDirection);
final double uB = findUpperBound(lsf, 0, initialStep);
// XXX Last parameters is set to a value close to zero in order to
// work around the divergence problem in the "testCircleFitting"
// unit test (see MATH-439).
final double step = solver.solve(maxEval, lsf, 0, uB, 1e-15);
maxEval -= solver.getEvaluations(); // Subtract used up evaluations.
// Validate new point.
for (int i = 0; i < point.length; ++i) {
point[i] += step * searchDirection[i];
}
r = computeObjectiveGradient(point);
if (goal == GoalType.MINIMIZE) {
for (int i = 0; i < n; ++i) {
r[i] = -r[i];
}
}
// Compute beta.
final double deltaOld = delta;
final double[] newSteepestDescent = preconditioner.precondition(point, r);
delta = 0;
for (int i = 0; i < n; ++i) {
delta += r[i] * newSteepestDescent[i];
}
final double beta;
switch (updateFormula) {
case FLETCHER_REEVES:
beta = delta / deltaOld;
break;
case POLAK_RIBIERE:
double deltaMid = 0;
for (int i = 0; i < r.length; ++i) {
deltaMid += r[i] * steepestDescent[i];
}
beta = (delta - deltaMid) / deltaOld;
break;
default:
// Should never happen.
throw new MathInternalError();
}
steepestDescent = newSteepestDescent;
// Compute conjugate search direction.
if (iter % n == 0 ||
beta < 0) {
// Break conjugation: reset search direction.
searchDirection = steepestDescent.clone();
} else {
// Compute new conjugate search direction.
for (int i = 0; i < n; ++i) {
searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
}
}
}
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link InitialStep}</li>
* </ul>
*/
private void parseOptimizationData(OptimizationData... optData) {
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData) {
if (data instanceof BracketingStep) {
initialStep = ((BracketingStep) data).getBracketingStep();
// If more data must be parsed, this statement _must_ be
// changed to "continue".
break;
}
}
}
/**
* Finds the upper bound b ensuring bracketing of a root between a and b.
*
* @param f function whose root must be bracketed.
* @param a lower bound of the interval.
* @param h initial step to try.
* @return b such that f(a) and f(b) have opposite signs.
* @throws MathIllegalStateException if no bracket can be found.
*/
private double findUpperBound(final UnivariateFunction f,
final double a, final double h) {
final double yA = f.value(a);
double yB = yA;
for (double step = h; step < Double.MAX_VALUE; step *= FastMath.max(2, yA / yB)) {
final double b = a + step;
yB = f.value(b);
if (yA * yB <= 0) {
return b;
}
}
throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
}
/** Default identity preconditioner. */
public static class IdentityPreconditioner implements Preconditioner {
/** {@inheritDoc} */
public double[] precondition(double[] variables, double[] r) {
return r.clone();
}
}
/**
* Internal class for line search.
* <p>
* The function represented by this class is the dot product of
* the objective function gradient and the search direction. Its
* value is zero when the gradient is orthogonal to the search
* direction, i.e. when the objective function value is a local
* extremum along the search direction.
* </p>
*/
private class LineSearchFunction implements UnivariateFunction {
/** Current point. */
private final double[] currentPoint;
/** Search direction. */
private final double[] searchDirection;
/**
* @param point Current point.
* @param direction Search direction.
*/
public LineSearchFunction(double[] point,
double[] direction) {
currentPoint = point.clone();
searchDirection = direction.clone();
}
/** {@inheritDoc} */
public double value(double x) {
// current point in the search direction
final double[] shiftedPoint = currentPoint.clone();
for (int i = 0; i < shiftedPoint.length; ++i) {
shiftedPoint[i] += x * searchDirection[i];
}
// gradient of the objective function
final double[] gradient = computeObjectiveGradient(shiftedPoint);
// dot product with the search direction
double dotProduct = 0;
for (int i = 0; i < gradient.length; ++i) {
dotProduct += gradient[i] * searchDirection[i];
}
return dotProduct;
}
}
}

View File

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.gradient;
/**
* This interface represents a preconditioner for differentiable scalar
* objective function optimizers.
* @version $Id: Preconditioner.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public interface Preconditioner {
/**
* Precondition a search direction.
* <p>
* The returned preconditioned search direction must be computed fast or
* the algorithm performances will drop drastically. A classical approach
* is to compute only the diagonal elements of the hessian and to divide
* the raw search direction by these elements if they are all positive.
* If at least one of them is negative, it is safer to return a clone of
* the raw search direction as if the hessian was the identity matrix. The
* rationale for this simplified choice is that a negative diagonal element
* means the current point is far from the optimum and preconditioning will
* not be efficient anyway in this case.
* </p>
* @param point current point at which the search direction was computed
* @param r raw search direction (i.e. opposite of the gradient)
* @return approximation of H<sup>-1</sup>r where H is the objective function hessian
*/
double[] precondition(double[] point, double[] r);
}

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.gradient;
/**
* This package provides optimization algorithms that require derivatives.
*/

View File

@ -0,0 +1,346 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
import java.util.Arrays;
import java.util.Comparator;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.ZeroException;
import org.apache.commons.math3.exception.OutOfRangeException;
import org.apache.commons.math3.exception.NullArgumentException;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.OptimizationData;
/**
* This class implements the simplex concept.
* It is intended to be used in conjunction with {@link SimplexOptimizer}.
* <br/>
* The initial configuration of the simplex is set by the constructors
* {@link #AbstractSimplex(double[])} or {@link #AbstractSimplex(double[][])}.
* The other {@link #AbstractSimplex(int) constructor} will set all steps
* to 1, thus building a default configuration from a unit hypercube.
* <br/>
* Users <em>must</em> call the {@link #build(double[]) build} method in order
* to create the data structure that will be acted on by the other methods of
* this class.
*
* @see SimplexOptimizer
* @version $Id: AbstractSimplex.java 1397759 2012-10-13 01:12:58Z erans $
* @since 3.0
*/
public abstract class AbstractSimplex implements OptimizationData {
/** Simplex. */
private PointValuePair[] simplex;
/** Start simplex configuration. */
private double[][] startConfiguration;
/** Simplex dimension (must be equal to {@code simplex.length - 1}). */
private final int dimension;
/**
* Build a unit hypercube simplex.
*
* @param n Dimension of the simplex.
*/
protected AbstractSimplex(int n) {
this(n, 1d);
}
/**
* Build a hypercube simplex with the given side length.
*
* @param n Dimension of the simplex.
* @param sideLength Length of the sides of the hypercube.
*/
protected AbstractSimplex(int n,
double sideLength) {
this(createHypercubeSteps(n, sideLength));
}
/**
* The start configuration for simplex is built from a box parallel to
* the canonical axes of the space. The simplex is the subset of vertices
* of a box parallel to the canonical axes. It is built as the path followed
* while traveling from one vertex of the box to the diagonally opposite
* vertex moving only along the box edges. The first vertex of the box will
* be located at the start point of the optimization.
* As an example, in dimension 3 a simplex has 4 vertices. Setting the
* steps to (1, 10, 2) and the start point to (1, 1, 1) would imply the
* start simplex would be: { (1, 1, 1), (2, 1, 1), (2, 11, 1), (2, 11, 3) }.
* The first vertex would be set to the start point at (1, 1, 1) and the
* last vertex would be set to the diagonally opposite vertex at (2, 11, 3).
*
* @param steps Steps along the canonical axes representing box edges. They
* may be negative but not zero.
* @throws NullArgumentException if {@code steps} is {@code null}.
* @throws ZeroException if one of the steps is zero.
*/
protected AbstractSimplex(final double[] steps) {
if (steps == null) {
throw new NullArgumentException();
}
if (steps.length == 0) {
throw new ZeroException();
}
dimension = steps.length;
// Only the relative position of the n final vertices with respect
// to the first one are stored.
startConfiguration = new double[dimension][dimension];
for (int i = 0; i < dimension; i++) {
final double[] vertexI = startConfiguration[i];
for (int j = 0; j < i + 1; j++) {
if (steps[j] == 0) {
throw new ZeroException(LocalizedFormats.EQUAL_VERTICES_IN_SIMPLEX);
}
System.arraycopy(steps, 0, vertexI, 0, j + 1);
}
}
}
/**
* The real initial simplex will be set up by moving the reference
* simplex such that its first point is located at the start point of the
* optimization.
*
* @param referenceSimplex Reference simplex.
* @throws NotStrictlyPositiveException if the reference simplex does not
* contain at least one point.
* @throws DimensionMismatchException if there is a dimension mismatch
* in the reference simplex.
* @throws IllegalArgumentException if one of its vertices is duplicated.
*/
protected AbstractSimplex(final double[][] referenceSimplex) {
if (referenceSimplex.length <= 0) {
throw new NotStrictlyPositiveException(LocalizedFormats.SIMPLEX_NEED_ONE_POINT,
referenceSimplex.length);
}
dimension = referenceSimplex.length - 1;
// Only the relative position of the n final vertices with respect
// to the first one are stored.
startConfiguration = new double[dimension][dimension];
final double[] ref0 = referenceSimplex[0];
// Loop over vertices.
for (int i = 0; i < referenceSimplex.length; i++) {
final double[] refI = referenceSimplex[i];
// Safety checks.
if (refI.length != dimension) {
throw new DimensionMismatchException(refI.length, dimension);
}
for (int j = 0; j < i; j++) {
final double[] refJ = referenceSimplex[j];
boolean allEquals = true;
for (int k = 0; k < dimension; k++) {
if (refI[k] != refJ[k]) {
allEquals = false;
break;
}
}
if (allEquals) {
throw new MathIllegalArgumentException(LocalizedFormats.EQUAL_VERTICES_IN_SIMPLEX,
i, j);
}
}
// Store vertex i position relative to vertex 0 position.
if (i > 0) {
final double[] confI = startConfiguration[i - 1];
for (int k = 0; k < dimension; k++) {
confI[k] = refI[k] - ref0[k];
}
}
}
}
/**
* Get simplex dimension.
*
* @return the dimension of the simplex.
*/
public int getDimension() {
return dimension;
}
/**
* Get simplex size.
* After calling the {@link #build(double[]) build} method, this method will
* will be equivalent to {@code getDimension() + 1}.
*
* @return the size of the simplex.
*/
public int getSize() {
return simplex.length;
}
/**
* Compute the next simplex of the algorithm.
*
* @param evaluationFunction Evaluation function.
* @param comparator Comparator to use to sort simplex vertices from best
* to worst.
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
* if the algorithm fails to converge.
*/
public abstract void iterate(final MultivariateFunction evaluationFunction,
final Comparator<PointValuePair> comparator);
/**
* Build an initial simplex.
*
* @param startPoint First point of the simplex.
* @throws DimensionMismatchException if the start point does not match
* simplex dimension.
*/
public void build(final double[] startPoint) {
if (dimension != startPoint.length) {
throw new DimensionMismatchException(dimension, startPoint.length);
}
// Set first vertex.
simplex = new PointValuePair[dimension + 1];
simplex[0] = new PointValuePair(startPoint, Double.NaN);
// Set remaining vertices.
for (int i = 0; i < dimension; i++) {
final double[] confI = startConfiguration[i];
final double[] vertexI = new double[dimension];
for (int k = 0; k < dimension; k++) {
vertexI[k] = startPoint[k] + confI[k];
}
simplex[i + 1] = new PointValuePair(vertexI, Double.NaN);
}
}
/**
* Evaluate all the non-evaluated points of the simplex.
*
* @param evaluationFunction Evaluation function.
* @param comparator Comparator to use to sort simplex vertices from best to worst.
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
* if the maximal number of evaluations is exceeded.
*/
public void evaluate(final MultivariateFunction evaluationFunction,
final Comparator<PointValuePair> comparator) {
// Evaluate the objective function at all non-evaluated simplex points.
for (int i = 0; i < simplex.length; i++) {
final PointValuePair vertex = simplex[i];
final double[] point = vertex.getPointRef();
if (Double.isNaN(vertex.getValue())) {
simplex[i] = new PointValuePair(point, evaluationFunction.value(point), false);
}
}
// Sort the simplex from best to worst.
Arrays.sort(simplex, comparator);
}
/**
* Replace the worst point of the simplex by a new point.
*
* @param pointValuePair Point to insert.
* @param comparator Comparator to use for sorting the simplex vertices
* from best to worst.
*/
protected void replaceWorstPoint(PointValuePair pointValuePair,
final Comparator<PointValuePair> comparator) {
for (int i = 0; i < dimension; i++) {
if (comparator.compare(simplex[i], pointValuePair) > 0) {
PointValuePair tmp = simplex[i];
simplex[i] = pointValuePair;
pointValuePair = tmp;
}
}
simplex[dimension] = pointValuePair;
}
/**
* Get the points of the simplex.
*
* @return all the simplex points.
*/
public PointValuePair[] getPoints() {
final PointValuePair[] copy = new PointValuePair[simplex.length];
System.arraycopy(simplex, 0, copy, 0, simplex.length);
return copy;
}
/**
* Get the simplex point stored at the requested {@code index}.
*
* @param index Location.
* @return the point at location {@code index}.
*/
public PointValuePair getPoint(int index) {
if (index < 0 ||
index >= simplex.length) {
throw new OutOfRangeException(index, 0, simplex.length - 1);
}
return simplex[index];
}
/**
* Store a new point at location {@code index}.
* Note that no deep-copy of {@code point} is performed.
*
* @param index Location.
* @param point New value.
*/
protected void setPoint(int index, PointValuePair point) {
if (index < 0 ||
index >= simplex.length) {
throw new OutOfRangeException(index, 0, simplex.length - 1);
}
simplex[index] = point;
}
/**
* Replace all points.
* Note that no deep-copy of {@code points} is performed.
*
* @param points New Points.
*/
protected void setPoints(PointValuePair[] points) {
if (points.length != simplex.length) {
throw new DimensionMismatchException(points.length, simplex.length);
}
simplex = points;
}
/**
* Create steps for a unit hypercube.
*
* @param n Dimension of the hypercube.
* @param sideLength Length of the sides of the hypercube.
* @return the steps.
*/
private static double[] createHypercubeSteps(int n,
double sideLength) {
final double[] steps = new double[n];
for (int i = 0; i < n; i++) {
steps[i] = sideLength;
}
return steps;
}
}

View File

@ -0,0 +1,216 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
import java.util.Comparator;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.PointValuePair;
/**
* This class implements the multi-directional direct search method.
*
* @version $Id: MultiDirectionalSimplex.java 1364392 2012-07-22 18:27:12Z tn $
* @since 3.0
*/
public class MultiDirectionalSimplex extends AbstractSimplex {
/** Default value for {@link #khi}: {@value}. */
private static final double DEFAULT_KHI = 2;
/** Default value for {@link #gamma}: {@value}. */
private static final double DEFAULT_GAMMA = 0.5;
/** Expansion coefficient. */
private final double khi;
/** Contraction coefficient. */
private final double gamma;
/**
* Build a multi-directional simplex with default coefficients.
* The default values are 2.0 for khi and 0.5 for gamma.
*
* @param n Dimension of the simplex.
*/
public MultiDirectionalSimplex(final int n) {
this(n, 1d);
}
/**
* Build a multi-directional simplex with default coefficients.
* The default values are 2.0 for khi and 0.5 for gamma.
*
* @param n Dimension of the simplex.
* @param sideLength Length of the sides of the default (hypercube)
* simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
*/
public MultiDirectionalSimplex(final int n, double sideLength) {
this(n, sideLength, DEFAULT_KHI, DEFAULT_GAMMA);
}
/**
* Build a multi-directional simplex with specified coefficients.
*
* @param n Dimension of the simplex. See
* {@link AbstractSimplex#AbstractSimplex(int,double)}.
* @param khi Expansion coefficient.
* @param gamma Contraction coefficient.
*/
public MultiDirectionalSimplex(final int n,
final double khi, final double gamma) {
this(n, 1d, khi, gamma);
}
/**
* Build a multi-directional simplex with specified coefficients.
*
* @param n Dimension of the simplex. See
* {@link AbstractSimplex#AbstractSimplex(int,double)}.
* @param sideLength Length of the sides of the default (hypercube)
* simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
* @param khi Expansion coefficient.
* @param gamma Contraction coefficient.
*/
public MultiDirectionalSimplex(final int n, double sideLength,
final double khi, final double gamma) {
super(n, sideLength);
this.khi = khi;
this.gamma = gamma;
}
/**
* Build a multi-directional simplex with default coefficients.
* The default values are 2.0 for khi and 0.5 for gamma.
*
* @param steps Steps along the canonical axes representing box edges.
* They may be negative but not zero. See
*/
public MultiDirectionalSimplex(final double[] steps) {
this(steps, DEFAULT_KHI, DEFAULT_GAMMA);
}
/**
* Build a multi-directional simplex with specified coefficients.
*
* @param steps Steps along the canonical axes representing box edges.
* They may be negative but not zero. See
* {@link AbstractSimplex#AbstractSimplex(double[])}.
* @param khi Expansion coefficient.
* @param gamma Contraction coefficient.
*/
public MultiDirectionalSimplex(final double[] steps,
final double khi, final double gamma) {
super(steps);
this.khi = khi;
this.gamma = gamma;
}
/**
* Build a multi-directional simplex with default coefficients.
* The default values are 2.0 for khi and 0.5 for gamma.
*
* @param referenceSimplex Reference simplex. See
* {@link AbstractSimplex#AbstractSimplex(double[][])}.
*/
public MultiDirectionalSimplex(final double[][] referenceSimplex) {
this(referenceSimplex, DEFAULT_KHI, DEFAULT_GAMMA);
}
/**
* Build a multi-directional simplex with specified coefficients.
*
* @param referenceSimplex Reference simplex. See
* {@link AbstractSimplex#AbstractSimplex(double[][])}.
* @param khi Expansion coefficient.
* @param gamma Contraction coefficient.
* @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
* if the reference simplex does not contain at least one point.
* @throws org.apache.commons.math3.exception.DimensionMismatchException
* if there is a dimension mismatch in the reference simplex.
*/
public MultiDirectionalSimplex(final double[][] referenceSimplex,
final double khi, final double gamma) {
super(referenceSimplex);
this.khi = khi;
this.gamma = gamma;
}
/** {@inheritDoc} */
@Override
public void iterate(final MultivariateFunction evaluationFunction,
final Comparator<PointValuePair> comparator) {
// Save the original simplex.
final PointValuePair[] original = getPoints();
final PointValuePair best = original[0];
// Perform a reflection step.
final PointValuePair reflected = evaluateNewSimplex(evaluationFunction,
original, 1, comparator);
if (comparator.compare(reflected, best) < 0) {
// Compute the expanded simplex.
final PointValuePair[] reflectedSimplex = getPoints();
final PointValuePair expanded = evaluateNewSimplex(evaluationFunction,
original, khi, comparator);
if (comparator.compare(reflected, expanded) <= 0) {
// Keep the reflected simplex.
setPoints(reflectedSimplex);
}
// Keep the expanded simplex.
return;
}
// Compute the contracted simplex.
evaluateNewSimplex(evaluationFunction, original, gamma, comparator);
}
/**
* Compute and evaluate a new simplex.
*
* @param evaluationFunction Evaluation function.
* @param original Original simplex (to be preserved).
* @param coeff Linear coefficient.
* @param comparator Comparator to use to sort simplex vertices from best
* to poorest.
* @return the best point in the transformed simplex.
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
* if the maximal number of evaluations is exceeded.
*/
private PointValuePair evaluateNewSimplex(final MultivariateFunction evaluationFunction,
final PointValuePair[] original,
final double coeff,
final Comparator<PointValuePair> comparator) {
final double[] xSmallest = original[0].getPointRef();
// Perform a linear transformation on all the simplex points,
// except the first one.
setPoint(0, original[0]);
final int dim = getDimension();
for (int i = 1; i < getSize(); i++) {
final double[] xOriginal = original[i].getPointRef();
final double[] xTransformed = new double[dim];
for (int j = 0; j < dim; j++) {
xTransformed[j] = xSmallest[j] + coeff * (xSmallest[j] - xOriginal[j]);
}
setPoint(i, new PointValuePair(xTransformed, Double.NaN, false));
}
// Evaluate the simplex.
evaluate(evaluationFunction, comparator);
return getPoint(0);
}
}

View File

@ -0,0 +1,281 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
import java.util.Comparator;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.analysis.MultivariateFunction;
/**
* This class implements the Nelder-Mead simplex algorithm.
*
* @version $Id: NelderMeadSimplex.java 1364392 2012-07-22 18:27:12Z tn $
* @since 3.0
*/
public class NelderMeadSimplex extends AbstractSimplex {
/** Default value for {@link #rho}: {@value}. */
private static final double DEFAULT_RHO = 1;
/** Default value for {@link #khi}: {@value}. */
private static final double DEFAULT_KHI = 2;
/** Default value for {@link #gamma}: {@value}. */
private static final double DEFAULT_GAMMA = 0.5;
/** Default value for {@link #sigma}: {@value}. */
private static final double DEFAULT_SIGMA = 0.5;
/** Reflection coefficient. */
private final double rho;
/** Expansion coefficient. */
private final double khi;
/** Contraction coefficient. */
private final double gamma;
/** Shrinkage coefficient. */
private final double sigma;
/**
* Build a Nelder-Mead simplex with default coefficients.
* The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
* for both gamma and sigma.
*
* @param n Dimension of the simplex.
*/
public NelderMeadSimplex(final int n) {
this(n, 1d);
}
/**
* Build a Nelder-Mead simplex with default coefficients.
* The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
* for both gamma and sigma.
*
* @param n Dimension of the simplex.
* @param sideLength Length of the sides of the default (hypercube)
* simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
*/
public NelderMeadSimplex(final int n, double sideLength) {
this(n, sideLength,
DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
}
/**
* Build a Nelder-Mead simplex with specified coefficients.
*
* @param n Dimension of the simplex. See
* {@link AbstractSimplex#AbstractSimplex(int,double)}.
* @param sideLength Length of the sides of the default (hypercube)
* simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
* @param rho Reflection coefficient.
* @param khi Expansion coefficient.
* @param gamma Contraction coefficient.
* @param sigma Shrinkage coefficient.
*/
public NelderMeadSimplex(final int n, double sideLength,
final double rho, final double khi,
final double gamma, final double sigma) {
super(n, sideLength);
this.rho = rho;
this.khi = khi;
this.gamma = gamma;
this.sigma = sigma;
}
/**
* Build a Nelder-Mead simplex with specified coefficients.
*
* @param n Dimension of the simplex. See
* {@link AbstractSimplex#AbstractSimplex(int)}.
* @param rho Reflection coefficient.
* @param khi Expansion coefficient.
* @param gamma Contraction coefficient.
* @param sigma Shrinkage coefficient.
*/
public NelderMeadSimplex(final int n,
final double rho, final double khi,
final double gamma, final double sigma) {
this(n, 1d, rho, khi, gamma, sigma);
}
/**
* Build a Nelder-Mead simplex with default coefficients.
* The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
* for both gamma and sigma.
*
* @param steps Steps along the canonical axes representing box edges.
* They may be negative but not zero. See
*/
public NelderMeadSimplex(final double[] steps) {
this(steps, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
}
/**
* Build a Nelder-Mead simplex with specified coefficients.
*
* @param steps Steps along the canonical axes representing box edges.
* They may be negative but not zero. See
* {@link AbstractSimplex#AbstractSimplex(double[])}.
* @param rho Reflection coefficient.
* @param khi Expansion coefficient.
* @param gamma Contraction coefficient.
* @param sigma Shrinkage coefficient.
* @throws IllegalArgumentException if one of the steps is zero.
*/
public NelderMeadSimplex(final double[] steps,
final double rho, final double khi,
final double gamma, final double sigma) {
super(steps);
this.rho = rho;
this.khi = khi;
this.gamma = gamma;
this.sigma = sigma;
}
/**
* Build a Nelder-Mead simplex with default coefficients.
* The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
* for both gamma and sigma.
*
* @param referenceSimplex Reference simplex. See
* {@link AbstractSimplex#AbstractSimplex(double[][])}.
*/
public NelderMeadSimplex(final double[][] referenceSimplex) {
this(referenceSimplex, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
}
/**
* Build a Nelder-Mead simplex with specified coefficients.
*
* @param referenceSimplex Reference simplex. See
* {@link AbstractSimplex#AbstractSimplex(double[][])}.
* @param rho Reflection coefficient.
* @param khi Expansion coefficient.
* @param gamma Contraction coefficient.
* @param sigma Shrinkage coefficient.
* @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
* if the reference simplex does not contain at least one point.
* @throws org.apache.commons.math3.exception.DimensionMismatchException
* if there is a dimension mismatch in the reference simplex.
*/
public NelderMeadSimplex(final double[][] referenceSimplex,
final double rho, final double khi,
final double gamma, final double sigma) {
super(referenceSimplex);
this.rho = rho;
this.khi = khi;
this.gamma = gamma;
this.sigma = sigma;
}
/** {@inheritDoc} */
@Override
public void iterate(final MultivariateFunction evaluationFunction,
final Comparator<PointValuePair> comparator) {
// The simplex has n + 1 points if dimension is n.
final int n = getDimension();
// Interesting values.
final PointValuePair best = getPoint(0);
final PointValuePair secondBest = getPoint(n - 1);
final PointValuePair worst = getPoint(n);
final double[] xWorst = worst.getPointRef();
// Compute the centroid of the best vertices (dismissing the worst
// point at index n).
final double[] centroid = new double[n];
for (int i = 0; i < n; i++) {
final double[] x = getPoint(i).getPointRef();
for (int j = 0; j < n; j++) {
centroid[j] += x[j];
}
}
final double scaling = 1.0 / n;
for (int j = 0; j < n; j++) {
centroid[j] *= scaling;
}
// compute the reflection point
final double[] xR = new double[n];
for (int j = 0; j < n; j++) {
xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
}
final PointValuePair reflected
= new PointValuePair(xR, evaluationFunction.value(xR), false);
if (comparator.compare(best, reflected) <= 0 &&
comparator.compare(reflected, secondBest) < 0) {
// Accept the reflected point.
replaceWorstPoint(reflected, comparator);
} else if (comparator.compare(reflected, best) < 0) {
// Compute the expansion point.
final double[] xE = new double[n];
for (int j = 0; j < n; j++) {
xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
}
final PointValuePair expanded
= new PointValuePair(xE, evaluationFunction.value(xE), false);
if (comparator.compare(expanded, reflected) < 0) {
// Accept the expansion point.
replaceWorstPoint(expanded, comparator);
} else {
// Accept the reflected point.
replaceWorstPoint(reflected, comparator);
}
} else {
if (comparator.compare(reflected, worst) < 0) {
// Perform an outside contraction.
final double[] xC = new double[n];
for (int j = 0; j < n; j++) {
xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
}
final PointValuePair outContracted
= new PointValuePair(xC, evaluationFunction.value(xC), false);
if (comparator.compare(outContracted, reflected) <= 0) {
// Accept the contraction point.
replaceWorstPoint(outContracted, comparator);
return;
}
} else {
// Perform an inside contraction.
final double[] xC = new double[n];
for (int j = 0; j < n; j++) {
xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
}
final PointValuePair inContracted
= new PointValuePair(xC, evaluationFunction.value(xC), false);
if (comparator.compare(inContracted, worst) < 0) {
// Accept the contraction point.
replaceWorstPoint(inContracted, comparator);
return;
}
}
// Perform a shrink.
final double[] xSmallest = getPoint(0).getPointRef();
for (int i = 1; i <= n; i++) {
final double[] x = getPoint(i).getPoint();
for (int j = 0; j < n; j++) {
x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
}
setPoint(i, new PointValuePair(x, Double.NaN, false));
}
evaluate(evaluationFunction, comparator);
}
}
}

View File

@ -0,0 +1,356 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.MathArrays;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer;
import org.apache.commons.math3.optim.univariate.BracketFinder;
import org.apache.commons.math3.optim.univariate.BrentOptimizer;
import org.apache.commons.math3.optim.univariate.UnivariatePointValuePair;
import org.apache.commons.math3.optim.univariate.SimpleUnivariateValueChecker;
import org.apache.commons.math3.optim.univariate.SearchInterval;
import org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction;
/**
* Powell algorithm.
* This code is translated and adapted from the Python version of this
* algorithm (as implemented in module {@code optimize.py} v0.5 of
* <em>SciPy</em>).
* <br/>
* The default stopping criterion is based on the differences of the
* function value between two successive iterations. It is however possible
* to define a custom convergence checker that might terminate the algorithm
* earlier.
* <br/>
* The internal line search optimizer is a {@link BrentOptimizer} with a
* convergence checker set to {@link SimpleUnivariateValueChecker}.
*
* @version $Id: PowellOptimizer.java 1413594 2012-11-26 13:16:39Z erans $
* @since 2.2
*/
public class PowellOptimizer
extends MultivariateOptimizer {
/**
* Minimum relative tolerance.
*/
private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
/**
* Relative threshold.
*/
private final double relativeThreshold;
/**
* Absolute threshold.
*/
private final double absoluteThreshold;
/**
* Line search.
*/
private final LineSearch line;
/**
* This constructor allows to specify a user-defined convergence checker,
* in addition to the parameters that control the default convergence
* checking procedure.
* <br/>
* The internal line search tolerances are set to the square-root of their
* corresponding value in the multivariate optimizer.
*
* @param rel Relative threshold.
* @param abs Absolute threshold.
* @param checker Convergence checker.
* @throws NotStrictlyPositiveException if {@code abs <= 0}.
* @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
*/
public PowellOptimizer(double rel,
double abs,
ConvergenceChecker<PointValuePair> checker) {
this(rel, abs, FastMath.sqrt(rel), FastMath.sqrt(abs), checker);
}
/**
* This constructor allows to specify a user-defined convergence checker,
* in addition to the parameters that control the default convergence
* checking procedure and the line search tolerances.
*
* @param rel Relative threshold for this optimizer.
* @param abs Absolute threshold for this optimizer.
* @param lineRel Relative threshold for the internal line search optimizer.
* @param lineAbs Absolute threshold for the internal line search optimizer.
* @param checker Convergence checker.
* @throws NotStrictlyPositiveException if {@code abs <= 0}.
* @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
*/
public PowellOptimizer(double rel,
double abs,
double lineRel,
double lineAbs,
ConvergenceChecker<PointValuePair> checker) {
super(checker);
if (rel < MIN_RELATIVE_TOLERANCE) {
throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
}
if (abs <= 0) {
throw new NotStrictlyPositiveException(abs);
}
relativeThreshold = rel;
absoluteThreshold = abs;
// Create the line search optimizer.
line = new LineSearch(lineRel,
lineAbs);
}
/**
* The parameters control the default convergence checking procedure.
* <br/>
* The internal line search tolerances are set to the square-root of their
* corresponding value in the multivariate optimizer.
*
* @param rel Relative threshold.
* @param abs Absolute threshold.
* @throws NotStrictlyPositiveException if {@code abs <= 0}.
* @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
*/
public PowellOptimizer(double rel,
double abs) {
this(rel, abs, null);
}
/**
* Builds an instance with the default convergence checking procedure.
*
* @param rel Relative threshold.
* @param abs Absolute threshold.
* @param lineRel Relative threshold for the internal line search optimizer.
* @param lineAbs Absolute threshold for the internal line search optimizer.
* @throws NotStrictlyPositiveException if {@code abs <= 0}.
* @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
*/
public PowellOptimizer(double rel,
double abs,
double lineRel,
double lineAbs) {
this(rel, abs, lineRel, lineAbs, null);
}
/** {@inheritDoc} */
@Override
protected PointValuePair doOptimize() {
final GoalType goal = getGoalType();
final double[] guess = getStartPoint();
final int n = guess.length;
final double[][] direc = new double[n][n];
for (int i = 0; i < n; i++) {
direc[i][i] = 1;
}
final ConvergenceChecker<PointValuePair> checker
= getConvergenceChecker();
double[] x = guess;
double fVal = computeObjectiveValue(x);
double[] x1 = x.clone();
int iter = 0;
while (true) {
++iter;
double fX = fVal;
double fX2 = 0;
double delta = 0;
int bigInd = 0;
double alphaMin = 0;
for (int i = 0; i < n; i++) {
final double[] d = MathArrays.copyOf(direc[i]);
fX2 = fVal;
final UnivariatePointValuePair optimum = line.search(x, d);
fVal = optimum.getValue();
alphaMin = optimum.getPoint();
final double[][] result = newPointAndDirection(x, d, alphaMin);
x = result[0];
if ((fX2 - fVal) > delta) {
delta = fX2 - fVal;
bigInd = i;
}
}
// Default convergence check.
boolean stop = 2 * (fX - fVal) <=
(relativeThreshold * (FastMath.abs(fX) + FastMath.abs(fVal)) +
absoluteThreshold);
final PointValuePair previous = new PointValuePair(x1, fX);
final PointValuePair current = new PointValuePair(x, fVal);
if (!stop) { // User-defined stopping criteria.
if (checker != null) {
stop = checker.converged(iter, previous, current);
}
}
if (stop) {
if (goal == GoalType.MINIMIZE) {
return (fVal < fX) ? current : previous;
} else {
return (fVal > fX) ? current : previous;
}
}
final double[] d = new double[n];
final double[] x2 = new double[n];
for (int i = 0; i < n; i++) {
d[i] = x[i] - x1[i];
x2[i] = 2 * x[i] - x1[i];
}
x1 = x.clone();
fX2 = computeObjectiveValue(x2);
if (fX > fX2) {
double t = 2 * (fX + fX2 - 2 * fVal);
double temp = fX - fVal - delta;
t *= temp * temp;
temp = fX - fX2;
t -= delta * temp * temp;
if (t < 0.0) {
final UnivariatePointValuePair optimum = line.search(x, d);
fVal = optimum.getValue();
alphaMin = optimum.getPoint();
final double[][] result = newPointAndDirection(x, d, alphaMin);
x = result[0];
final int lastInd = n - 1;
direc[bigInd] = direc[lastInd];
direc[lastInd] = result[1];
}
}
}
}
/**
* Compute a new point (in the original space) and a new direction
* vector, resulting from the line search.
*
* @param p Point used in the line search.
* @param d Direction used in the line search.
* @param optimum Optimum found by the line search.
* @return a 2-element array containing the new point (at index 0) and
* the new direction (at index 1).
*/
private double[][] newPointAndDirection(double[] p,
double[] d,
double optimum) {
final int n = p.length;
final double[] nP = new double[n];
final double[] nD = new double[n];
for (int i = 0; i < n; i++) {
nD[i] = d[i] * optimum;
nP[i] = p[i] + nD[i];
}
final double[][] result = new double[2][];
result[0] = nP;
result[1] = nD;
return result;
}
/**
* Class for finding the minimum of the objective function along a given
* direction.
*/
private class LineSearch extends BrentOptimizer {
/**
* Value that will pass the precondition check for {@link BrentOptimizer}
* but will not pass the convergence check, so that the custom checker
* will always decide when to stop the line search.
*/
private static final double REL_TOL_UNUSED = 1e-15;
/**
* Value that will pass the precondition check for {@link BrentOptimizer}
* but will not pass the convergence check, so that the custom checker
* will always decide when to stop the line search.
*/
private static final double ABS_TOL_UNUSED = Double.MIN_VALUE;
/**
* Automatic bracketing.
*/
private final BracketFinder bracket = new BracketFinder();
/**
* The "BrentOptimizer" default stopping criterion uses the tolerances
* to check the domain (point) values, not the function values.
* We thus create a custom checker to use function values.
*
* @param rel Relative threshold.
* @param abs Absolute threshold.
*/
LineSearch(double rel,
double abs) {
super(REL_TOL_UNUSED,
ABS_TOL_UNUSED,
new SimpleUnivariateValueChecker(rel, abs));
}
/**
* Find the minimum of the function {@code f(p + alpha * d)}.
*
* @param p Starting point.
* @param d Search direction.
* @return the optimum.
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
* if the number of evaluations is exceeded.
*/
public UnivariatePointValuePair search(final double[] p, final double[] d) {
final int n = p.length;
final UnivariateFunction f = new UnivariateFunction() {
public double value(double alpha) {
final double[] x = new double[n];
for (int i = 0; i < n; i++) {
x[i] = p[i] + alpha * d[i];
}
final double obj = PowellOptimizer.this.computeObjectiveValue(x);
return obj;
}
};
final GoalType goal = PowellOptimizer.this.getGoalType();
bracket.search(f, goal, 0, 1);
// Passing "MAX_VALUE" as a dummy value because it is the enclosing
// class that counts the number of evaluations (and will eventually
// generate the exception).
return optimize(new MaxEval(Integer.MAX_VALUE),
new UnivariateObjectiveFunction(f),
goal,
new SearchInterval(bracket.getLo(),
bracket.getHi(),
bracket.getMid()));
}
}
}

View File

@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
import java.util.Comparator;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.exception.NullArgumentException;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleValueChecker;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer;
/**
* This class implements simplex-based direct search optimization.
*
* <p>
* Direct search methods only use objective function values, they do
* not need derivatives and don't either try to compute approximation
* of the derivatives. According to a 1996 paper by Margaret H. Wright
* (<a href="http://cm.bell-labs.com/cm/cs/doc/96/4-02.ps.gz">Direct
* Search Methods: Once Scorned, Now Respectable</a>), they are used
* when either the computation of the derivative is impossible (noisy
* functions, unpredictable discontinuities) or difficult (complexity,
* computation cost). In the first cases, rather than an optimum, a
* <em>not too bad</em> point is desired. In the latter cases, an
* optimum is desired but cannot be reasonably found. In all cases
* direct search methods can be useful.
* </p>
* <p>
* Simplex-based direct search methods are based on comparison of
* the objective function values at the vertices of a simplex (which is a
* set of n+1 points in dimension n) that is updated by the algorithms
* steps.
* <p>
* <p>
* The simplex update procedure ({@link NelderMeadSimplex} or
* {@link MultiDirectionalSimplex}) must be passed to the
* {@code optimize} method.
* </p>
* <p>
* Each call to {@code optimize} will re-use the start configuration of
* the current simplex and move it such that its first vertex is at the
* provided start point of the optimization.
* If the {@code optimize} method is called to solve a different problem
* and the number of parameters change, the simplex must be re-initialized
* to one with the appropriate dimensions.
* </p>
* <p>
* Convergence is checked by providing the <em>worst</em> points of
* previous and current simplex to the convergence checker, not the best
* ones.
* </p>
* <p>
* This simplex optimizer implementation does not directly support constrained
* optimization with simple bounds; so, for such optimizations, either a more
* dedicated algorithm must be used like
* {@link CMAESOptimizer} or {@link BOBYQAOptimizer}, or the objective
* function must be wrapped in an adapter like
* {@link org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionMappingAdapter
* MultivariateFunctionMappingAdapter} or
* {@link org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionPenaltyAdapter
* MultivariateFunctionPenaltyAdapter}.
* </p>
*
* @version $Id: SimplexOptimizer.java 1397759 2012-10-13 01:12:58Z erans $
* @since 3.0
*/
public class SimplexOptimizer extends MultivariateOptimizer {
/** Simplex update rule. */
private AbstractSimplex simplex;
/**
* @param checker Convergence checker.
*/
public SimplexOptimizer(ConvergenceChecker<PointValuePair> checker) {
super(checker);
}
/**
* @param rel Relative threshold.
* @param abs Absolute threshold.
*/
public SimplexOptimizer(double rel, double abs) {
this(new SimpleValueChecker(rel, abs));
}
/**
* {@inheritDoc}
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
* <li>{@link AbstractSimplex}</li>
* </ul>
* @return {@inheritDoc}
*/
@Override
public PointValuePair optimize(OptimizationData... optData) {
// Retrieve settings
parseOptimizationData(optData);
// Set up base class and perform computation.
return super.optimize(optData);
}
/** {@inheritDoc} */
@Override
protected PointValuePair doOptimize() {
if (simplex == null) {
throw new NullArgumentException();
}
// Indirect call to "computeObjectiveValue" in order to update the
// evaluations counter.
final MultivariateFunction evalFunc
= new MultivariateFunction() {
public double value(double[] point) {
return computeObjectiveValue(point);
}
};
final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
final Comparator<PointValuePair> comparator
= new Comparator<PointValuePair>() {
public int compare(final PointValuePair o1,
final PointValuePair o2) {
final double v1 = o1.getValue();
final double v2 = o2.getValue();
return isMinim ? Double.compare(v1, v2) : Double.compare(v2, v1);
}
};
// Initialize search.
simplex.build(getStartPoint());
simplex.evaluate(evalFunc, comparator);
PointValuePair[] previous = null;
int iteration = 0;
final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
while (true) {
if (iteration > 0) {
boolean converged = true;
for (int i = 0; i < simplex.getSize(); i++) {
PointValuePair prev = previous[i];
converged = converged &&
checker.converged(iteration, prev, simplex.getPoint(i));
}
if (converged) {
// We have found an optimum.
return simplex.getPoint(0);
}
}
// We still need to search.
previous = simplex.getPoints();
simplex.iterate(evalFunc, comparator);
++iteration;
}
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link AbstractSimplex}</li>
* </ul>
*/
private void parseOptimizationData(OptimizationData... optData) {
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData) {
if (data instanceof AbstractSimplex) {
simplex = (AbstractSimplex) data;
// If more data must be parsed, this statement _must_ be
// changed to "continue".
break;
}
}
}
}

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
/**
* This package provides optimization algorithms that do not require derivatives.
*/

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar;
/**
* Algorithms for optimizing a scalar function.
*/

View File

@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.vector;
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointVectorValuePair;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.exception.DimensionMismatchException;
/**
* Base class for implementing optimizers for multivariate vector
* differentiable functions.
* It contains boiler-plate code for dealing with Jacobian evaluation.
* It assumes that the rows of the Jacobian matrix iterate on the model
* functions while the columns iterate on the parameters; thus, the numbers
* of rows is equal to the dimension of the {@link Target} while the
* number of columns is equal to the dimension of the
* {@link org.apache.commons.math3.optim.InitialGuess InitialGuess}.
*
* @version $Id$
* @since 3.1
*/
public abstract class JacobianMultivariateVectorOptimizer
extends MultivariateVectorOptimizer {
/**
* Jacobian of the model function.
*/
private MultivariateMatrixFunction jacobian;
/**
* @param checker Convergence checker.
*/
protected JacobianMultivariateVectorOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
super(checker);
}
/**
* Computes the Jacobian matrix.
*
* @param params Point at which the Jacobian must be evaluated.
* @return the Jacobian at the specified point.
*/
protected double[][] computeJacobian(final double[] params) {
return jacobian.value(params);
}
/**
* {@inheritDoc}
*
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
* <li>{@link Target}</li>
* <li>{@link Weight}</li>
* <li>{@link ModelFunction}</li>
* <li>{@link ModelFunctionJacobian}</li>
* </ul>
* @return {@inheritDoc}
* @throws TooManyEvaluationsException if the maximal number of
* evaluations is exceeded.
* @throws DimensionMismatchException if the initial guess, target, and weight
* arguments have inconsistent dimensions.
*/
@Override
public PointVectorValuePair optimize(OptimizationData... optData)
throws TooManyEvaluationsException,
DimensionMismatchException {
// Retrieve settings.
parseOptimizationData(optData);
// Set up base class and perform computation.
return super.optimize(optData);
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link ModelFunctionJacobian}</li>
* </ul>
*/
private void parseOptimizationData(OptimizationData... optData) {
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData) {
if (data instanceof ModelFunctionJacobian) {
jacobian = ((ModelFunctionJacobian) data).getModelFunctionJacobian();
// If more data must be parsed, this statement _must_ be
// changed to "continue".
break;
}
}
}
}

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.vector;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.optim.OptimizationData;
/**
* Model (vector) function to be optimized.
*
* @version $Id$
* @since 3.1
*/
public class ModelFunction implements OptimizationData {
/** Function to be optimized. */
private final MultivariateVectorFunction model;
/**
* @param m Model function to be optimized.
*/
public ModelFunction(MultivariateVectorFunction m) {
model = m;
}
/**
* Gets the model function to be optimized.
*
* @return the model function.
*/
public MultivariateVectorFunction getModelFunction() {
return model;
}
}

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.vector;
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
import org.apache.commons.math3.optim.OptimizationData;
/**
* Jacobian of the model (vector) function to be optimized.
*
* @version $Id$
* @since 3.1
*/
public class ModelFunctionJacobian implements OptimizationData {
/** Function to be optimized. */
private final MultivariateMatrixFunction jacobian;
/**
* @param j Jacobian of the model function to be optimized.
*/
public ModelFunctionJacobian(MultivariateMatrixFunction j) {
jacobian = j;
}
/**
* Gets the Jacobian of the model function to be optimized.
*
* @return the model function Jacobian.
*/
public MultivariateMatrixFunction getModelFunctionJacobian() {
return jacobian;
}
}

View File

@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.vector;
import java.util.Collections;
import java.util.List;
import java.util.ArrayList;
import java.util.Comparator;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NullArgumentException;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.random.RandomVectorGenerator;
import org.apache.commons.math3.optim.BaseMultiStartMultivariateOptimizer;
import org.apache.commons.math3.optim.PointVectorValuePair;
/**
* Multi-start optimizer for a (vector) model function.
*
* This class wraps an optimizer in order to use it several times in
* turn with different starting points (trying to avoid being trapped
* in a local extremum when looking for a global one).
*
* @version $Id$
* @since 3.0
*/
public class MultiStartMultivariateVectorOptimizer
extends BaseMultiStartMultivariateOptimizer<PointVectorValuePair> {
/** Underlying optimizer. */
private final MultivariateVectorOptimizer optimizer;
/** Found optima. */
private final List<PointVectorValuePair> optima = new ArrayList<PointVectorValuePair>();
/**
* Create a multi-start optimizer from a single-start optimizer.
*
* @param optimizer Single-start optimizer to wrap.
* @param starts Number of starts to perform.
* If {@code starts == 1}, the result will be same as if {@code optimizer}
* is called directly.
* @param generator Random vector generator to use for restarts.
* @throws NullArgumentException if {@code optimizer} or {@code generator}
* is {@code null}.
* @throws NotStrictlyPositiveException if {@code starts < 1}.
*/
public MultiStartMultivariateVectorOptimizer(final MultivariateVectorOptimizer optimizer,
final int starts,
final RandomVectorGenerator generator)
throws NullArgumentException,
NotStrictlyPositiveException {
super(optimizer, starts, generator);
this.optimizer = optimizer;
}
/**
* {@inheritDoc}
*/
@Override
public PointVectorValuePair[] getOptima() {
Collections.sort(optima, getPairComparator());
return optima.toArray(new PointVectorValuePair[0]);
}
/**
* {@inheritDoc}
*/
@Override
protected void store(PointVectorValuePair optimum) {
optima.add(optimum);
}
/**
* {@inheritDoc}
*/
@Override
protected void clear() {
optima.clear();
}
/**
* @return a comparator for sorting the optima.
*/
private Comparator<PointVectorValuePair> getPairComparator() {
return new Comparator<PointVectorValuePair>() {
private final RealVector target = new ArrayRealVector(optimizer.getTarget(), false);
private final RealMatrix weight = optimizer.getWeight();
public int compare(final PointVectorValuePair o1,
final PointVectorValuePair o2) {
if (o1 == null) {
return (o2 == null) ? 0 : 1;
} else if (o2 == null) {
return -1;
}
return Double.compare(weightedResidual(o1),
weightedResidual(o2));
}
private double weightedResidual(final PointVectorValuePair pv) {
final RealVector v = new ArrayRealVector(pv.getValueRef(), false);
final RealVector r = target.subtract(v);
return r.dotProduct(weight.operate(r));
}
};
}
}

View File

@ -0,0 +1,164 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.vector;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.BaseMultivariateOptimizer;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.PointVectorValuePair;
import org.apache.commons.math3.linear.RealMatrix;
/**
* Base class for a multivariate vector function optimizer.
*
* @version $Id$
* @since 3.1
*/
public abstract class MultivariateVectorOptimizer
extends BaseMultivariateOptimizer<PointVectorValuePair> {
/** Target values for the model function at optimum. */
private double[] target;
/** Weight matrix. */
private RealMatrix weightMatrix;
/** Model function. */
private MultivariateVectorFunction model;
/**
* @param checker Convergence checker.
*/
protected MultivariateVectorOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
super(checker);
}
/**
* Computes the objective function value.
* This method <em>must</em> be called by subclasses to enforce the
* evaluation counter limit.
*
* @param params Point at which the objective function must be evaluated.
* @return the objective function value at the specified point.
* @throws TooManyEvaluationsException if the maximal number of evaluations
* (of the model vector function) is exceeded.
*/
protected double[] computeObjectiveValue(double[] params) {
super.incrementEvaluationCount();
return model.value(params);
}
/**
* {@inheritDoc}
*
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
* <li>{@link Target}</li>
* <li>{@link Weight}</li>
* <li>{@link ModelFunction}</li>
* </ul>
* @return {@inheritDoc}
* @throws TooManyEvaluationsException if the maximal number of
* evaluations is exceeded.
* @throws DimensionMismatchException if the initial guess, target, and weight
* arguments have inconsistent dimensions.
*/
public PointVectorValuePair optimize(OptimizationData... optData)
throws TooManyEvaluationsException,
DimensionMismatchException {
// Retrieve settings.
parseOptimizationData(optData);
// Check input consistency.
checkParameters();
// Set up base class and perform computation.
return super.optimize(optData);
}
/**
* Gets the weight matrix of the observations.
*
* @return the weight matrix.
*/
public RealMatrix getWeight() {
return weightMatrix.copy();
}
/**
* Gets the observed values to be matched by the objective vector
* function.
*
* @return the target values.
*/
public double[] getTarget() {
return target.clone();
}
/**
* Gets the number of observed values.
*
* @return the length of the target vector.
*/
public int getTargetSize() {
return target.length;
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
*
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link Target}</li>
* <li>{@link Weight}</li>
* <li>{@link ModelFunction}</li>
* </ul>
*/
private void parseOptimizationData(OptimizationData... optData) {
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData) {
if (data instanceof ModelFunction) {
model = ((ModelFunction) data).getModelFunction();
continue;
}
if (data instanceof Target) {
target = ((Target) data).getTarget();
continue;
}
if (data instanceof Weight) {
weightMatrix = ((Weight) data).getWeight();
continue;
}
}
}
/**
* Check parameters consistency.
*
* @throws DimensionMismatchException if {@link #target} and
* {@link #weightMatrix} have inconsistent dimensions.
*/
private void checkParameters() {
if (target.length != weightMatrix.getColumnDimension()) {
throw new DimensionMismatchException(target.length,
weightMatrix.getColumnDimension());
}
}
}

View File

@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.vector;
import org.apache.commons.math3.optim.OptimizationData;
/**
* Target of the optimization procedure.
* They are the values which the objective vector function must reproduce
* When the parameters of the model have been optimized.
* <br/>
* Immutable class.
*
* @version $Id: Target.java 1416643 2012-12-03 19:37:14Z tn $
* @since 3.1
*/
public class Target implements OptimizationData {
/** Target values (of the objective vector function). */
private final double[] target;
/**
* @param observations Target values.
*/
public Target(double[] observations) {
target = observations.clone();
}
/**
* Gets the initial guess.
*
* @return the initial guess.
*/
public double[] getTarget() {
return target.clone();
}
}

View File

@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.vector;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.NonSquareMatrixException;
/**
* Weight matrix of the residuals between model and observations.
* <br/>
* Immutable class.
*
* @version $Id: Weight.java 1416643 2012-12-03 19:37:14Z tn $
* @since 3.1
*/
public class Weight implements OptimizationData {
/** Weight matrix. */
private final RealMatrix weightMatrix;
/**
* Creates a diagonal weight matrix.
*
* @param weight List of the values of the diagonal.
*/
public Weight(double[] weight) {
final int dim = weight.length;
weightMatrix = MatrixUtils.createRealMatrix(dim, dim);
for (int i = 0; i < dim; i++) {
weightMatrix.setEntry(i, i, weight[i]);
}
}
/**
* @param weight Weight matrix.
* @throws NonSquareMatrixException if the argument is not
* a square matrix.
*/
public Weight(RealMatrix weight) {
if (weight.getColumnDimension() != weight.getRowDimension()) {
throw new NonSquareMatrixException(weight.getColumnDimension(),
weight.getRowDimension());
}
weightMatrix = weight.copy();
}
/**
* Gets the initial guess.
*
* @return the initial guess.
*/
public RealMatrix getWeight() {
return weightMatrix.copy();
}
}

View File

@ -0,0 +1,269 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.vector.jacobian;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.PointVectorValuePair;
import org.apache.commons.math3.optim.nonlinear.vector.Weight;
import org.apache.commons.math3.optim.nonlinear.vector.JacobianMultivariateVectorOptimizer;
import org.apache.commons.math3.util.FastMath;
/**
* Base class for implementing least-squares optimizers.
* It provides methods for error estimation.
*
* @version $Id$
* @since 3.1
*/
public abstract class AbstractLeastSquaresOptimizer
extends JacobianMultivariateVectorOptimizer {
/** Square-root of the weight matrix. */
private RealMatrix weightMatrixSqrt;
/** Cost value (square root of the sum of the residuals). */
private double cost;
/**
* @param checker Convergence checker.
*/
protected AbstractLeastSquaresOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
super(checker);
}
/**
* Computes the weighted Jacobian matrix.
*
* @param params Model parameters at which to compute the Jacobian.
* @return the weighted Jacobian: W<sup>1/2</sup> J.
* @throws DimensionMismatchException if the Jacobian dimension does not
* match problem dimension.
*/
protected RealMatrix computeWeightedJacobian(double[] params) {
return weightMatrixSqrt.multiply(MatrixUtils.createRealMatrix(computeJacobian(params)));
}
/**
* Computes the cost.
*
* @param residuals Residuals.
* @return the cost.
* @see #computeResiduals(double[])
*/
protected double computeCost(double[] residuals) {
final ArrayRealVector r = new ArrayRealVector(residuals);
return FastMath.sqrt(r.dotProduct(getWeight().operate(r)));
}
/**
* Gets the root-mean-square (RMS) value.
*
* The RMS the root of the arithmetic mean of the square of all weighted
* residuals.
* This is related to the criterion that is minimized by the optimizer
* as follows: If <em>c</em> if the criterion, and <em>n</em> is the
* number of measurements, then the RMS is <em>sqrt (c/n)</em>.
*
* @return the RMS value.
*/
public double getRMS() {
return FastMath.sqrt(getChiSquare() / getTargetSize());
}
/**
* Get a Chi-Square-like value assuming the N residuals follow N
* distinct normal distributions centered on 0 and whose variances are
* the reciprocal of the weights.
* @return chi-square value
*/
public double getChiSquare() {
return cost * cost;
}
/**
* Gets the square-root of the weight matrix.
*
* @return the square-root of the weight matrix.
*/
public RealMatrix getWeightSquareRoot() {
return weightMatrixSqrt.copy();
}
/**
* Sets the cost.
*
* @param cost Cost value.
*/
protected void setCost(double cost) {
this.cost = cost;
}
/**
* Get the covariance matrix of the optimized parameters.
* <br/>
* Note that this operation involves the inversion of the
* <code>J<sup>T</sup>J</code> matrix, where {@code J} is the
* Jacobian matrix.
* The {@code threshold} parameter is a way for the caller to specify
* that the result of this computation should be considered meaningless,
* and thus trigger an exception.
*
* @param params Model parameters.
* @param threshold Singularity threshold.
* @return the covariance matrix.
* @throws org.apache.commons.math3.linear.SingularMatrixException
* if the covariance matrix cannot be computed (singular problem).
*/
public double[][] computeCovariances(double[] params,
double threshold) {
// Set up the Jacobian.
final RealMatrix j = computeWeightedJacobian(params);
// Compute transpose(J)J.
final RealMatrix jTj = j.transpose().multiply(j);
// Compute the covariances matrix.
final DecompositionSolver solver
= new QRDecomposition(jTj, threshold).getSolver();
return solver.getInverse().getData();
}
/**
* Computes an estimate of the standard deviation of the parameters. The
* returned values are the square root of the diagonal coefficients of the
* covariance matrix, {@code sd(a[i]) ~= sqrt(C[i][i])}, where {@code a[i]}
* is the optimized value of the {@code i}-th parameter, and {@code C} is
* the covariance matrix.
*
* @param params Model parameters.
* @param covarianceSingularityThreshold Singularity threshold (see
* {@link #computeCovariances(double[],double) computeCovariances}).
* @return an estimate of the standard deviation of the optimized parameters
* @throws org.apache.commons.math3.linear.SingularMatrixException
* if the covariance matrix cannot be computed.
*/
public double[] computeSigma(double[] params,
double covarianceSingularityThreshold) {
final int nC = params.length;
final double[] sig = new double[nC];
final double[][] cov = computeCovariances(params, covarianceSingularityThreshold);
for (int i = 0; i < nC; ++i) {
sig[i] = FastMath.sqrt(cov[i][i]);
}
return sig;
}
/**
* {@inheritDoc}
*
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
* <li>{@link org.apache.commons.math3.optim.nonlinear.vector.Target}</li>
* <li>{@link org.apache.commons.math3.optim.nonlinear.vector.Weight}</li>
* <li>{@link org.apache.commons.math3.optim.nonlinear.vector.ModelFunction}</li>
* <li>{@link org.apache.commons.math3.optim.nonlinear.vector.ModelFunctionJacobian}</li>
* </ul>
* @return {@inheritDoc}
* @throws TooManyEvaluationsException if the maximal number of
* evaluations is exceeded.
* @throws DimensionMismatchException if the initial guess, target, and weight
* arguments have inconsistent dimensions.
*/
@Override
public PointVectorValuePair optimize(OptimizationData... optData)
throws TooManyEvaluationsException {
// Retrieve settings.
parseOptimizationData(optData);
// Set up base class and perform computation.
return super.optimize(optData);
}
/**
* Computes the residuals.
* The residual is the difference between the observed (target)
* values and the model (objective function) value.
* There is one residual for each element of the vector-valued
* function.
*
* @param objectiveValue Value of the the objective function. This is
* the value returned from a call to
* {@link #computeObjectiveValue(double[]) computeObjectiveValue}
* (whose array argument contains the model parameters).
* @return the residuals.
* @throws DimensionMismatchException if {@code params} has a wrong
* length.
*/
protected double[] computeResiduals(double[] objectiveValue) {
final double[] target = getTarget();
if (objectiveValue.length != target.length) {
throw new DimensionMismatchException(target.length,
objectiveValue.length);
}
final double[] residuals = new double[target.length];
for (int i = 0; i < target.length; i++) {
residuals[i] = target[i] - objectiveValue[i];
}
return residuals;
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
* If the weight matrix is specified, the {@link #weightMatrixSqrt}
* field is recomputed.
*
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link Weight}</li>
* </ul>
*/
private void parseOptimizationData(OptimizationData... optData) {
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData) {
if (data instanceof Weight) {
weightMatrixSqrt = squareRoot(((Weight) data).getWeight());
// If more data must be parsed, this statement _must_ be
// changed to "continue".
break;
}
}
}
/**
* Computes the square-root of the weight matrix.
*
* @param m Symmetric, positive-definite (weight) matrix.
* @return the square-root of the weight matrix.
*/
private RealMatrix squareRoot(RealMatrix m) {
final EigenDecomposition dec = new EigenDecomposition(m);
return dec.getSquareRoot();
}
}

View File

@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.vector.jacobian;
import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.NullArgumentException;
import org.apache.commons.math3.exception.MathInternalError;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.PointVectorValuePair;
/**
* Gauss-Newton least-squares solver.
* <p>
* This class solve a least-square problem by solving the normal equations
* of the linearized problem at each iteration. Either LU decomposition or
* QR decomposition can be used to solve the normal equations. LU decomposition
* is faster but QR decomposition is more robust for difficult problems.
* </p>
*
* @version $Id: GaussNewtonOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*
*/
public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
/** Indicator for using LU decomposition. */
private final boolean useLU;
/**
* Simple constructor with default settings.
* The normal equations will be solved using LU decomposition.
*
* @param checker Convergence checker.
*/
public GaussNewtonOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
this(true, checker);
}
/**
* @param useLU If {@code true}, the normal equations will be solved
* using LU decomposition, otherwise they will be solved using QR
* decomposition.
* @param checker Convergence checker.
*/
public GaussNewtonOptimizer(final boolean useLU,
ConvergenceChecker<PointVectorValuePair> checker) {
super(checker);
this.useLU = useLU;
}
/** {@inheritDoc} */
@Override
public PointVectorValuePair doOptimize() {
final ConvergenceChecker<PointVectorValuePair> checker
= getConvergenceChecker();
// Computation will be useless without a checker (see "for-loop").
if (checker == null) {
throw new NullArgumentException();
}
final double[] targetValues = getTarget();
final int nR = targetValues.length; // Number of observed data.
final RealMatrix weightMatrix = getWeight();
// Diagonal of the weight matrix.
final double[] residualsWeights = new double[nR];
for (int i = 0; i < nR; i++) {
residualsWeights[i] = weightMatrix.getEntry(i, i);
}
final double[] currentPoint = getStartPoint();
final int nC = currentPoint.length;
// iterate until convergence is reached
PointVectorValuePair current = null;
int iter = 0;
for (boolean converged = false; !converged;) {
++iter;
// evaluate the objective function and its jacobian
PointVectorValuePair previous = current;
// Value of the objective function at "currentPoint".
final double[] currentObjective = computeObjectiveValue(currentPoint);
final double[] currentResiduals = computeResiduals(currentObjective);
final RealMatrix weightedJacobian = computeWeightedJacobian(currentPoint);
current = new PointVectorValuePair(currentPoint, currentObjective);
// build the linear problem
final double[] b = new double[nC];
final double[][] a = new double[nC][nC];
for (int i = 0; i < nR; ++i) {
final double[] grad = weightedJacobian.getRow(i);
final double weight = residualsWeights[i];
final double residual = currentResiduals[i];
// compute the normal equation
final double wr = weight * residual;
for (int j = 0; j < nC; ++j) {
b[j] += wr * grad[j];
}
// build the contribution matrix for measurement i
for (int k = 0; k < nC; ++k) {
double[] ak = a[k];
double wgk = weight * grad[k];
for (int l = 0; l < nC; ++l) {
ak[l] += wgk * grad[l];
}
}
}
try {
// solve the linearized least squares problem
RealMatrix mA = new BlockRealMatrix(a);
DecompositionSolver solver = useLU ?
new LUDecomposition(mA).getSolver() :
new QRDecomposition(mA).getSolver();
final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
// update the estimated parameters
for (int i = 0; i < nC; ++i) {
currentPoint[i] += dX[i];
}
} catch (SingularMatrixException e) {
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
}
// Check convergence.
if (previous != null) {
converged = checker.converged(iter, previous, current);
if (converged) {
setCost(computeCost(currentResiduals));
return current;
}
}
}
// Must never happen.
throw new MathInternalError();
}
}

View File

@ -0,0 +1,939 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.vector.jacobian;
import java.util.Arrays;
import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.optim.PointVectorValuePair;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.Precision;
import org.apache.commons.math3.util.FastMath;
/**
* This class solves a least-squares problem using the Levenberg-Marquardt algorithm.
*
* <p>This implementation <em>should</em> work even for over-determined systems
* (i.e. systems having more point than equations). Over-determined systems
* are solved by ignoring the point which have the smallest impact according
* to their jacobian column norm. Only the rank of the matrix and some loop bounds
* are changed to implement this.</p>
*
* <p>The resolution engine is a simple translation of the MINPACK <a
* href="http://www.netlib.org/minpack/lmder.f">lmder</a> routine with minor
* changes. The changes include the over-determined resolution, the use of
* inherited convergence checker and the Q.R. decomposition which has been
* rewritten following the algorithm described in the
* P. Lascaux and R. Theodor book <i>Analyse num&eacute;rique matricielle
* appliqu&eacute;e &agrave; l'art de l'ing&eacute;nieur</i>, Masson 1986.</p>
* <p>The authors of the original fortran version are:
* <ul>
* <li>Argonne National Laboratory. MINPACK project. March 1980</li>
* <li>Burton S. Garbow</li>
* <li>Kenneth E. Hillstrom</li>
* <li>Jorge J. More</li>
* </ul>
* The redistribution policy for MINPACK is available <a
* href="http://www.netlib.org/minpack/disclaimer">here</a>, for convenience, it
* is reproduced below.</p>
*
* <table border="0" width="80%" cellpadding="10" align="center" bgcolor="#E0E0E0">
* <tr><td>
* Minpack Copyright Notice (1999) University of Chicago.
* All rights reserved
* </td></tr>
* <tr><td>
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* <ol>
* <li>Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.</li>
* <li>Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided
* with the distribution.</li>
* <li>The end-user documentation included with the redistribution, if any,
* must include the following acknowledgment:
* <code>This product includes software developed by the University of
* Chicago, as Operator of Argonne National Laboratory.</code>
* Alternately, this acknowledgment may appear in the software itself,
* if and wherever such third-party acknowledgments normally appear.</li>
* <li><strong>WARRANTY DISCLAIMER. THE SOFTWARE IS SUPPLIED "AS IS"
* WITHOUT WARRANTY OF ANY KIND. THE COPYRIGHT HOLDER, THE
* UNITED STATES, THE UNITED STATES DEPARTMENT OF ENERGY, AND
* THEIR EMPLOYEES: (1) DISCLAIM ANY WARRANTIES, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED WARRANTIES
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE
* OR NON-INFRINGEMENT, (2) DO NOT ASSUME ANY LEGAL LIABILITY
* OR RESPONSIBILITY FOR THE ACCURACY, COMPLETENESS, OR
* USEFULNESS OF THE SOFTWARE, (3) DO NOT REPRESENT THAT USE OF
* THE SOFTWARE WOULD NOT INFRINGE PRIVATELY OWNED RIGHTS, (4)
* DO NOT WARRANT THAT THE SOFTWARE WILL FUNCTION
* UNINTERRUPTED, THAT IT IS ERROR-FREE OR THAT ANY ERRORS WILL
* BE CORRECTED.</strong></li>
* <li><strong>LIMITATION OF LIABILITY. IN NO EVENT WILL THE COPYRIGHT
* HOLDER, THE UNITED STATES, THE UNITED STATES DEPARTMENT OF
* ENERGY, OR THEIR EMPLOYEES: BE LIABLE FOR ANY INDIRECT,
* INCIDENTAL, CONSEQUENTIAL, SPECIAL OR PUNITIVE DAMAGES OF
* ANY KIND OR NATURE, INCLUDING BUT NOT LIMITED TO LOSS OF
* PROFITS OR LOSS OF DATA, FOR ANY REASON WHATSOEVER, WHETHER
* SUCH LIABILITY IS ASSERTED ON THE BASIS OF CONTRACT, TORT
* (INCLUDING NEGLIGENCE OR STRICT LIABILITY), OR OTHERWISE,
* EVEN IF ANY OF SAID PARTIES HAS BEEN WARNED OF THE
* POSSIBILITY OF SUCH LOSS OR DAMAGES.</strong></li>
* <ol></td></tr>
* </table>
*
* @version $Id: LevenbergMarquardtOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class LevenbergMarquardtOptimizer
extends AbstractLeastSquaresOptimizer {
/** Number of solved point. */
private int solvedCols;
/** Diagonal elements of the R matrix in the Q.R. decomposition. */
private double[] diagR;
/** Norms of the columns of the jacobian matrix. */
private double[] jacNorm;
/** Coefficients of the Householder transforms vectors. */
private double[] beta;
/** Columns permutation array. */
private int[] permutation;
/** Rank of the jacobian matrix. */
private int rank;
/** Levenberg-Marquardt parameter. */
private double lmPar;
/** Parameters evolution direction associated with lmPar. */
private double[] lmDir;
/** Positive input variable used in determining the initial step bound. */
private final double initialStepBoundFactor;
/** Desired relative error in the sum of squares. */
private final double costRelativeTolerance;
/** Desired relative error in the approximate solution parameters. */
private final double parRelativeTolerance;
/** Desired max cosine on the orthogonality between the function vector
* and the columns of the jacobian. */
private final double orthoTolerance;
/** Threshold for QR ranking. */
private final double qrRankingThreshold;
/** Weighted residuals. */
private double[] weightedResidual;
/** Weighted Jacobian. */
private double[][] weightedJacobian;
/**
* Build an optimizer for least squares problems with default values
* for all the tuning parameters (see the {@link
* #LevenbergMarquardtOptimizer(double,double,double,double,double)
* other contructor}.
* The default values for the algorithm settings are:
* <ul>
* <li>Initial step bound factor: 100</li>
* <li>Cost relative tolerance: 1e-10</li>
* <li>Parameters relative tolerance: 1e-10</li>
* <li>Orthogonality tolerance: 1e-10</li>
* <li>QR ranking threshold: {@link Precision#SAFE_MIN}</li>
* </ul>
*/
public LevenbergMarquardtOptimizer() {
this(100, 1e-10, 1e-10, 1e-10, Precision.SAFE_MIN);
}
/**
* Constructor that allows the specification of a custom convergence
* checker.
* Note that all the usual convergence checks will be <em>disabled</em>.
* The default values for the algorithm settings are:
* <ul>
* <li>Initial step bound factor: 100</li>
* <li>Cost relative tolerance: 1e-10</li>
* <li>Parameters relative tolerance: 1e-10</li>
* <li>Orthogonality tolerance: 1e-10</li>
* <li>QR ranking threshold: {@link Precision#SAFE_MIN}</li>
* </ul>
*
* @param checker Convergence checker.
*/
public LevenbergMarquardtOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
this(100, checker, 1e-10, 1e-10, 1e-10, Precision.SAFE_MIN);
}
/**
* Constructor that allows the specification of a custom convergence
* checker, in addition to the standard ones.
*
* @param initialStepBoundFactor Positive input variable used in
* determining the initial step bound. This bound is set to the
* product of initialStepBoundFactor and the euclidean norm of
* {@code diag * x} if non-zero, or else to {@code initialStepBoundFactor}
* itself. In most cases factor should lie in the interval
* {@code (0.1, 100.0)}. {@code 100} is a generally recommended value.
* @param checker Convergence checker.
* @param costRelativeTolerance Desired relative error in the sum of
* squares.
* @param parRelativeTolerance Desired relative error in the approximate
* solution parameters.
* @param orthoTolerance Desired max cosine on the orthogonality between
* the function vector and the columns of the Jacobian.
* @param threshold Desired threshold for QR ranking. If the squared norm
* of a column vector is smaller or equal to this threshold during QR
* decomposition, it is considered to be a zero vector and hence the rank
* of the matrix is reduced.
*/
public LevenbergMarquardtOptimizer(double initialStepBoundFactor,
ConvergenceChecker<PointVectorValuePair> checker,
double costRelativeTolerance,
double parRelativeTolerance,
double orthoTolerance,
double threshold) {
super(checker);
this.initialStepBoundFactor = initialStepBoundFactor;
this.costRelativeTolerance = costRelativeTolerance;
this.parRelativeTolerance = parRelativeTolerance;
this.orthoTolerance = orthoTolerance;
this.qrRankingThreshold = threshold;
}
/**
* Build an optimizer for least squares problems with default values
* for some of the tuning parameters (see the {@link
* #LevenbergMarquardtOptimizer(double,double,double,double,double)
* other contructor}.
* The default values for the algorithm settings are:
* <ul>
* <li>Initial step bound factor}: 100</li>
* <li>QR ranking threshold}: {@link Precision#SAFE_MIN}</li>
* </ul>
*
* @param costRelativeTolerance Desired relative error in the sum of
* squares.
* @param parRelativeTolerance Desired relative error in the approximate
* solution parameters.
* @param orthoTolerance Desired max cosine on the orthogonality between
* the function vector and the columns of the Jacobian.
*/
public LevenbergMarquardtOptimizer(double costRelativeTolerance,
double parRelativeTolerance,
double orthoTolerance) {
this(100,
costRelativeTolerance, parRelativeTolerance, orthoTolerance,
Precision.SAFE_MIN);
}
/**
* The arguments control the behaviour of the default convergence checking
* procedure.
* Additional criteria can defined through the setting of a {@link
* ConvergenceChecker}.
*
* @param initialStepBoundFactor Positive input variable used in
* determining the initial step bound. This bound is set to the
* product of initialStepBoundFactor and the euclidean norm of
* {@code diag * x} if non-zero, or else to {@code initialStepBoundFactor}
* itself. In most cases factor should lie in the interval
* {@code (0.1, 100.0)}. {@code 100} is a generally recommended value.
* @param costRelativeTolerance Desired relative error in the sum of
* squares.
* @param parRelativeTolerance Desired relative error in the approximate
* solution parameters.
* @param orthoTolerance Desired max cosine on the orthogonality between
* the function vector and the columns of the Jacobian.
* @param threshold Desired threshold for QR ranking. If the squared norm
* of a column vector is smaller or equal to this threshold during QR
* decomposition, it is considered to be a zero vector and hence the rank
* of the matrix is reduced.
*/
public LevenbergMarquardtOptimizer(double initialStepBoundFactor,
double costRelativeTolerance,
double parRelativeTolerance,
double orthoTolerance,
double threshold) {
super(null); // No custom convergence criterion.
this.initialStepBoundFactor = initialStepBoundFactor;
this.costRelativeTolerance = costRelativeTolerance;
this.parRelativeTolerance = parRelativeTolerance;
this.orthoTolerance = orthoTolerance;
this.qrRankingThreshold = threshold;
}
/** {@inheritDoc} */
@Override
protected PointVectorValuePair doOptimize() {
final int nR = getTarget().length; // Number of observed data.
final double[] currentPoint = getStartPoint();
final int nC = currentPoint.length; // Number of parameters.
// arrays shared with the other private methods
solvedCols = FastMath.min(nR, nC);
diagR = new double[nC];
jacNorm = new double[nC];
beta = new double[nC];
permutation = new int[nC];
lmDir = new double[nC];
// local point
double delta = 0;
double xNorm = 0;
double[] diag = new double[nC];
double[] oldX = new double[nC];
double[] oldRes = new double[nR];
double[] oldObj = new double[nR];
double[] qtf = new double[nR];
double[] work1 = new double[nC];
double[] work2 = new double[nC];
double[] work3 = new double[nC];
final RealMatrix weightMatrixSqrt = getWeightSquareRoot();
// Evaluate the function at the starting point and calculate its norm.
double[] currentObjective = computeObjectiveValue(currentPoint);
double[] currentResiduals = computeResiduals(currentObjective);
PointVectorValuePair current = new PointVectorValuePair(currentPoint, currentObjective);
double currentCost = computeCost(currentResiduals);
// Outer loop.
lmPar = 0;
boolean firstIteration = true;
int iter = 0;
final ConvergenceChecker<PointVectorValuePair> checker = getConvergenceChecker();
while (true) {
++iter;
final PointVectorValuePair previous = current;
// QR decomposition of the jacobian matrix
qrDecomposition(computeWeightedJacobian(currentPoint));
weightedResidual = weightMatrixSqrt.operate(currentResiduals);
for (int i = 0; i < nR; i++) {
qtf[i] = weightedResidual[i];
}
// compute Qt.res
qTy(qtf);
// now we don't need Q anymore,
// so let jacobian contain the R matrix with its diagonal elements
for (int k = 0; k < solvedCols; ++k) {
int pk = permutation[k];
weightedJacobian[k][pk] = diagR[pk];
}
if (firstIteration) {
// scale the point according to the norms of the columns
// of the initial jacobian
xNorm = 0;
for (int k = 0; k < nC; ++k) {
double dk = jacNorm[k];
if (dk == 0) {
dk = 1.0;
}
double xk = dk * currentPoint[k];
xNorm += xk * xk;
diag[k] = dk;
}
xNorm = FastMath.sqrt(xNorm);
// initialize the step bound delta
delta = (xNorm == 0) ? initialStepBoundFactor : (initialStepBoundFactor * xNorm);
}
// check orthogonality between function vector and jacobian columns
double maxCosine = 0;
if (currentCost != 0) {
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
double s = jacNorm[pj];
if (s != 0) {
double sum = 0;
for (int i = 0; i <= j; ++i) {
sum += weightedJacobian[i][pj] * qtf[i];
}
maxCosine = FastMath.max(maxCosine, FastMath.abs(sum) / (s * currentCost));
}
}
}
if (maxCosine <= orthoTolerance) {
// Convergence has been reached.
setCost(currentCost);
return current;
}
// rescale if necessary
for (int j = 0; j < nC; ++j) {
diag[j] = FastMath.max(diag[j], jacNorm[j]);
}
// Inner loop.
for (double ratio = 0; ratio < 1.0e-4;) {
// save the state
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
oldX[pj] = currentPoint[pj];
}
final double previousCost = currentCost;
double[] tmpVec = weightedResidual;
weightedResidual = oldRes;
oldRes = tmpVec;
tmpVec = currentObjective;
currentObjective = oldObj;
oldObj = tmpVec;
// determine the Levenberg-Marquardt parameter
determineLMParameter(qtf, delta, diag, work1, work2, work3);
// compute the new point and the norm of the evolution direction
double lmNorm = 0;
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
lmDir[pj] = -lmDir[pj];
currentPoint[pj] = oldX[pj] + lmDir[pj];
double s = diag[pj] * lmDir[pj];
lmNorm += s * s;
}
lmNorm = FastMath.sqrt(lmNorm);
// on the first iteration, adjust the initial step bound.
if (firstIteration) {
delta = FastMath.min(delta, lmNorm);
}
// Evaluate the function at x + p and calculate its norm.
currentObjective = computeObjectiveValue(currentPoint);
currentResiduals = computeResiduals(currentObjective);
current = new PointVectorValuePair(currentPoint, currentObjective);
currentCost = computeCost(currentResiduals);
// compute the scaled actual reduction
double actRed = -1.0;
if (0.1 * currentCost < previousCost) {
double r = currentCost / previousCost;
actRed = 1.0 - r * r;
}
// compute the scaled predicted reduction
// and the scaled directional derivative
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
double dirJ = lmDir[pj];
work1[j] = 0;
for (int i = 0; i <= j; ++i) {
work1[i] += weightedJacobian[i][pj] * dirJ;
}
}
double coeff1 = 0;
for (int j = 0; j < solvedCols; ++j) {
coeff1 += work1[j] * work1[j];
}
double pc2 = previousCost * previousCost;
coeff1 = coeff1 / pc2;
double coeff2 = lmPar * lmNorm * lmNorm / pc2;
double preRed = coeff1 + 2 * coeff2;
double dirDer = -(coeff1 + coeff2);
// ratio of the actual to the predicted reduction
ratio = (preRed == 0) ? 0 : (actRed / preRed);
// update the step bound
if (ratio <= 0.25) {
double tmp =
(actRed < 0) ? (0.5 * dirDer / (dirDer + 0.5 * actRed)) : 0.5;
if ((0.1 * currentCost >= previousCost) || (tmp < 0.1)) {
tmp = 0.1;
}
delta = tmp * FastMath.min(delta, 10.0 * lmNorm);
lmPar /= tmp;
} else if ((lmPar == 0) || (ratio >= 0.75)) {
delta = 2 * lmNorm;
lmPar *= 0.5;
}
// test for successful iteration.
if (ratio >= 1.0e-4) {
// successful iteration, update the norm
firstIteration = false;
xNorm = 0;
for (int k = 0; k < nC; ++k) {
double xK = diag[k] * currentPoint[k];
xNorm += xK * xK;
}
xNorm = FastMath.sqrt(xNorm);
// tests for convergence.
if (checker != null) {
// we use the vectorial convergence checker
if (checker.converged(iter, previous, current)) {
setCost(currentCost);
return current;
}
}
} else {
// failed iteration, reset the previous values
currentCost = previousCost;
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
currentPoint[pj] = oldX[pj];
}
tmpVec = weightedResidual;
weightedResidual = oldRes;
oldRes = tmpVec;
tmpVec = currentObjective;
currentObjective = oldObj;
oldObj = tmpVec;
// Reset "current" to previous values.
current = new PointVectorValuePair(currentPoint, currentObjective);
}
// Default convergence criteria.
if ((FastMath.abs(actRed) <= costRelativeTolerance &&
preRed <= costRelativeTolerance &&
ratio <= 2.0) ||
delta <= parRelativeTolerance * xNorm) {
setCost(currentCost);
return current;
}
// tests for termination and stringent tolerances
// (2.2204e-16 is the machine epsilon for IEEE754)
if ((FastMath.abs(actRed) <= 2.2204e-16) && (preRed <= 2.2204e-16) && (ratio <= 2.0)) {
throw new ConvergenceException(LocalizedFormats.TOO_SMALL_COST_RELATIVE_TOLERANCE,
costRelativeTolerance);
} else if (delta <= 2.2204e-16 * xNorm) {
throw new ConvergenceException(LocalizedFormats.TOO_SMALL_PARAMETERS_RELATIVE_TOLERANCE,
parRelativeTolerance);
} else if (maxCosine <= 2.2204e-16) {
throw new ConvergenceException(LocalizedFormats.TOO_SMALL_ORTHOGONALITY_TOLERANCE,
orthoTolerance);
}
}
}
}
/**
* Determine the Levenberg-Marquardt parameter.
* <p>This implementation is a translation in Java of the MINPACK
* <a href="http://www.netlib.org/minpack/lmpar.f">lmpar</a>
* routine.</p>
* <p>This method sets the lmPar and lmDir attributes.</p>
* <p>The authors of the original fortran function are:</p>
* <ul>
* <li>Argonne National Laboratory. MINPACK project. March 1980</li>
* <li>Burton S. Garbow</li>
* <li>Kenneth E. Hillstrom</li>
* <li>Jorge J. More</li>
* </ul>
* <p>Luc Maisonobe did the Java translation.</p>
*
* @param qy array containing qTy
* @param delta upper bound on the euclidean norm of diagR * lmDir
* @param diag diagonal matrix
* @param work1 work array
* @param work2 work array
* @param work3 work array
*/
private void determineLMParameter(double[] qy, double delta, double[] diag,
double[] work1, double[] work2, double[] work3) {
final int nC = weightedJacobian[0].length;
// compute and store in x the gauss-newton direction, if the
// jacobian is rank-deficient, obtain a least squares solution
for (int j = 0; j < rank; ++j) {
lmDir[permutation[j]] = qy[j];
}
for (int j = rank; j < nC; ++j) {
lmDir[permutation[j]] = 0;
}
for (int k = rank - 1; k >= 0; --k) {
int pk = permutation[k];
double ypk = lmDir[pk] / diagR[pk];
for (int i = 0; i < k; ++i) {
lmDir[permutation[i]] -= ypk * weightedJacobian[i][pk];
}
lmDir[pk] = ypk;
}
// evaluate the function at the origin, and test
// for acceptance of the Gauss-Newton direction
double dxNorm = 0;
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
double s = diag[pj] * lmDir[pj];
work1[pj] = s;
dxNorm += s * s;
}
dxNorm = FastMath.sqrt(dxNorm);
double fp = dxNorm - delta;
if (fp <= 0.1 * delta) {
lmPar = 0;
return;
}
// if the jacobian is not rank deficient, the Newton step provides
// a lower bound, parl, for the zero of the function,
// otherwise set this bound to zero
double sum2;
double parl = 0;
if (rank == solvedCols) {
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
work1[pj] *= diag[pj] / dxNorm;
}
sum2 = 0;
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
double sum = 0;
for (int i = 0; i < j; ++i) {
sum += weightedJacobian[i][pj] * work1[permutation[i]];
}
double s = (work1[pj] - sum) / diagR[pj];
work1[pj] = s;
sum2 += s * s;
}
parl = fp / (delta * sum2);
}
// calculate an upper bound, paru, for the zero of the function
sum2 = 0;
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
double sum = 0;
for (int i = 0; i <= j; ++i) {
sum += weightedJacobian[i][pj] * qy[i];
}
sum /= diag[pj];
sum2 += sum * sum;
}
double gNorm = FastMath.sqrt(sum2);
double paru = gNorm / delta;
if (paru == 0) {
// 2.2251e-308 is the smallest positive real for IEE754
paru = 2.2251e-308 / FastMath.min(delta, 0.1);
}
// if the input par lies outside of the interval (parl,paru),
// set par to the closer endpoint
lmPar = FastMath.min(paru, FastMath.max(lmPar, parl));
if (lmPar == 0) {
lmPar = gNorm / dxNorm;
}
for (int countdown = 10; countdown >= 0; --countdown) {
// evaluate the function at the current value of lmPar
if (lmPar == 0) {
lmPar = FastMath.max(2.2251e-308, 0.001 * paru);
}
double sPar = FastMath.sqrt(lmPar);
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
work1[pj] = sPar * diag[pj];
}
determineLMDirection(qy, work1, work2, work3);
dxNorm = 0;
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
double s = diag[pj] * lmDir[pj];
work3[pj] = s;
dxNorm += s * s;
}
dxNorm = FastMath.sqrt(dxNorm);
double previousFP = fp;
fp = dxNorm - delta;
// if the function is small enough, accept the current value
// of lmPar, also test for the exceptional cases where parl is zero
if ((FastMath.abs(fp) <= 0.1 * delta) ||
((parl == 0) && (fp <= previousFP) && (previousFP < 0))) {
return;
}
// compute the Newton correction
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
work1[pj] = work3[pj] * diag[pj] / dxNorm;
}
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
work1[pj] /= work2[j];
double tmp = work1[pj];
for (int i = j + 1; i < solvedCols; ++i) {
work1[permutation[i]] -= weightedJacobian[i][pj] * tmp;
}
}
sum2 = 0;
for (int j = 0; j < solvedCols; ++j) {
double s = work1[permutation[j]];
sum2 += s * s;
}
double correction = fp / (delta * sum2);
// depending on the sign of the function, update parl or paru.
if (fp > 0) {
parl = FastMath.max(parl, lmPar);
} else if (fp < 0) {
paru = FastMath.min(paru, lmPar);
}
// compute an improved estimate for lmPar
lmPar = FastMath.max(parl, lmPar + correction);
}
}
/**
* Solve a*x = b and d*x = 0 in the least squares sense.
* <p>This implementation is a translation in Java of the MINPACK
* <a href="http://www.netlib.org/minpack/qrsolv.f">qrsolv</a>
* routine.</p>
* <p>This method sets the lmDir and lmDiag attributes.</p>
* <p>The authors of the original fortran function are:</p>
* <ul>
* <li>Argonne National Laboratory. MINPACK project. March 1980</li>
* <li>Burton S. Garbow</li>
* <li>Kenneth E. Hillstrom</li>
* <li>Jorge J. More</li>
* </ul>
* <p>Luc Maisonobe did the Java translation.</p>
*
* @param qy array containing qTy
* @param diag diagonal matrix
* @param lmDiag diagonal elements associated with lmDir
* @param work work array
*/
private void determineLMDirection(double[] qy, double[] diag,
double[] lmDiag, double[] work) {
// copy R and Qty to preserve input and initialize s
// in particular, save the diagonal elements of R in lmDir
for (int j = 0; j < solvedCols; ++j) {
int pj = permutation[j];
for (int i = j + 1; i < solvedCols; ++i) {
weightedJacobian[i][pj] = weightedJacobian[j][permutation[i]];
}
lmDir[j] = diagR[pj];
work[j] = qy[j];
}
// eliminate the diagonal matrix d using a Givens rotation
for (int j = 0; j < solvedCols; ++j) {
// prepare the row of d to be eliminated, locating the
// diagonal element using p from the Q.R. factorization
int pj = permutation[j];
double dpj = diag[pj];
if (dpj != 0) {
Arrays.fill(lmDiag, j + 1, lmDiag.length, 0);
}
lmDiag[j] = dpj;
// the transformations to eliminate the row of d
// modify only a single element of Qty
// beyond the first n, which is initially zero.
double qtbpj = 0;
for (int k = j; k < solvedCols; ++k) {
int pk = permutation[k];
// determine a Givens rotation which eliminates the
// appropriate element in the current row of d
if (lmDiag[k] != 0) {
final double sin;
final double cos;
double rkk = weightedJacobian[k][pk];
if (FastMath.abs(rkk) < FastMath.abs(lmDiag[k])) {
final double cotan = rkk / lmDiag[k];
sin = 1.0 / FastMath.sqrt(1.0 + cotan * cotan);
cos = sin * cotan;
} else {
final double tan = lmDiag[k] / rkk;
cos = 1.0 / FastMath.sqrt(1.0 + tan * tan);
sin = cos * tan;
}
// compute the modified diagonal element of R and
// the modified element of (Qty,0)
weightedJacobian[k][pk] = cos * rkk + sin * lmDiag[k];
final double temp = cos * work[k] + sin * qtbpj;
qtbpj = -sin * work[k] + cos * qtbpj;
work[k] = temp;
// accumulate the tranformation in the row of s
for (int i = k + 1; i < solvedCols; ++i) {
double rik = weightedJacobian[i][pk];
final double temp2 = cos * rik + sin * lmDiag[i];
lmDiag[i] = -sin * rik + cos * lmDiag[i];
weightedJacobian[i][pk] = temp2;
}
}
}
// store the diagonal element of s and restore
// the corresponding diagonal element of R
lmDiag[j] = weightedJacobian[j][permutation[j]];
weightedJacobian[j][permutation[j]] = lmDir[j];
}
// solve the triangular system for z, if the system is
// singular, then obtain a least squares solution
int nSing = solvedCols;
for (int j = 0; j < solvedCols; ++j) {
if ((lmDiag[j] == 0) && (nSing == solvedCols)) {
nSing = j;
}
if (nSing < solvedCols) {
work[j] = 0;
}
}
if (nSing > 0) {
for (int j = nSing - 1; j >= 0; --j) {
int pj = permutation[j];
double sum = 0;
for (int i = j + 1; i < nSing; ++i) {
sum += weightedJacobian[i][pj] * work[i];
}
work[j] = (work[j] - sum) / lmDiag[j];
}
}
// permute the components of z back to components of lmDir
for (int j = 0; j < lmDir.length; ++j) {
lmDir[permutation[j]] = work[j];
}
}
/**
* Decompose a matrix A as A.P = Q.R using Householder transforms.
* <p>As suggested in the P. Lascaux and R. Theodor book
* <i>Analyse num&eacute;rique matricielle appliqu&eacute;e &agrave;
* l'art de l'ing&eacute;nieur</i> (Masson, 1986), instead of representing
* the Householder transforms with u<sub>k</sub> unit vectors such that:
* <pre>
* H<sub>k</sub> = I - 2u<sub>k</sub>.u<sub>k</sub><sup>t</sup>
* </pre>
* we use <sub>k</sub> non-unit vectors such that:
* <pre>
* H<sub>k</sub> = I - beta<sub>k</sub>v<sub>k</sub>.v<sub>k</sub><sup>t</sup>
* </pre>
* where v<sub>k</sub> = a<sub>k</sub> - alpha<sub>k</sub> e<sub>k</sub>.
* The beta<sub>k</sub> coefficients are provided upon exit as recomputing
* them from the v<sub>k</sub> vectors would be costly.</p>
* <p>This decomposition handles rank deficient cases since the tranformations
* are performed in non-increasing columns norms order thanks to columns
* pivoting. The diagonal elements of the R matrix are therefore also in
* non-increasing absolute values order.</p>
*
* @param jacobian Weighted Jacobian matrix at the current point.
* @exception ConvergenceException if the decomposition cannot be performed
*/
private void qrDecomposition(RealMatrix jacobian) throws ConvergenceException {
// Code in this class assumes that the weighted Jacobian is -(W^(1/2) J),
// hence the multiplication by -1.
weightedJacobian = jacobian.scalarMultiply(-1).getData();
final int nR = weightedJacobian.length;
final int nC = weightedJacobian[0].length;
// initializations
for (int k = 0; k < nC; ++k) {
permutation[k] = k;
double norm2 = 0;
for (int i = 0; i < nR; ++i) {
double akk = weightedJacobian[i][k];
norm2 += akk * akk;
}
jacNorm[k] = FastMath.sqrt(norm2);
}
// transform the matrix column after column
for (int k = 0; k < nC; ++k) {
// select the column with the greatest norm on active components
int nextColumn = -1;
double ak2 = Double.NEGATIVE_INFINITY;
for (int i = k; i < nC; ++i) {
double norm2 = 0;
for (int j = k; j < nR; ++j) {
double aki = weightedJacobian[j][permutation[i]];
norm2 += aki * aki;
}
if (Double.isInfinite(norm2) || Double.isNaN(norm2)) {
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_PERFORM_QR_DECOMPOSITION_ON_JACOBIAN,
nR, nC);
}
if (norm2 > ak2) {
nextColumn = i;
ak2 = norm2;
}
}
if (ak2 <= qrRankingThreshold) {
rank = k;
return;
}
int pk = permutation[nextColumn];
permutation[nextColumn] = permutation[k];
permutation[k] = pk;
// choose alpha such that Hk.u = alpha ek
double akk = weightedJacobian[k][pk];
double alpha = (akk > 0) ? -FastMath.sqrt(ak2) : FastMath.sqrt(ak2);
double betak = 1.0 / (ak2 - akk * alpha);
beta[pk] = betak;
// transform the current column
diagR[pk] = alpha;
weightedJacobian[k][pk] -= alpha;
// transform the remaining columns
for (int dk = nC - 1 - k; dk > 0; --dk) {
double gamma = 0;
for (int j = k; j < nR; ++j) {
gamma += weightedJacobian[j][pk] * weightedJacobian[j][permutation[k + dk]];
}
gamma *= betak;
for (int j = k; j < nR; ++j) {
weightedJacobian[j][permutation[k + dk]] -= gamma * weightedJacobian[j][pk];
}
}
}
rank = solvedCols;
}
/**
* Compute the product Qt.y for some Q.R. decomposition.
*
* @param y vector to multiply (will be overwritten with the result)
*/
private void qTy(double[] y) {
final int nR = weightedJacobian.length;
final int nC = weightedJacobian[0].length;
for (int k = 0; k < nC; ++k) {
int pk = permutation[k];
double gamma = 0;
for (int i = k; i < nR; ++i) {
gamma += weightedJacobian[i][pk] * y[i];
}
gamma *= beta[pk];
for (int i = k; i < nR; ++i) {
y[i] -= gamma * weightedJacobian[i][pk];
}
}
}
}

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.vector.jacobian;
/**
* This package provides optimization algorithms that require derivatives.
*/

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.vector;
/**
* Algorithms for optimizing a vector function.
*/

View File

@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
/**
* <p>
* Generally, optimizers are algorithms that will either
* {@link GoalType#MINIMIZE minimize} or {@link GoalType#MAXIMIZE maximize}
* a scalar function, called the {@link ObjectiveFunction <em>objective
* function</em>}.
* <br/>
* For some scalar objective functions the gradient can be computed (analytically
* or numerically). Algorithms that use this knowledge are defined in the
* {@link org.apache.commons.math3.optim.nonlinear.scalar.gradient} package.
* The algorithms that do not need this additional information are located in
* the {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv} package.
* </p>
*
* <p>
* Some problems are solved more efficiently by algorithms that, instead of an
* objective function, need access to a
* {@link org.apache.commons.math3.optim.nonlinear.vector.ModelFunction
* <em>model function</em>}: such a model predicts a set of values which the
* algorithm tries to match with a set of given
* {@link org.apache.commons.math3.optim.nonlinear.vector.Target target values}.
* Those algorithms are located in the
* {@link org.apache.commons.math3.optim.nonlinear.vector} package.
* <br/>
* Algorithms that also require the
* {@link org.apache.commons.math3.optim.nonlinear.vector.ModelFunctionJacobian
* Jacobian matrix of the model} are located in the
* {@link org.apache.commons.math3.optim.nonlinear.vector.jacobian} package.
* <br/>
* The {@link org.apache.commons.math3.optim.nonlinear.vector.jacobian.AbstractLeastSquaresOptimizer
* non-linear least-squares optimizers} are a specialization of the the latter,
* that minimize the distance (called <em>cost</em> or <em>&chi;<sup>2</sup></em>)
* between model and observations.
* <br/>
* For cases where the Jacobian cannot be provided, a utility class will
* {@link org.apache.commons.math3.optim.nonlinear.scalar.LeastSquaresConverter
* convert} a (vector) model into a (scalar) objective function.
* </p>
*
* <p>
* This package provides common functionality for the optimization algorithms.
* Abstract classes ({@link BaseOptimizer} and {@link BaseMultivariateOptimizer})
* define boiler-plate code for storing {@link MaxEval evaluations} and
* {@link MaxIter iterations} counters and a user-defined
* {@link ConvergenceChecker convergence checker}.
* </p>
*
* <p>
* For each of the optimizer types, there is a special implementation that
* wraps an optimizer instance and provides a "multi-start" feature: it calls
* the underlying optimizer several times with different starting points and
* returns the best optimum found, or all optima if so desired.
* This could be useful to avoid being trapped in a local extremum.
* </p>
*/

View File

@ -0,0 +1,287 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.univariate;
import org.apache.commons.math3.util.Incrementor;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.exception.MaxCountExceededException;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.optim.GoalType;
/**
* Provide an interval that brackets a local optimum of a function.
* This code is based on a Python implementation (from <em>SciPy</em>,
* module {@code optimize.py} v0.5).
*
* @version $Id: BracketFinder.java 1413186 2012-11-24 13:47:59Z erans $
* @since 2.2
*/
public class BracketFinder {
/** Tolerance to avoid division by zero. */
private static final double EPS_MIN = 1e-21;
/**
* Golden section.
*/
private static final double GOLD = 1.618034;
/**
* Factor for expanding the interval.
*/
private final double growLimit;
/**
* Counter for function evaluations.
*/
private final Incrementor evaluations = new Incrementor();
/**
* Lower bound of the bracket.
*/
private double lo;
/**
* Higher bound of the bracket.
*/
private double hi;
/**
* Point inside the bracket.
*/
private double mid;
/**
* Function value at {@link #lo}.
*/
private double fLo;
/**
* Function value at {@link #hi}.
*/
private double fHi;
/**
* Function value at {@link #mid}.
*/
private double fMid;
/**
* Constructor with default values {@code 100, 50} (see the
* {@link #BracketFinder(double,int) other constructor}).
*/
public BracketFinder() {
this(100, 50);
}
/**
* Create a bracketing interval finder.
*
* @param growLimit Expanding factor.
* @param maxEvaluations Maximum number of evaluations allowed for finding
* a bracketing interval.
*/
public BracketFinder(double growLimit,
int maxEvaluations) {
if (growLimit <= 0) {
throw new NotStrictlyPositiveException(growLimit);
}
if (maxEvaluations <= 0) {
throw new NotStrictlyPositiveException(maxEvaluations);
}
this.growLimit = growLimit;
evaluations.setMaximalCount(maxEvaluations);
}
/**
* Search new points that bracket a local optimum of the function.
*
* @param func Function whose optimum should be bracketed.
* @param goal {@link GoalType Goal type}.
* @param xA Initial point.
* @param xB Initial point.
* @throws TooManyEvaluationsException if the maximum number of evaluations
* is exceeded.
*/
public void search(UnivariateFunction func, GoalType goal, double xA, double xB) {
evaluations.resetCount();
final boolean isMinim = goal == GoalType.MINIMIZE;
double fA = eval(func, xA);
double fB = eval(func, xB);
if (isMinim ?
fA < fB :
fA > fB) {
double tmp = xA;
xA = xB;
xB = tmp;
tmp = fA;
fA = fB;
fB = tmp;
}
double xC = xB + GOLD * (xB - xA);
double fC = eval(func, xC);
while (isMinim ? fC < fB : fC > fB) {
double tmp1 = (xB - xA) * (fB - fC);
double tmp2 = (xB - xC) * (fB - fA);
double val = tmp2 - tmp1;
double denom = Math.abs(val) < EPS_MIN ? 2 * EPS_MIN : 2 * val;
double w = xB - ((xB - xC) * tmp2 - (xB - xA) * tmp1) / denom;
double wLim = xB + growLimit * (xC - xB);
double fW;
if ((w - xC) * (xB - w) > 0) {
fW = eval(func, w);
if (isMinim ?
fW < fC :
fW > fC) {
xA = xB;
xB = w;
fA = fB;
fB = fW;
break;
} else if (isMinim ?
fW > fB :
fW < fB) {
xC = w;
fC = fW;
break;
}
w = xC + GOLD * (xC - xB);
fW = eval(func, w);
} else if ((w - wLim) * (wLim - xC) >= 0) {
w = wLim;
fW = eval(func, w);
} else if ((w - wLim) * (xC - w) > 0) {
fW = eval(func, w);
if (isMinim ?
fW < fC :
fW > fC) {
xB = xC;
xC = w;
w = xC + GOLD * (xC - xB);
fB = fC;
fC =fW;
fW = eval(func, w);
}
} else {
w = xC + GOLD * (xC - xB);
fW = eval(func, w);
}
xA = xB;
fA = fB;
xB = xC;
fB = fC;
xC = w;
fC = fW;
}
lo = xA;
fLo = fA;
mid = xB;
fMid = fB;
hi = xC;
fHi = fC;
if (lo > hi) {
double tmp = lo;
lo = hi;
hi = tmp;
tmp = fLo;
fLo = fHi;
fHi = tmp;
}
}
/**
* @return the number of evalutations.
*/
public int getMaxEvaluations() {
return evaluations.getMaximalCount();
}
/**
* @return the number of evalutations.
*/
public int getEvaluations() {
return evaluations.getCount();
}
/**
* @return the lower bound of the bracket.
* @see #getFLo()
*/
public double getLo() {
return lo;
}
/**
* Get function value at {@link #getLo()}.
* @return function value at {@link #getLo()}
*/
public double getFLo() {
return fLo;
}
/**
* @return the higher bound of the bracket.
* @see #getFHi()
*/
public double getHi() {
return hi;
}
/**
* Get function value at {@link #getHi()}.
* @return function value at {@link #getHi()}
*/
public double getFHi() {
return fHi;
}
/**
* @return a point in the middle of the bracket.
* @see #getFMid()
*/
public double getMid() {
return mid;
}
/**
* Get function value at {@link #getMid()}.
* @return function value at {@link #getMid()}
*/
public double getFMid() {
return fMid;
}
/**
* @param f Function.
* @param x Argument.
* @return {@code f(x)}
* @throws TooManyEvaluationsException if the maximal number of evaluations is
* exceeded.
*/
private double eval(UnivariateFunction f, double x) {
try {
evaluations.incrementCount();
} catch (MaxCountExceededException e) {
throw new TooManyEvaluationsException(e.getMax());
}
return f.value(x);
}
}

View File

@ -0,0 +1,317 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.univariate;
import org.apache.commons.math3.util.Precision;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.GoalType;
/**
* For a function defined on some interval {@code (lo, hi)}, this class
* finds an approximation {@code x} to the point at which the function
* attains its minimum.
* It implements Richard Brent's algorithm (from his book "Algorithms for
* Minimization without Derivatives", p. 79) for finding minima of real
* univariate functions.
* <br/>
* This code is an adaptation, partly based on the Python code from SciPy
* (module "optimize.py" v0.5); the original algorithm is also modified
* <ul>
* <li>to use an initial guess provided by the user,</li>
* <li>to ensure that the best point encountered is the one returned.</li>
* </ul>
*
* @version $Id: BrentOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
* @since 2.0
*/
public class BrentOptimizer extends UnivariateOptimizer {
/**
* Golden section.
*/
private static final double GOLDEN_SECTION = 0.5 * (3 - FastMath.sqrt(5));
/**
* Minimum relative tolerance.
*/
private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
/**
* Relative threshold.
*/
private final double relativeThreshold;
/**
* Absolute threshold.
*/
private final double absoluteThreshold;
/**
* The arguments are used implement the original stopping criterion
* of Brent's algorithm.
* {@code abs} and {@code rel} define a tolerance
* {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
* <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
* where <em>macheps</em> is the relative machine precision. {@code abs} must
* be positive.
*
* @param rel Relative threshold.
* @param abs Absolute threshold.
* @param checker Additional, user-defined, convergence checking
* procedure.
* @throws NotStrictlyPositiveException if {@code abs <= 0}.
* @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
*/
public BrentOptimizer(double rel,
double abs,
ConvergenceChecker<UnivariatePointValuePair> checker) {
super(checker);
if (rel < MIN_RELATIVE_TOLERANCE) {
throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
}
if (abs <= 0) {
throw new NotStrictlyPositiveException(abs);
}
relativeThreshold = rel;
absoluteThreshold = abs;
}
/**
* The arguments are used for implementing the original stopping criterion
* of Brent's algorithm.
* {@code abs} and {@code rel} define a tolerance
* {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
* <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
* where <em>macheps</em> is the relative machine precision. {@code abs} must
* be positive.
*
* @param rel Relative threshold.
* @param abs Absolute threshold.
* @throws NotStrictlyPositiveException if {@code abs <= 0}.
* @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
*/
public BrentOptimizer(double rel,
double abs) {
this(rel, abs, null);
}
/** {@inheritDoc} */
@Override
protected UnivariatePointValuePair doOptimize() {
final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
final double lo = getMin();
final double mid = getStartValue();
final double hi = getMax();
// Optional additional convergence criteria.
final ConvergenceChecker<UnivariatePointValuePair> checker
= getConvergenceChecker();
double a;
double b;
if (lo < hi) {
a = lo;
b = hi;
} else {
a = hi;
b = lo;
}
double x = mid;
double v = x;
double w = x;
double d = 0;
double e = 0;
double fx = computeObjectiveValue(x);
if (!isMinim) {
fx = -fx;
}
double fv = fx;
double fw = fx;
UnivariatePointValuePair previous = null;
UnivariatePointValuePair current
= new UnivariatePointValuePair(x, isMinim ? fx : -fx);
// Best point encountered so far (which is the initial guess).
UnivariatePointValuePair best = current;
int iter = 0;
while (true) {
final double m = 0.5 * (a + b);
final double tol1 = relativeThreshold * FastMath.abs(x) + absoluteThreshold;
final double tol2 = 2 * tol1;
// Default stopping criterion.
final boolean stop = FastMath.abs(x - m) <= tol2 - 0.5 * (b - a);
if (!stop) {
double p = 0;
double q = 0;
double r = 0;
double u = 0;
if (FastMath.abs(e) > tol1) { // Fit parabola.
r = (x - w) * (fx - fv);
q = (x - v) * (fx - fw);
p = (x - v) * q - (x - w) * r;
q = 2 * (q - r);
if (q > 0) {
p = -p;
} else {
q = -q;
}
r = e;
e = d;
if (p > q * (a - x) &&
p < q * (b - x) &&
FastMath.abs(p) < FastMath.abs(0.5 * q * r)) {
// Parabolic interpolation step.
d = p / q;
u = x + d;
// f must not be evaluated too close to a or b.
if (u - a < tol2 || b - u < tol2) {
if (x <= m) {
d = tol1;
} else {
d = -tol1;
}
}
} else {
// Golden section step.
if (x < m) {
e = b - x;
} else {
e = a - x;
}
d = GOLDEN_SECTION * e;
}
} else {
// Golden section step.
if (x < m) {
e = b - x;
} else {
e = a - x;
}
d = GOLDEN_SECTION * e;
}
// Update by at least "tol1".
if (FastMath.abs(d) < tol1) {
if (d >= 0) {
u = x + tol1;
} else {
u = x - tol1;
}
} else {
u = x + d;
}
double fu = computeObjectiveValue(u);
if (!isMinim) {
fu = -fu;
}
// User-defined convergence checker.
previous = current;
current = new UnivariatePointValuePair(u, isMinim ? fu : -fu);
best = best(best,
best(previous,
current,
isMinim),
isMinim);
if (checker != null) {
if (checker.converged(iter, previous, current)) {
return best;
}
}
// Update a, b, v, w and x.
if (fu <= fx) {
if (u < x) {
b = x;
} else {
a = x;
}
v = w;
fv = fw;
w = x;
fw = fx;
x = u;
fx = fu;
} else {
if (u < x) {
a = u;
} else {
b = u;
}
if (fu <= fw ||
Precision.equals(w, x)) {
v = w;
fv = fw;
w = u;
fw = fu;
} else if (fu <= fv ||
Precision.equals(v, x) ||
Precision.equals(v, w)) {
v = u;
fv = fu;
}
}
} else { // Default termination (Brent's criterion).
return best(best,
best(previous,
current,
isMinim),
isMinim);
}
++iter;
}
}
/**
* Selects the best of two points.
*
* @param a Point and value.
* @param b Point and value.
* @param isMinim {@code true} if the selected point must be the one with
* the lowest value.
* @return the best point, or {@code null} if {@code a} and {@code b} are
* both {@code null}. When {@code a} and {@code b} have the same function
* value, {@code a} is returned.
*/
private UnivariatePointValuePair best(UnivariatePointValuePair a,
UnivariatePointValuePair b,
boolean isMinim) {
if (a == null) {
return b;
}
if (b == null) {
return a;
}
if (isMinim) {
return a.getValue() <= b.getValue() ? a : b;
} else {
return a.getValue() >= b.getValue() ? a : b;
}
}
}

View File

@ -0,0 +1,235 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.univariate;
import java.util.Arrays;
import java.util.Comparator;
import org.apache.commons.math3.exception.MathIllegalStateException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NullArgumentException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.OptimizationData;
/**
* Special implementation of the {@link UnivariateOptimizer} interface
* adding multi-start features to an existing optimizer.
* <br/>
* This class wraps an optimizer in order to use it several times in
* turn with different starting points (trying to avoid being trapped
* in a local extremum when looking for a global one).
*
* @version $Id$
* @since 3.0
*/
public class MultiStartUnivariateOptimizer
extends UnivariateOptimizer {
/** Underlying classical optimizer. */
private final UnivariateOptimizer optimizer;
/** Number of evaluations already performed for all starts. */
private int totalEvaluations;
/** Number of starts to go. */
private int starts;
/** Random generator for multi-start. */
private RandomGenerator generator;
/** Found optima. */
private UnivariatePointValuePair[] optima;
/** Optimization data. */
private OptimizationData[] optimData;
/**
* Location in {@link #optimData} where the updated maximum
* number of evaluations will be stored.
*/
private int maxEvalIndex = -1;
/**
* Location in {@link #optimData} where the updated start value
* will be stored.
*/
private int searchIntervalIndex = -1;
/**
* Create a multi-start optimizer from a single-start optimizer.
*
* @param optimizer Single-start optimizer to wrap.
* @param starts Number of starts to perform. If {@code starts == 1},
* the {@code optimize} methods will return the same solution as
* {@code optimizer} would.
* @param generator Random generator to use for restarts.
* @throws NullArgumentException if {@code optimizer} or {@code generator}
* is {@code null}.
* @throws NotStrictlyPositiveException if {@code starts < 1}.
*/
public MultiStartUnivariateOptimizer(final UnivariateOptimizer optimizer,
final int starts,
final RandomGenerator generator) {
super(optimizer.getConvergenceChecker());
if (optimizer == null ||
generator == null) {
throw new NullArgumentException();
}
if (starts < 1) {
throw new NotStrictlyPositiveException(starts);
}
this.optimizer = optimizer;
this.starts = starts;
this.generator = generator;
}
/** {@inheritDoc} */
@Override
public int getEvaluations() {
return totalEvaluations;
}
/**
* Gets all the optima found during the last call to {@code optimize}.
* The optimizer stores all the optima found during a set of
* restarts. The {@code optimize} method returns the best point only.
* This method returns all the points found at the end of each starts,
* including the best one already returned by the {@code optimize} method.
* <br/>
* The returned array as one element for each start as specified
* in the constructor. It is ordered with the results from the
* runs that did converge first, sorted from best to worst
* objective value (i.e in ascending order if minimizing and in
* descending order if maximizing), followed by {@code null} elements
* corresponding to the runs that did not converge. This means all
* elements will be {@code null} if the {@code optimize} method did throw
* an exception.
* This also means that if the first element is not {@code null}, it is
* the best point found across all starts.
*
* @return an array containing the optima.
* @throws MathIllegalStateException if {@link #optimize(OptimizationData[])
* optimize} has not been called.
*/
public UnivariatePointValuePair[] getOptima() {
if (optima == null) {
throw new MathIllegalStateException(LocalizedFormats.NO_OPTIMUM_COMPUTED_YET);
}
return optima.clone();
}
/**
* {@inheritDoc}
*
* @throws MathIllegalStateException if {@code optData} does not contain an
* instance of {@link MaxEval} or {@link SearchInterval}.
*/
@Override
public UnivariatePointValuePair optimize(OptimizationData... optData) {
// Store arguments in order to pass them to the internal optimizer.
optimData = optData;
// Set up base class and perform computations.
return super.optimize(optData);
}
/** {@inheritDoc} */
@Override
protected UnivariatePointValuePair doOptimize() {
// Remove all instances of "MaxEval" and "SearchInterval" from the
// array that will be passed to the internal optimizer.
// The former is to enforce smaller numbers of allowed evaluations
// (according to how many have been used up already), and the latter
// to impose a different start value for each start.
for (int i = 0; i < optimData.length; i++) {
if (optimData[i] instanceof MaxEval) {
optimData[i] = null;
maxEvalIndex = i;
continue;
}
if (optimData[i] instanceof SearchInterval) {
optimData[i] = null;
searchIntervalIndex = i;
continue;
}
}
if (maxEvalIndex == -1) {
throw new MathIllegalStateException();
}
if (searchIntervalIndex == -1) {
throw new MathIllegalStateException();
}
RuntimeException lastException = null;
optima = new UnivariatePointValuePair[starts];
totalEvaluations = 0;
final int maxEval = getMaxEvaluations();
final double min = getMin();
final double max = getMax();
final double startValue = getStartValue();
// Multi-start loop.
for (int i = 0; i < starts; i++) {
// CHECKSTYLE: stop IllegalCatch
try {
// Decrease number of allowed evaluations.
optimData[maxEvalIndex] = new MaxEval(maxEval - totalEvaluations);
// New start value.
final double s = (i == 0) ?
startValue :
min + generator.nextDouble() * (max - min);
optimData[searchIntervalIndex] = new SearchInterval(min, max, s);
// Optimize.
optima[i] = optimizer.optimize(optimData);
} catch (RuntimeException mue) {
lastException = mue;
optima[i] = null;
}
// CHECKSTYLE: resume IllegalCatch
totalEvaluations += optimizer.getEvaluations();
}
sortPairs(getGoalType());
if (optima[0] == null) {
throw lastException; // Cannot be null if starts >= 1.
}
// Return the point with the best objective function value.
return optima[0];
}
/**
* Sort the optima from best to worst, followed by {@code null} elements.
*
* @param goal Goal type.
*/
private void sortPairs(final GoalType goal) {
Arrays.sort(optima, new Comparator<UnivariatePointValuePair>() {
public int compare(final UnivariatePointValuePair o1,
final UnivariatePointValuePair o2) {
if (o1 == null) {
return (o2 == null) ? 0 : 1;
} else if (o2 == null) {
return -1;
}
final double v1 = o1.getValue();
final double v2 = o2.getValue();
return (goal == GoalType.MINIMIZE) ?
Double.compare(v1, v2) : Double.compare(v2, v1);
}
});
}
}

View File

@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.univariate;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException;
/**
* Search interval and (optional) start value.
* <br/>
* Immutable class.
*
* @version $Id$
* @since 3.1
*/
public class SearchInterval implements OptimizationData {
/** Lower bound. */
private final double lower;
/** Upper bound. */
private final double upper;
/** Start value. */
private final double start;
/**
* @param lo Lower bound.
* @param hi Upper bound.
* @param init Start value.
* @throws NumberIsTooLargeException if {@code lo >= hi}.
* @throws OutOfRangeException if {@code init < lo} or {@code init > hi}.
*/
public SearchInterval(double lo,
double hi,
double init) {
if (lo >= hi) {
throw new NumberIsTooLargeException(lo, hi, false);
}
if (init < lo ||
init > hi) {
throw new OutOfRangeException(init, lo, hi);
}
lower = lo;
upper = hi;
start = init;
}
/**
* @param lo Lower bound.
* @param hi Upper bound.
* @throws NumberIsTooLargeException if {@code lo >= hi}.
*/
public SearchInterval(double lo,
double hi) {
this(lo, hi, 0.5 * (lo + hi));
}
/**
* Gets the lower bound.
*
* @return the lower bound.
*/
public double getMin() {
return lower;
}
/**
* Gets the upper bound.
*
* @return the upper bound.
*/
public double getMax() {
return upper;
}
/**
* Gets the start value.
*
* @return the start value.
*/
public double getStartValue() {
return start;
}
}

View File

@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.univariate;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.optim.AbstractConvergenceChecker;
/**
* Simple implementation of the
* {@link org.apache.commons.math3.optimization.ConvergenceChecker} interface
* that uses only objective function values.
*
* Convergence is considered to have been reached if either the relative
* difference between the objective function values is smaller than a
* threshold or if either the absolute difference between the objective
* function values is smaller than another threshold.
* <br/>
* The {@link #converged(int,UnivariatePointValuePair,UnivariatePointValuePair)
* converged} method will also return {@code true} if the number of iterations
* has been set (see {@link #SimpleUnivariateValueChecker(double,double,int)
* this constructor}).
*
* @version $Id: SimpleUnivariateValueChecker.java 1413171 2012-11-24 11:11:10Z erans $
* @since 3.1
*/
public class SimpleUnivariateValueChecker
extends AbstractConvergenceChecker<UnivariatePointValuePair> {
/**
* If {@link #maxIterationCount} is set to this value, the number of
* iterations will never cause
* {@link #converged(int,UnivariatePointValuePair,UnivariatePointValuePair)}
* to return {@code true}.
*/
private static final int ITERATION_CHECK_DISABLED = -1;
/**
* Number of iterations after which the
* {@link #converged(int,UnivariatePointValuePair,UnivariatePointValuePair)}
* method will return true (unless the check is disabled).
*/
private final int maxIterationCount;
/** Build an instance with specified thresholds.
*
* In order to perform only relative checks, the absolute tolerance
* must be set to a negative value. In order to perform only absolute
* checks, the relative tolerance must be set to a negative value.
*
* @param relativeThreshold relative tolerance threshold
* @param absoluteThreshold absolute tolerance threshold
*/
public SimpleUnivariateValueChecker(final double relativeThreshold,
final double absoluteThreshold) {
super(relativeThreshold, absoluteThreshold);
maxIterationCount = ITERATION_CHECK_DISABLED;
}
/**
* Builds an instance with specified thresholds.
*
* In order to perform only relative checks, the absolute tolerance
* must be set to a negative value. In order to perform only absolute
* checks, the relative tolerance must be set to a negative value.
*
* @param relativeThreshold relative tolerance threshold
* @param absoluteThreshold absolute tolerance threshold
* @param maxIter Maximum iteration count.
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
*
* @since 3.1
*/
public SimpleUnivariateValueChecker(final double relativeThreshold,
final double absoluteThreshold,
final int maxIter) {
super(relativeThreshold, absoluteThreshold);
if (maxIter <= 0) {
throw new NotStrictlyPositiveException(maxIter);
}
maxIterationCount = maxIter;
}
/**
* Check if the optimization algorithm has converged considering the
* last two points.
* This method may be called several time from the same algorithm
* iteration with different points. This can be detected by checking the
* iteration number at each call if needed. Each time this method is
* called, the previous and current point correspond to points with the
* same role at each iteration, so they can be compared. As an example,
* simplex-based algorithms call this method for all points of the simplex,
* not only for the best or worst ones.
*
* @param iteration Index of current iteration
* @param previous Best point in the previous iteration.
* @param current Best point in the current iteration.
* @return {@code true} if the algorithm has converged.
*/
@Override
public boolean converged(final int iteration,
final UnivariatePointValuePair previous,
final UnivariatePointValuePair current) {
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
if (iteration >= maxIterationCount) {
return true;
}
}
final double p = previous.getValue();
final double c = current.getValue();
final double difference = FastMath.abs(p - c);
final double size = FastMath.max(FastMath.abs(p), FastMath.abs(c));
return difference <= size * getRelativeThreshold() ||
difference <= getAbsoluteThreshold();
}
}

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.univariate;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.optim.OptimizationData;
/**
* Scalar function to be optimized.
*
* @version $Id$
* @since 3.1
*/
public class UnivariateObjectiveFunction implements OptimizationData {
/** Function to be optimized. */
private final UnivariateFunction function;
/**
* @param f Function to be optimized.
*/
public UnivariateObjectiveFunction(UnivariateFunction f) {
function = f;
}
/**
* Gets the function to be optimized.
*
* @return the objective function.
*/
public UnivariateFunction getObjectiveFunction() {
return function;
}
}

View File

@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.univariate;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.optim.BaseOptimizer;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
/**
* Base class for a univariate scalar function optimizer.
*
* @version $Id$
* @since 3.1
*/
public abstract class UnivariateOptimizer
extends BaseOptimizer<UnivariatePointValuePair> {
/** Objective function. */
private UnivariateFunction function;
/** Type of optimization. */
private GoalType goal;
/** Initial guess. */
private double start;
/** Lower bound. */
private double min;
/** Upper bound. */
private double max;
/**
* @param checker Convergence checker.
*/
protected UnivariateOptimizer(ConvergenceChecker<UnivariatePointValuePair> checker) {
super(checker);
}
/**
* {@inheritDoc}
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link GoalType}</li>
* <li>{@link SearchInterval}</li>
* <li>{@link UnivariateObjectiveFunction}</li>
* </ul>
* @return {@inheritDoc}
* @throws TooManyEvaluationsException if the maximal number of
* evaluations is exceeded.
*/
public UnivariatePointValuePair optimize(OptimizationData... optData)
throws TooManyEvaluationsException {
// Retrieve settings.
parseOptimizationData(optData);
// Perform computation.
return super.optimize(optData);
}
/**
* @return the optimization type.
*/
public GoalType getGoalType() {
return goal;
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
*
* @param optData Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link GoalType}</li>
* <li>{@link SearchInterval}</li>
* <li>{@link UnivariateObjectiveFunction}</li>
* </ul>
*/
private void parseOptimizationData(OptimizationData... optData) {
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData) {
if (data instanceof SearchInterval) {
final SearchInterval interval = (SearchInterval) data;
min = interval.getMin();
max = interval.getMax();
start = interval.getStartValue();
continue;
}
if (data instanceof UnivariateObjectiveFunction) {
function = ((UnivariateObjectiveFunction) data).getObjectiveFunction();
continue;
}
if (data instanceof GoalType) {
goal = (GoalType) data;
continue;
}
}
}
/**
* @return the initial guess.
*/
public double getStartValue() {
return start;
}
/**
* @return the lower bounds.
*/
public double getMin() {
return min;
}
/**
* @return the upper bounds.
*/
public double getMax() {
return max;
}
/**
* Computes the objective function value.
* This method <em>must</em> be called by subclasses to enforce the
* evaluation counter limit.
*
* @param x Point at which the objective function must be evaluated.
* @return the objective function value at the specified point.
* @throws TooManyEvaluationsException if the maximal number of
* evaluations is exceeded.
*/
protected double computeObjectiveValue(double x) {
super.incrementEvaluationCount();
return function.value(x);
}
}

View File

@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.univariate;
import java.io.Serializable;
/**
* This class holds a point and the value of an objective function at this
* point.
* This is a simple immutable container.
*
* @version $Id: UnivariatePointValuePair.java 1364392 2012-07-22 18:27:12Z tn $
* @since 3.0
*/
public class UnivariatePointValuePair implements Serializable {
/** Serializable version identifier. */
private static final long serialVersionUID = 1003888396256744753L;
/** Point. */
private final double point;
/** Value of the objective function at the point. */
private final double value;
/**
* Build a point/objective function value pair.
*
* @param point Point.
* @param value Value of an objective function at the point
*/
public UnivariatePointValuePair(final double point,
final double value) {
this.point = point;
this.value = value;
}
/**
* Get the point.
*
* @return the point.
*/
public double getPoint() {
return point;
}
/**
* Get the value of the objective function.
*
* @return the stored value of the objective function.
*/
public double getValue() {
return value;
}
}

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.univariate;
/**
* One-dimensional optimization algorithms.
*/

View File

@ -119,6 +119,7 @@ INVALID_REGRESSION_ARRAY= la longueur du tableau de donn\u00e9es = {0} ne corres
INVALID_REGRESSION_OBSERVATION = la longueur du tableau de variables explicatives ({0}) ne correspond pas au nombre de variables dans le mod\u00e8le ({1})
INVALID_ROUNDING_METHOD = m\u00e9thode d''arondi {0} invalide, m\u00e9thodes valides : {1} ({2}), {3} ({4}), {5} ({6}), {7} ({8}), {9} ({10}), {11} ({12}), {13} ({14}), {15} ({16})
ITERATOR_EXHAUSTED = it\u00e9ration achev\u00e9e
ITERATIONS = it\u00e9rations
LCM_OVERFLOW_32_BITS = d\u00e9passement de capacit\u00e9 : le MCM de {0} et {1} vaut 2^31
LCM_OVERFLOW_64_BITS = d\u00e9passement de capacit\u00e9 : le MCM de {0} et {1} vaut 2^63
LIST_OF_CHROMOSOMES_BIGGER_THAN_POPULATION_SIZE = la liste des chromosomes d\u00e9passe maxPopulationSize

View File

@ -30,7 +30,7 @@ public class LocalizedFormatsTest {
@Test
public void testMessageNumber() {
Assert.assertEquals(311, LocalizedFormats.values().length);
Assert.assertEquals(312, LocalizedFormats.values().length);
}
@Test

View File

@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.fitting;
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.LevenbergMarquardtOptimizer;
import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
public class CurveFitterTest {
@Test
public void testMath303() {
LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
CurveFitter<ParametricUnivariateFunction> fitter = new CurveFitter<ParametricUnivariateFunction>(optimizer);
fitter.addObservedPoint(2.805d, 0.6934785852953367d);
fitter.addObservedPoint(2.74333333333333d, 0.6306772025518496d);
fitter.addObservedPoint(1.655d, 0.9474675497289684);
fitter.addObservedPoint(1.725d, 0.9013594835804194d);
ParametricUnivariateFunction sif = new SimpleInverseFunction();
double[] initialguess1 = new double[1];
initialguess1[0] = 1.0d;
Assert.assertEquals(1, fitter.fit(sif, initialguess1).length);
double[] initialguess2 = new double[2];
initialguess2[0] = 1.0d;
initialguess2[1] = .5d;
Assert.assertEquals(2, fitter.fit(sif, initialguess2).length);
}
@Test
public void testMath304() {
LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
CurveFitter<ParametricUnivariateFunction> fitter = new CurveFitter<ParametricUnivariateFunction>(optimizer);
fitter.addObservedPoint(2.805d, 0.6934785852953367d);
fitter.addObservedPoint(2.74333333333333d, 0.6306772025518496d);
fitter.addObservedPoint(1.655d, 0.9474675497289684);
fitter.addObservedPoint(1.725d, 0.9013594835804194d);
ParametricUnivariateFunction sif = new SimpleInverseFunction();
double[] initialguess1 = new double[1];
initialguess1[0] = 1.0d;
Assert.assertEquals(1.6357215104109237, fitter.fit(sif, initialguess1)[0], 1.0e-14);
double[] initialguess2 = new double[1];
initialguess2[0] = 10.0d;
Assert.assertEquals(1.6357215104109237, fitter.fit(sif, initialguess1)[0], 1.0e-14);
}
@Test
public void testMath372() {
LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
CurveFitter<ParametricUnivariateFunction> curveFitter = new CurveFitter<ParametricUnivariateFunction>(optimizer);
curveFitter.addObservedPoint( 15, 4443);
curveFitter.addObservedPoint( 31, 8493);
curveFitter.addObservedPoint( 62, 17586);
curveFitter.addObservedPoint(125, 30582);
curveFitter.addObservedPoint(250, 45087);
curveFitter.addObservedPoint(500, 50683);
ParametricUnivariateFunction f = new ParametricUnivariateFunction() {
public double value(double x, double ... parameters) {
double a = parameters[0];
double b = parameters[1];
double c = parameters[2];
double d = parameters[3];
return d + ((a - d) / (1 + FastMath.pow(x / c, b)));
}
public double[] gradient(double x, double ... parameters) {
double a = parameters[0];
double b = parameters[1];
double c = parameters[2];
double d = parameters[3];
double[] gradients = new double[4];
double den = 1 + FastMath.pow(x / c, b);
// derivative with respect to a
gradients[0] = 1 / den;
// derivative with respect to b
// in the reported (invalid) issue, there was a sign error here
gradients[1] = -((a - d) * FastMath.pow(x / c, b) * FastMath.log(x / c)) / (den * den);
// derivative with respect to c
gradients[2] = (b * FastMath.pow(x / c, b - 1) * (x / (c * c)) * (a - d)) / (den * den);
// derivative with respect to d
gradients[3] = 1 - (1 / den);
return gradients;
}
};
double[] initialGuess = new double[] { 1500, 0.95, 65, 35000 };
double[] estimatedParameters = curveFitter.fit(f, initialGuess);
Assert.assertEquals( 2411.00, estimatedParameters[0], 500.00);
Assert.assertEquals( 1.62, estimatedParameters[1], 0.04);
Assert.assertEquals( 111.22, estimatedParameters[2], 0.30);
Assert.assertEquals(55347.47, estimatedParameters[3], 300.00);
Assert.assertTrue(optimizer.getRMS() < 600.0);
}
private static class SimpleInverseFunction implements ParametricUnivariateFunction {
public double value(double x, double ... parameters) {
return parameters[0] / x + (parameters.length < 2 ? 0 : parameters[1]);
}
public double[] gradient(double x, double ... doubles) {
double[] gradientVector = new double[doubles.length];
gradientVector[0] = 1 / x;
if (doubles.length >= 2) {
gradientVector[1] = 1;
}
return gradientVector;
}
}
}

View File

@ -0,0 +1,363 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.fitting;
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.LevenbergMarquardtOptimizer;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.junit.Assert;
import org.junit.Test;
/**
* Tests {@link GaussianFitter}.
*
* @since 2.2
* @version $Id: GaussianFitterTest.java 1349707 2012-06-13 09:30:56Z erans $
*/
public class GaussianFitterTest {
/** Good data. */
protected static final double[][] DATASET1 = new double[][] {
{4.0254623, 531026.0},
{4.02804905, 664002.0},
{4.02934242, 787079.0},
{4.03128248, 984167.0},
{4.03386923, 1294546.0},
{4.03580929, 1560230.0},
{4.03839603, 1887233.0},
{4.0396894, 2113240.0},
{4.04162946, 2375211.0},
{4.04421621, 2687152.0},
{4.04550958, 2862644.0},
{4.04744964, 3078898.0},
{4.05003639, 3327238.0},
{4.05132976, 3461228.0},
{4.05326982, 3580526.0},
{4.05585657, 3576946.0},
{4.05779662, 3439750.0},
{4.06038337, 3220296.0},
{4.06167674, 3070073.0},
{4.0636168, 2877648.0},
{4.06620355, 2595848.0},
{4.06749692, 2390157.0},
{4.06943698, 2175960.0},
{4.07202373, 1895104.0},
{4.0733171, 1687576.0},
{4.07525716, 1447024.0},
{4.0778439, 1130879.0},
{4.07978396, 904900.0},
{4.08237071, 717104.0},
{4.08366408, 620014.0}
};
/** Poor data: right of peak not symmetric with left of peak. */
protected static final double[][] DATASET2 = new double[][] {
{-20.15, 1523.0},
{-19.65, 1566.0},
{-19.15, 1592.0},
{-18.65, 1927.0},
{-18.15, 3089.0},
{-17.65, 6068.0},
{-17.15, 14239.0},
{-16.65, 34124.0},
{-16.15, 64097.0},
{-15.65, 110352.0},
{-15.15, 164742.0},
{-14.65, 209499.0},
{-14.15, 267274.0},
{-13.65, 283290.0},
{-13.15, 275363.0},
{-12.65, 258014.0},
{-12.15, 225000.0},
{-11.65, 200000.0},
{-11.15, 190000.0},
{-10.65, 185000.0},
{-10.15, 180000.0},
{ -9.65, 179000.0},
{ -9.15, 178000.0},
{ -8.65, 177000.0},
{ -8.15, 176000.0},
{ -7.65, 175000.0},
{ -7.15, 174000.0},
{ -6.65, 173000.0},
{ -6.15, 172000.0},
{ -5.65, 171000.0},
{ -5.15, 170000.0}
};
/** Poor data: long tails. */
protected static final double[][] DATASET3 = new double[][] {
{-90.15, 1513.0},
{-80.15, 1514.0},
{-70.15, 1513.0},
{-60.15, 1514.0},
{-50.15, 1513.0},
{-40.15, 1514.0},
{-30.15, 1513.0},
{-20.15, 1523.0},
{-19.65, 1566.0},
{-19.15, 1592.0},
{-18.65, 1927.0},
{-18.15, 3089.0},
{-17.65, 6068.0},
{-17.15, 14239.0},
{-16.65, 34124.0},
{-16.15, 64097.0},
{-15.65, 110352.0},
{-15.15, 164742.0},
{-14.65, 209499.0},
{-14.15, 267274.0},
{-13.65, 283290.0},
{-13.15, 275363.0},
{-12.65, 258014.0},
{-12.15, 214073.0},
{-11.65, 182244.0},
{-11.15, 136419.0},
{-10.65, 97823.0},
{-10.15, 58930.0},
{ -9.65, 35404.0},
{ -9.15, 16120.0},
{ -8.65, 9823.0},
{ -8.15, 5064.0},
{ -7.65, 2575.0},
{ -7.15, 1642.0},
{ -6.65, 1101.0},
{ -6.15, 812.0},
{ -5.65, 690.0},
{ -5.15, 565.0},
{ 5.15, 564.0},
{ 15.15, 565.0},
{ 25.15, 564.0},
{ 35.15, 565.0},
{ 45.15, 564.0},
{ 55.15, 565.0},
{ 65.15, 564.0},
{ 75.15, 565.0}
};
/** Poor data: right of peak is missing. */
protected static final double[][] DATASET4 = new double[][] {
{-20.15, 1523.0},
{-19.65, 1566.0},
{-19.15, 1592.0},
{-18.65, 1927.0},
{-18.15, 3089.0},
{-17.65, 6068.0},
{-17.15, 14239.0},
{-16.65, 34124.0},
{-16.15, 64097.0},
{-15.65, 110352.0},
{-15.15, 164742.0},
{-14.65, 209499.0},
{-14.15, 267274.0},
{-13.65, 283290.0}
};
/** Good data, but few points. */
protected static final double[][] DATASET5 = new double[][] {
{4.0254623, 531026.0},
{4.03128248, 984167.0},
{4.03839603, 1887233.0},
{4.04421621, 2687152.0},
{4.05132976, 3461228.0},
{4.05326982, 3580526.0},
{4.05779662, 3439750.0},
{4.0636168, 2877648.0},
{4.06943698, 2175960.0},
{4.07525716, 1447024.0},
{4.08237071, 717104.0},
{4.08366408, 620014.0}
};
/**
* Basic.
*/
@Test
public void testFit01() {
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
addDatasetToGaussianFitter(DATASET1, fitter);
double[] parameters = fitter.fit();
Assert.assertEquals(3496978.1837704973, parameters[0], 1e-4);
Assert.assertEquals(4.054933085999146, parameters[1], 1e-4);
Assert.assertEquals(0.015039355620304326, parameters[2], 1e-4);
}
/**
* Zero points is not enough observed points.
*/
@Test(expected=MathIllegalArgumentException.class)
public void testFit02() {
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
fitter.fit();
}
/**
* Two points is not enough observed points.
*/
@Test(expected=MathIllegalArgumentException.class)
public void testFit03() {
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
addDatasetToGaussianFitter(new double[][] {
{4.0254623, 531026.0},
{4.02804905, 664002.0}},
fitter);
fitter.fit();
}
/**
* Poor data: right of peak not symmetric with left of peak.
*/
@Test
public void testFit04() {
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
addDatasetToGaussianFitter(DATASET2, fitter);
double[] parameters = fitter.fit();
Assert.assertEquals(233003.2967252038, parameters[0], 1e-4);
Assert.assertEquals(-10.654887521095983, parameters[1], 1e-4);
Assert.assertEquals(4.335937353196641, parameters[2], 1e-4);
}
/**
* Poor data: long tails.
*/
@Test
public void testFit05() {
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
addDatasetToGaussianFitter(DATASET3, fitter);
double[] parameters = fitter.fit();
Assert.assertEquals(283863.81929180305, parameters[0], 1e-4);
Assert.assertEquals(-13.29641995105174, parameters[1], 1e-4);
Assert.assertEquals(1.7297330293549908, parameters[2], 1e-4);
}
/**
* Poor data: right of peak is missing.
*/
@Test
public void testFit06() {
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
addDatasetToGaussianFitter(DATASET4, fitter);
double[] parameters = fitter.fit();
Assert.assertEquals(285250.66754309234, parameters[0], 1e-4);
Assert.assertEquals(-13.528375695228455, parameters[1], 1e-4);
Assert.assertEquals(1.5204344894331614, parameters[2], 1e-4);
}
/**
* Basic with smaller dataset.
*/
@Test
public void testFit07() {
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
addDatasetToGaussianFitter(DATASET5, fitter);
double[] parameters = fitter.fit();
Assert.assertEquals(3514384.729342235, parameters[0], 1e-4);
Assert.assertEquals(4.054970307455625, parameters[1], 1e-4);
Assert.assertEquals(0.015029412832160017, parameters[2], 1e-4);
}
@Test
public void testMath519() {
// The optimizer will try negative sigma values but "GaussianFitter"
// will catch the raised exceptions and return NaN values instead.
final double[] data = {
1.1143831578403364E-29,
4.95281403484594E-28,
1.1171347211930288E-26,
1.7044813962636277E-25,
1.9784716574832164E-24,
1.8630236407866774E-23,
1.4820532905097742E-22,
1.0241963854632831E-21,
6.275077366673128E-21,
3.461808994532493E-20,
1.7407124684715706E-19,
8.056687953553974E-19,
3.460193945992071E-18,
1.3883326374011525E-17,
5.233894983671116E-17,
1.8630791465263745E-16,
6.288759227922111E-16,
2.0204433920597856E-15,
6.198768938576155E-15,
1.821419346860626E-14,
5.139176445538471E-14,
1.3956427429045787E-13,
3.655705706448139E-13,
9.253753324779779E-13,
2.267636001476696E-12,
5.3880460095836855E-12,
1.2431632654852931E-11
};
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
for (int i = 0; i < data.length; i++) {
fitter.addObservedPoint(i, data[i]);
}
final double[] p = fitter.fit();
Assert.assertEquals(53.1572792, p[1], 1e-7);
Assert.assertEquals(5.75214622, p[2], 1e-8);
}
@Test
public void testMath798() {
final GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
// When the data points are not commented out below, the fit stalls.
// This is expected however, since the whole dataset hardly looks like
// a Gaussian.
// When commented out, the fit proceeds fine.
fitter.addObservedPoint(0.23, 395.0);
//fitter.addObservedPoint(0.68, 0.0);
fitter.addObservedPoint(1.14, 376.0);
//fitter.addObservedPoint(1.59, 0.0);
fitter.addObservedPoint(2.05, 163.0);
//fitter.addObservedPoint(2.50, 0.0);
fitter.addObservedPoint(2.95, 49.0);
//fitter.addObservedPoint(3.41, 0.0);
fitter.addObservedPoint(3.86, 16.0);
//fitter.addObservedPoint(4.32, 0.0);
fitter.addObservedPoint(4.77, 1.0);
final double[] p = fitter.fit();
// Values are copied from a previous run of this test.
Assert.assertEquals(420.8397296167364, p[0], 1e-12);
Assert.assertEquals(0.603770729862231, p[1], 1e-15);
Assert.assertEquals(1.0786447936766612, p[2], 1e-14);
}
/**
* Adds the specified points to specified <code>GaussianFitter</code>
* instance.
*
* @param points data points where first dimension is a point index and
* second dimension is an array of length two representing the point
* with the first value corresponding to X and the second value
* corresponding to Y
* @param fitter fitter to which the points in <code>points</code> should be
* added as observed points
*/
protected static void addDatasetToGaussianFitter(double[][] points,
GaussianFitter fitter) {
for (int i = 0; i < points.length; i++) {
fitter.addObservedPoint(points[i][0], points[i][1]);
}
}
}

View File

@ -0,0 +1,185 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.fitting;
import java.util.Random;
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.LevenbergMarquardtOptimizer;
import org.apache.commons.math3.analysis.function.HarmonicOscillator;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.exception.MathIllegalStateException;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.MathUtils;
import org.junit.Test;
import org.junit.Assert;
public class HarmonicFitterTest {
@Test(expected=NumberIsTooSmallException.class)
public void testPreconditions1() {
HarmonicFitter fitter =
new HarmonicFitter(new LevenbergMarquardtOptimizer());
fitter.fit();
}
@Test
public void testNoError() {
final double a = 0.2;
final double w = 3.4;
final double p = 4.1;
HarmonicOscillator f = new HarmonicOscillator(a, w, p);
HarmonicFitter fitter =
new HarmonicFitter(new LevenbergMarquardtOptimizer());
for (double x = 0.0; x < 1.3; x += 0.01) {
fitter.addObservedPoint(1, x, f.value(x));
}
final double[] fitted = fitter.fit();
Assert.assertEquals(a, fitted[0], 1.0e-13);
Assert.assertEquals(w, fitted[1], 1.0e-13);
Assert.assertEquals(p, MathUtils.normalizeAngle(fitted[2], p), 1e-13);
HarmonicOscillator ff = new HarmonicOscillator(fitted[0], fitted[1], fitted[2]);
for (double x = -1.0; x < 1.0; x += 0.01) {
Assert.assertTrue(FastMath.abs(f.value(x) - ff.value(x)) < 1e-13);
}
}
@Test
public void test1PercentError() {
Random randomizer = new Random(64925784252l);
final double a = 0.2;
final double w = 3.4;
final double p = 4.1;
HarmonicOscillator f = new HarmonicOscillator(a, w, p);
HarmonicFitter fitter =
new HarmonicFitter(new LevenbergMarquardtOptimizer());
for (double x = 0.0; x < 10.0; x += 0.1) {
fitter.addObservedPoint(1, x,
f.value(x) + 0.01 * randomizer.nextGaussian());
}
final double[] fitted = fitter.fit();
Assert.assertEquals(a, fitted[0], 7.6e-4);
Assert.assertEquals(w, fitted[1], 2.7e-3);
Assert.assertEquals(p, MathUtils.normalizeAngle(fitted[2], p), 1.3e-2);
}
@Test
public void testTinyVariationsData() {
Random randomizer = new Random(64925784252l);
HarmonicFitter fitter =
new HarmonicFitter(new LevenbergMarquardtOptimizer());
for (double x = 0.0; x < 10.0; x += 0.1) {
fitter.addObservedPoint(1, x, 1e-7 * randomizer.nextGaussian());
}
fitter.fit();
// This test serves to cover the part of the code of "guessAOmega"
// when the algorithm using integrals fails.
}
@Test
public void testInitialGuess() {
Random randomizer = new Random(45314242l);
final double a = 0.2;
final double w = 3.4;
final double p = 4.1;
HarmonicOscillator f = new HarmonicOscillator(a, w, p);
HarmonicFitter fitter =
new HarmonicFitter(new LevenbergMarquardtOptimizer());
for (double x = 0.0; x < 10.0; x += 0.1) {
fitter.addObservedPoint(1, x,
f.value(x) + 0.01 * randomizer.nextGaussian());
}
final double[] fitted = fitter.fit(new double[] { 0.15, 3.6, 4.5 });
Assert.assertEquals(a, fitted[0], 1.2e-3);
Assert.assertEquals(w, fitted[1], 3.3e-3);
Assert.assertEquals(p, MathUtils.normalizeAngle(fitted[2], p), 1.7e-2);
}
@Test
public void testUnsorted() {
Random randomizer = new Random(64925784252l);
final double a = 0.2;
final double w = 3.4;
final double p = 4.1;
HarmonicOscillator f = new HarmonicOscillator(a, w, p);
HarmonicFitter fitter =
new HarmonicFitter(new LevenbergMarquardtOptimizer());
// build a regularly spaced array of measurements
int size = 100;
double[] xTab = new double[size];
double[] yTab = new double[size];
for (int i = 0; i < size; ++i) {
xTab[i] = 0.1 * i;
yTab[i] = f.value(xTab[i]) + 0.01 * randomizer.nextGaussian();
}
// shake it
for (int i = 0; i < size; ++i) {
int i1 = randomizer.nextInt(size);
int i2 = randomizer.nextInt(size);
double xTmp = xTab[i1];
double yTmp = yTab[i1];
xTab[i1] = xTab[i2];
yTab[i1] = yTab[i2];
xTab[i2] = xTmp;
yTab[i2] = yTmp;
}
// pass it to the fitter
for (int i = 0; i < size; ++i) {
fitter.addObservedPoint(1, xTab[i], yTab[i]);
}
final double[] fitted = fitter.fit();
Assert.assertEquals(a, fitted[0], 7.6e-4);
Assert.assertEquals(w, fitted[1], 3.5e-3);
Assert.assertEquals(p, MathUtils.normalizeAngle(fitted[2], p), 1.5e-2);
}
@Test(expected=MathIllegalStateException.class)
public void testMath844() {
final double[] y = { 0, 1, 2, 3, 2, 1,
0, -1, -2, -3, -2, -1,
0, 1, 2, 3, 2, 1,
0, -1, -2, -3, -2, -1,
0, 1, 2, 3, 2, 1, 0 };
final int len = y.length;
final WeightedObservedPoint[] points = new WeightedObservedPoint[len];
for (int i = 0; i < len; i++) {
points[i] = new WeightedObservedPoint(1, i, y[i]);
}
// The guesser fails because the function is far from an harmonic
// function: It is a triangular periodic function with amplitude 3
// and period 12, and all sample points are taken at integer abscissae
// so function values all belong to the integer subset {-3, -2, -1, 0,
// 1, 2, 3}.
final HarmonicFitter.ParameterGuesser guesser
= new HarmonicFitter.ParameterGuesser(points);
}
}

View File

@ -0,0 +1,256 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.fitting;
import java.util.Random;
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction.Parametric;
import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.LevenbergMarquardtOptimizer;
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.GaussNewtonOptimizer;
import org.apache.commons.math3.optim.SimpleVectorValueChecker;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.TestUtils;
import org.junit.Test;
import org.junit.Assert;
/**
* Test for class {@link CurveFitter} where the function to fit is a
* polynomial.
*/
public class PolynomialFitterTest {
@Test
public void testFit() {
final RealDistribution rng = new UniformRealDistribution(-100, 100);
rng.reseedRandomGenerator(64925784252L);
final LevenbergMarquardtOptimizer optim = new LevenbergMarquardtOptimizer();
final PolynomialFitter fitter = new PolynomialFitter(optim);
final double[] coeff = { 12.9, -3.4, 2.1 }; // 12.9 - 3.4 x + 2.1 x^2
final PolynomialFunction f = new PolynomialFunction(coeff);
// Collect data from a known polynomial.
for (int i = 0; i < 100; i++) {
final double x = rng.sample();
fitter.addObservedPoint(x, f.value(x));
}
// Start fit from initial guesses that are far from the optimal values.
final double[] best = fitter.fit(new double[] { -1e-20, 3e15, -5e25 });
TestUtils.assertEquals("best != coeff", coeff, best, 1e-12);
}
@Test
public void testNoError() {
Random randomizer = new Random(64925784252l);
for (int degree = 1; degree < 10; ++degree) {
PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
PolynomialFitter fitter = new PolynomialFitter(new LevenbergMarquardtOptimizer());
for (int i = 0; i <= degree; ++i) {
fitter.addObservedPoint(1.0, i, p.value(i));
}
final double[] init = new double[degree + 1];
PolynomialFunction fitted = new PolynomialFunction(fitter.fit(init));
for (double x = -1.0; x < 1.0; x += 0.01) {
double error = FastMath.abs(p.value(x) - fitted.value(x)) /
(1.0 + FastMath.abs(p.value(x)));
Assert.assertEquals(0.0, error, 1.0e-6);
}
}
}
@Test
public void testSmallError() {
Random randomizer = new Random(53882150042l);
double maxError = 0;
for (int degree = 0; degree < 10; ++degree) {
PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
PolynomialFitter fitter = new PolynomialFitter(new LevenbergMarquardtOptimizer());
for (double x = -1.0; x < 1.0; x += 0.01) {
fitter.addObservedPoint(1.0, x,
p.value(x) + 0.1 * randomizer.nextGaussian());
}
final double[] init = new double[degree + 1];
PolynomialFunction fitted = new PolynomialFunction(fitter.fit(init));
for (double x = -1.0; x < 1.0; x += 0.01) {
double error = FastMath.abs(p.value(x) - fitted.value(x)) /
(1.0 + FastMath.abs(p.value(x)));
maxError = FastMath.max(maxError, error);
Assert.assertTrue(FastMath.abs(error) < 0.1);
}
}
Assert.assertTrue(maxError > 0.01);
}
@Test
public void testMath798() {
final double tol = 1e-14;
final SimpleVectorValueChecker checker = new SimpleVectorValueChecker(tol, tol);
final double[] init = new double[] { 0, 0 };
final int maxEval = 3;
final double[] lm = doMath798(new LevenbergMarquardtOptimizer(checker), maxEval, init);
final double[] gn = doMath798(new GaussNewtonOptimizer(checker), maxEval, init);
for (int i = 0; i <= 1; i++) {
Assert.assertEquals(lm[i], gn[i], tol);
}
}
/**
* This test shows that the user can set the maximum number of iterations
* to avoid running for too long.
* But in the test case, the real problem is that the tolerance is way too
* stringent.
*/
@Test(expected=TooManyEvaluationsException.class)
public void testMath798WithToleranceTooLow() {
final double tol = 1e-100;
final SimpleVectorValueChecker checker = new SimpleVectorValueChecker(tol, tol);
final double[] init = new double[] { 0, 0 };
final int maxEval = 10000; // Trying hard to fit.
final double[] gn = doMath798(new GaussNewtonOptimizer(checker), maxEval, init);
}
/**
* This test shows that the user can set the maximum number of iterations
* to avoid running for too long.
* Even if the real problem is that the tolerance is way too stringent, it
* is possible to get the best solution so far, i.e. a checker will return
* the point when the maximum iteration count has been reached.
*/
@Test
public void testMath798WithToleranceTooLowButNoException() {
final double tol = 1e-100;
final double[] init = new double[] { 0, 0 };
final int maxEval = 10000; // Trying hard to fit.
final SimpleVectorValueChecker checker = new SimpleVectorValueChecker(tol, tol, maxEval);
final double[] lm = doMath798(new LevenbergMarquardtOptimizer(checker), maxEval, init);
final double[] gn = doMath798(new GaussNewtonOptimizer(checker), maxEval, init);
for (int i = 0; i <= 1; i++) {
Assert.assertEquals(lm[i], gn[i], 1e-15);
}
}
/**
* @param optimizer Optimizer.
* @param maxEval Maximum number of function evaluations.
* @param init First guess.
* @return the solution found by the given optimizer.
*/
private double[] doMath798(MultivariateVectorOptimizer optimizer,
int maxEval,
double[] init) {
final CurveFitter<Parametric> fitter = new CurveFitter<Parametric>(optimizer);
fitter.addObservedPoint(-0.2, -7.12442E-13);
fitter.addObservedPoint(-0.199, -4.33397E-13);
fitter.addObservedPoint(-0.198, -2.823E-13);
fitter.addObservedPoint(-0.197, -1.40405E-13);
fitter.addObservedPoint(-0.196, -7.80821E-15);
fitter.addObservedPoint(-0.195, 6.20484E-14);
fitter.addObservedPoint(-0.194, 7.24673E-14);
fitter.addObservedPoint(-0.193, 1.47152E-13);
fitter.addObservedPoint(-0.192, 1.9629E-13);
fitter.addObservedPoint(-0.191, 2.12038E-13);
fitter.addObservedPoint(-0.19, 2.46906E-13);
fitter.addObservedPoint(-0.189, 2.77495E-13);
fitter.addObservedPoint(-0.188, 2.51281E-13);
fitter.addObservedPoint(-0.187, 2.64001E-13);
fitter.addObservedPoint(-0.186, 2.8882E-13);
fitter.addObservedPoint(-0.185, 3.13604E-13);
fitter.addObservedPoint(-0.184, 3.14248E-13);
fitter.addObservedPoint(-0.183, 3.1172E-13);
fitter.addObservedPoint(-0.182, 3.12912E-13);
fitter.addObservedPoint(-0.181, 3.06761E-13);
fitter.addObservedPoint(-0.18, 2.8559E-13);
fitter.addObservedPoint(-0.179, 2.86806E-13);
fitter.addObservedPoint(-0.178, 2.985E-13);
fitter.addObservedPoint(-0.177, 2.67148E-13);
fitter.addObservedPoint(-0.176, 2.94173E-13);
fitter.addObservedPoint(-0.175, 3.27528E-13);
fitter.addObservedPoint(-0.174, 3.33858E-13);
fitter.addObservedPoint(-0.173, 2.97511E-13);
fitter.addObservedPoint(-0.172, 2.8615E-13);
fitter.addObservedPoint(-0.171, 2.84624E-13);
final double[] coeff = fitter.fit(maxEval,
new PolynomialFunction.Parametric(),
init);
return coeff;
}
@Test
public void testRedundantSolvable() {
// Levenberg-Marquardt should handle redundant information gracefully
checkUnsolvableProblem(new LevenbergMarquardtOptimizer(), true);
}
@Test
public void testRedundantUnsolvable() {
// Gauss-Newton should not be able to solve redundant information
checkUnsolvableProblem(new GaussNewtonOptimizer(true, new SimpleVectorValueChecker(1e-15, 1e-15)), false);
}
private void checkUnsolvableProblem(MultivariateVectorOptimizer optimizer,
boolean solvable) {
Random randomizer = new Random(1248788532l);
for (int degree = 0; degree < 10; ++degree) {
PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
PolynomialFitter fitter = new PolynomialFitter(optimizer);
// reusing the same point over and over again does not bring
// information, the problem cannot be solved in this case for
// degrees greater than 1 (but one point is sufficient for
// degree 0)
for (double x = -1.0; x < 1.0; x += 0.01) {
fitter.addObservedPoint(1.0, 0.0, p.value(0.0));
}
try {
final double[] init = new double[degree + 1];
fitter.fit(init);
Assert.assertTrue(solvable || (degree == 0));
} catch(ConvergenceException e) {
Assert.assertTrue((! solvable) && (degree > 0));
}
}
}
private PolynomialFunction buildRandomPolynomial(int degree, Random randomizer) {
final double[] coefficients = new double[degree + 1];
for (int i = 0; i <= degree; ++i) {
coefficients[i] = randomizer.nextGaussian();
}
return new PolynomialFunction(coefficients);
}
}

View File

@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.TestUtils;
import org.junit.Assert;
import org.junit.Test;
public class PointValuePairTest {
@Test
public void testSerial() {
PointValuePair pv1 = new PointValuePair(new double[] { 1.0, 2.0, 3.0 }, 4.0);
PointValuePair pv2 = (PointValuePair) TestUtils.serializeAndRecover(pv1);
Assert.assertEquals(pv1.getKey().length, pv2.getKey().length);
for (int i = 0; i < pv1.getKey().length; ++i) {
Assert.assertEquals(pv1.getKey()[i], pv2.getKey()[i], 1.0e-15);
}
Assert.assertEquals(pv1.getValue(), pv2.getValue(), 1.0e-15);
}
}

View File

@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.TestUtils;
import org.junit.Assert;
import org.junit.Test;
public class PointVectorValuePairTest {
@Test
public void testSerial() {
PointVectorValuePair pv1 = new PointVectorValuePair(new double[] { 1.0, 2.0, 3.0 },
new double[] { 4.0, 5.0 });
PointVectorValuePair pv2 = (PointVectorValuePair) TestUtils.serializeAndRecover(pv1);
Assert.assertEquals(pv1.getKey().length, pv2.getKey().length);
for (int i = 0; i < pv1.getKey().length; ++i) {
Assert.assertEquals(pv1.getKey()[i], pv2.getKey()[i], 1.0e-15);
}
Assert.assertEquals(pv1.getValue().length, pv2.getValue().length);
for (int i = 0; i < pv1.getValue().length; ++i) {
Assert.assertEquals(pv1.getValue()[i], pv2.getValue()[i], 1.0e-15);
}
}
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.junit.Test;
import org.junit.Assert;
public class SimplePointCheckerTest {
@Test(expected=NotStrictlyPositiveException.class)
public void testIterationCheckPrecondition() {
new SimplePointChecker<PointValuePair>(1e-1, 1e-2, 0);
}
@Test
public void testIterationCheck() {
final int max = 10;
final SimplePointChecker<PointValuePair> checker
= new SimplePointChecker<PointValuePair>(1e-1, 1e-2, max);
Assert.assertTrue(checker.converged(max, null, null));
Assert.assertTrue(checker.converged(max + 1, null, null));
}
@Test
public void testIterationCheckDisabled() {
final SimplePointChecker<PointValuePair> checker
= new SimplePointChecker<PointValuePair>(1e-8, 1e-8);
final PointValuePair a = new PointValuePair(new double[] { 1d }, 1d);
final PointValuePair b = new PointValuePair(new double[] { 10d }, 10d);
Assert.assertFalse(checker.converged(-1, a, b));
Assert.assertFalse(checker.converged(0, a, b));
Assert.assertFalse(checker.converged(1000000, a, b));
Assert.assertTrue(checker.converged(-1, a, a));
Assert.assertTrue(checker.converged(-1, b, b));
}
}

View File

@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.junit.Test;
import org.junit.Assert;
public class SimpleValueCheckerTest {
@Test(expected=NotStrictlyPositiveException.class)
public void testIterationCheckPrecondition() {
new SimpleValueChecker(1e-1, 1e-2, 0);
}
@Test
public void testIterationCheck() {
final int max = 10;
final SimpleValueChecker checker = new SimpleValueChecker(1e-1, 1e-2, max);
Assert.assertTrue(checker.converged(max, null, null));
Assert.assertTrue(checker.converged(max + 1, null, null));
}
@Test
public void testIterationCheckDisabled() {
final SimpleValueChecker checker = new SimpleValueChecker(1e-8, 1e-8);
final PointValuePair a = new PointValuePair(new double[] { 1d }, 1d);
final PointValuePair b = new PointValuePair(new double[] { 10d }, 10d);
Assert.assertFalse(checker.converged(-1, a, b));
Assert.assertFalse(checker.converged(0, a, b));
Assert.assertFalse(checker.converged(1000000, a, b));
Assert.assertTrue(checker.converged(-1, a, a));
Assert.assertTrue(checker.converged(-1, b, b));
}
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.junit.Test;
import org.junit.Assert;
public class SimpleVectorValueCheckerTest {
@Test(expected=NotStrictlyPositiveException.class)
public void testIterationCheckPrecondition() {
new SimpleVectorValueChecker(1e-1, 1e-2, 0);
}
@Test
public void testIterationCheck() {
final int max = 10;
final SimpleVectorValueChecker checker = new SimpleVectorValueChecker(1e-1, 1e-2, max);
Assert.assertTrue(checker.converged(max, null, null));
Assert.assertTrue(checker.converged(max + 1, null, null));
}
@Test
public void testIterationCheckDisabled() {
final SimpleVectorValueChecker checker = new SimpleVectorValueChecker(1e-8, 1e-8);
final PointVectorValuePair a = new PointVectorValuePair(new double[] { 1d },
new double[] { 1d });
final PointVectorValuePair b = new PointVectorValuePair(new double[] { 10d },
new double[] { 10d });
Assert.assertFalse(checker.converged(-1, a, b));
Assert.assertFalse(checker.converged(0, a, b));
Assert.assertFalse(checker.converged(1000000, a, b));
Assert.assertTrue(checker.converged(-1, a, a));
Assert.assertTrue(checker.converged(-1, b, b));
}
}

View File

@ -0,0 +1,664 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.commons.math3.optim.MaxIter;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.util.Precision;
import org.junit.Test;
import org.junit.Assert;
public class SimplexSolverTest {
private static final MaxIter DEFAULT_MAX_ITER = new MaxIter(100);
@Test
public void testMath828() {
LinearObjectiveFunction f = new LinearObjectiveFunction(
new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 0.0);
ArrayList <LinearConstraint>constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] {0.0, 39.0, 23.0, 96.0, 15.0, 48.0, 9.0, 21.0, 48.0, 36.0, 76.0, 19.0, 88.0, 17.0, 16.0, 36.0,}, Relationship.GEQ, 15.0));
constraints.add(new LinearConstraint(new double[] {0.0, 59.0, 93.0, 12.0, 29.0, 78.0, 73.0, 87.0, 32.0, 70.0, 68.0, 24.0, 11.0, 26.0, 65.0, 25.0,}, Relationship.GEQ, 29.0));
constraints.add(new LinearConstraint(new double[] {0.0, 74.0, 5.0, 82.0, 6.0, 97.0, 55.0, 44.0, 52.0, 54.0, 5.0, 93.0, 91.0, 8.0, 20.0, 97.0,}, Relationship.GEQ, 6.0));
constraints.add(new LinearConstraint(new double[] {8.0, -3.0, -28.0, -72.0, -8.0, -31.0, -31.0, -74.0, -47.0, -59.0, -24.0, -57.0, -56.0, -16.0, -92.0, -59.0,}, Relationship.GEQ, 0.0));
constraints.add(new LinearConstraint(new double[] {25.0, -7.0, -99.0, -78.0, -25.0, -14.0, -16.0, -89.0, -39.0, -56.0, -53.0, -9.0, -18.0, -26.0, -11.0, -61.0,}, Relationship.GEQ, 0.0));
constraints.add(new LinearConstraint(new double[] {33.0, -95.0, -15.0, -4.0, -33.0, -3.0, -20.0, -96.0, -27.0, -13.0, -80.0, -24.0, -3.0, -13.0, -57.0, -76.0,}, Relationship.GEQ, 0.0));
constraints.add(new LinearConstraint(new double[] {7.0, -95.0, -39.0, -93.0, -7.0, -94.0, -94.0, -62.0, -76.0, -26.0, -53.0, -57.0, -31.0, -76.0, -53.0, -52.0,}, Relationship.GEQ, 0.0));
double epsilon = 1e-6;
PointValuePair solution = new SimplexSolver().optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(1.0d, solution.getValue(), epsilon);
Assert.assertTrue(validSolution(solution, constraints, epsilon));
}
@Test
public void testMath828Cycle() {
LinearObjectiveFunction f = new LinearObjectiveFunction(
new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 0.0);
ArrayList <LinearConstraint>constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] {0.0, 16.0, 14.0, 69.0, 1.0, 85.0, 52.0, 43.0, 64.0, 97.0, 14.0, 74.0, 89.0, 28.0, 94.0, 58.0, 13.0, 22.0, 21.0, 17.0, 30.0, 25.0, 1.0, 59.0, 91.0, 78.0, 12.0, 74.0, 56.0, 3.0, 88.0,}, Relationship.GEQ, 91.0));
constraints.add(new LinearConstraint(new double[] {0.0, 60.0, 40.0, 81.0, 71.0, 72.0, 46.0, 45.0, 38.0, 48.0, 40.0, 17.0, 33.0, 85.0, 64.0, 32.0, 84.0, 3.0, 54.0, 44.0, 71.0, 67.0, 90.0, 95.0, 54.0, 99.0, 99.0, 29.0, 52.0, 98.0, 9.0,}, Relationship.GEQ, 54.0));
constraints.add(new LinearConstraint(new double[] {0.0, 41.0, 12.0, 86.0, 90.0, 61.0, 31.0, 41.0, 23.0, 89.0, 17.0, 74.0, 44.0, 27.0, 16.0, 47.0, 80.0, 32.0, 11.0, 56.0, 68.0, 82.0, 11.0, 62.0, 62.0, 53.0, 39.0, 16.0, 48.0, 1.0, 63.0,}, Relationship.GEQ, 62.0));
constraints.add(new LinearConstraint(new double[] {83.0, -76.0, -94.0, -19.0, -15.0, -70.0, -72.0, -57.0, -63.0, -65.0, -22.0, -94.0, -22.0, -88.0, -86.0, -89.0, -72.0, -16.0, -80.0, -49.0, -70.0, -93.0, -95.0, -17.0, -83.0, -97.0, -31.0, -47.0, -31.0, -13.0, -23.0,}, Relationship.GEQ, 0.0));
constraints.add(new LinearConstraint(new double[] {41.0, -96.0, -41.0, -48.0, -70.0, -43.0, -43.0, -43.0, -97.0, -37.0, -85.0, -70.0, -45.0, -67.0, -87.0, -69.0, -94.0, -54.0, -54.0, -92.0, -79.0, -10.0, -35.0, -20.0, -41.0, -41.0, -65.0, -25.0, -12.0, -8.0, -46.0,}, Relationship.GEQ, 0.0));
constraints.add(new LinearConstraint(new double[] {27.0, -42.0, -65.0, -49.0, -53.0, -42.0, -17.0, -2.0, -61.0, -31.0, -76.0, -47.0, -8.0, -93.0, -86.0, -62.0, -65.0, -63.0, -22.0, -43.0, -27.0, -23.0, -32.0, -74.0, -27.0, -63.0, -47.0, -78.0, -29.0, -95.0, -73.0,}, Relationship.GEQ, 0.0));
constraints.add(new LinearConstraint(new double[] {15.0, -46.0, -41.0, -83.0, -98.0, -99.0, -21.0, -35.0, -7.0, -14.0, -80.0, -63.0, -18.0, -42.0, -5.0, -34.0, -56.0, -70.0, -16.0, -18.0, -74.0, -61.0, -47.0, -41.0, -15.0, -79.0, -18.0, -47.0, -88.0, -68.0, -55.0,}, Relationship.GEQ, 0.0));
double epsilon = 1e-6;
PointValuePair solution = new SimplexSolver().optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(1.0d, solution.getValue(), epsilon);
Assert.assertTrue(validSolution(solution, constraints, epsilon));
}
@Test
public void testMath781() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 2, 6, 7 }, 0);
ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1, 2, 1 }, Relationship.LEQ, 2));
constraints.add(new LinearConstraint(new double[] { -1, 1, 1 }, Relationship.LEQ, -1));
constraints.add(new LinearConstraint(new double[] { 2, -3, 1 }, Relationship.LEQ, -1));
double epsilon = 1e-6;
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
Assert.assertTrue(Precision.compareTo(solution.getPoint()[0], 0.0d, epsilon) > 0);
Assert.assertTrue(Precision.compareTo(solution.getPoint()[1], 0.0d, epsilon) > 0);
Assert.assertTrue(Precision.compareTo(solution.getPoint()[2], 0.0d, epsilon) < 0);
Assert.assertEquals(2.0d, solution.getValue(), epsilon);
}
@Test
public void testMath713NegativeVariable() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] {1.0, 1.0}, 0.0d);
ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] {1, 0}, Relationship.EQ, 1));
double epsilon = 1e-6;
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(true));
Assert.assertTrue(Precision.compareTo(solution.getPoint()[0], 0.0d, epsilon) >= 0);
Assert.assertTrue(Precision.compareTo(solution.getPoint()[1], 0.0d, epsilon) >= 0);
}
@Test
public void testMath434NegativeVariable() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] {0.0, 0.0, 1.0}, 0.0d);
ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] {1, 1, 0}, Relationship.EQ, 5));
constraints.add(new LinearConstraint(new double[] {0, 0, 1}, Relationship.GEQ, -10));
double epsilon = 1e-6;
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(false));
Assert.assertEquals(5.0, solution.getPoint()[0] + solution.getPoint()[1], epsilon);
Assert.assertEquals(-10.0, solution.getPoint()[2], epsilon);
Assert.assertEquals(-10.0, solution.getValue(), epsilon);
}
@Test(expected = NoFeasibleSolutionException.class)
public void testMath434UnfeasibleSolution() {
double epsilon = 1e-6;
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] {1.0, 0.0}, 0.0);
ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] {epsilon/2, 0.5}, Relationship.EQ, 0));
constraints.add(new LinearConstraint(new double[] {1e-3, 0.1}, Relationship.EQ, 10));
SimplexSolver solver = new SimplexSolver();
// allowing only non-negative values, no feasible solution shall be found
solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(true));
}
@Test
public void testMath434PivotRowSelection() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] {1.0}, 0.0);
double epsilon = 1e-6;
ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] {200}, Relationship.GEQ, 1));
constraints.add(new LinearConstraint(new double[] {100}, Relationship.GEQ, 0.499900001));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(false));
Assert.assertTrue(Precision.compareTo(solution.getPoint()[0] * 200.d, 1.d, epsilon) >= 0);
Assert.assertEquals(0.0050, solution.getValue(), epsilon);
}
@Test
public void testMath434PivotRowSelection2() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] {0.0d, 1.0d, 1.0d, 0.0d, 0.0d, 0.0d, 0.0d}, 0.0d);
ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] {1.0d, -0.1d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, Relationship.EQ, -0.1d));
constraints.add(new LinearConstraint(new double[] {1.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, Relationship.GEQ, -1e-18d));
constraints.add(new LinearConstraint(new double[] {0.0d, 1.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, Relationship.GEQ, 0.0d));
constraints.add(new LinearConstraint(new double[] {0.0d, 0.0d, 0.0d, 1.0d, 0.0d, -0.0128588d, 1e-5d}, Relationship.EQ, 0.0d));
constraints.add(new LinearConstraint(new double[] {0.0d, 0.0d, 0.0d, 0.0d, 1.0d, 1e-5d, -0.0128586d}, Relationship.EQ, 1e-10d));
constraints.add(new LinearConstraint(new double[] {0.0d, 0.0d, 1.0d, -1.0d, 0.0d, 0.0d, 0.0d}, Relationship.GEQ, 0.0d));
constraints.add(new LinearConstraint(new double[] {0.0d, 0.0d, 1.0d, 1.0d, 0.0d, 0.0d, 0.0d}, Relationship.GEQ, 0.0d));
constraints.add(new LinearConstraint(new double[] {0.0d, 0.0d, 1.0d, 0.0d, -1.0d, 0.0d, 0.0d}, Relationship.GEQ, 0.0d));
constraints.add(new LinearConstraint(new double[] {0.0d, 0.0d, 1.0d, 0.0d, 1.0d, 0.0d, 0.0d}, Relationship.GEQ, 0.0d));
double epsilon = 1e-7;
SimplexSolver simplex = new SimplexSolver();
PointValuePair solution = simplex.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(false));
Assert.assertTrue(Precision.compareTo(solution.getPoint()[0], -1e-18d, epsilon) >= 0);
Assert.assertEquals(1.0d, solution.getPoint()[1], epsilon);
Assert.assertEquals(0.0d, solution.getPoint()[2], epsilon);
Assert.assertEquals(1.0d, solution.getValue(), epsilon);
}
@Test
public void testMath272() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 2, 2, 1 }, 0);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1, 1, 0 }, Relationship.GEQ, 1));
constraints.add(new LinearConstraint(new double[] { 1, 0, 1 }, Relationship.GEQ, 1));
constraints.add(new LinearConstraint(new double[] { 0, 1, 0 }, Relationship.GEQ, 1));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(0.0, solution.getPoint()[0], .0000001);
Assert.assertEquals(1.0, solution.getPoint()[1], .0000001);
Assert.assertEquals(1.0, solution.getPoint()[2], .0000001);
Assert.assertEquals(3.0, solution.getValue(), .0000001);
}
@Test
public void testMath286() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 0.8, 0.2, 0.7, 0.3, 0.6, 0.4 }, 0 );
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1, 0, 1, 0, 1, 0 }, Relationship.EQ, 23.0));
constraints.add(new LinearConstraint(new double[] { 0, 1, 0, 1, 0, 1 }, Relationship.EQ, 23.0));
constraints.add(new LinearConstraint(new double[] { 1, 0, 0, 0, 0, 0 }, Relationship.GEQ, 10.0));
constraints.add(new LinearConstraint(new double[] { 0, 0, 1, 0, 0, 0 }, Relationship.GEQ, 8.0));
constraints.add(new LinearConstraint(new double[] { 0, 0, 0, 0, 1, 0 }, Relationship.GEQ, 5.0));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(25.8, solution.getValue(), .0000001);
Assert.assertEquals(23.0, solution.getPoint()[0] + solution.getPoint()[2] + solution.getPoint()[4], 0.0000001);
Assert.assertEquals(23.0, solution.getPoint()[1] + solution.getPoint()[3] + solution.getPoint()[5], 0.0000001);
Assert.assertTrue(solution.getPoint()[0] >= 10.0 - 0.0000001);
Assert.assertTrue(solution.getPoint()[2] >= 8.0 - 0.0000001);
Assert.assertTrue(solution.getPoint()[4] >= 5.0 - 0.0000001);
}
@Test
public void testDegeneracy() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 0.8, 0.7 }, 0 );
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1, 1 }, Relationship.LEQ, 18.0));
constraints.add(new LinearConstraint(new double[] { 1, 0 }, Relationship.GEQ, 10.0));
constraints.add(new LinearConstraint(new double[] { 0, 1 }, Relationship.GEQ, 8.0));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(13.6, solution.getValue(), .0000001);
}
@Test
public void testMath288() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 7, 3, 0, 0 }, 0 );
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 3, 0, -5, 0 }, Relationship.LEQ, 0.0));
constraints.add(new LinearConstraint(new double[] { 2, 0, 0, -5 }, Relationship.LEQ, 0.0));
constraints.add(new LinearConstraint(new double[] { 0, 3, 0, -5 }, Relationship.LEQ, 0.0));
constraints.add(new LinearConstraint(new double[] { 1, 0, 0, 0 }, Relationship.LEQ, 1.0));
constraints.add(new LinearConstraint(new double[] { 0, 1, 0, 0 }, Relationship.LEQ, 1.0));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(10.0, solution.getValue(), .0000001);
}
@Test
public void testMath290GEQ() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 1, 5 }, 0 );
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 2, 0 }, Relationship.GEQ, -1.0));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(0, solution.getValue(), .0000001);
Assert.assertEquals(0, solution.getPoint()[0], .0000001);
Assert.assertEquals(0, solution.getPoint()[1], .0000001);
}
@Test(expected=NoFeasibleSolutionException.class)
public void testMath290LEQ() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 1, 5 }, 0 );
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 2, 0 }, Relationship.LEQ, -1.0));
SimplexSolver solver = new SimplexSolver();
solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(true));
}
@Test
public void testMath293() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 0.8, 0.2, 0.7, 0.3, 0.4, 0.6}, 0 );
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1, 0, 1, 0, 1, 0 }, Relationship.EQ, 30.0));
constraints.add(new LinearConstraint(new double[] { 0, 1, 0, 1, 0, 1 }, Relationship.EQ, 30.0));
constraints.add(new LinearConstraint(new double[] { 0.8, 0.2, 0.0, 0.0, 0.0, 0.0 }, Relationship.GEQ, 10.0));
constraints.add(new LinearConstraint(new double[] { 0.0, 0.0, 0.7, 0.3, 0.0, 0.0 }, Relationship.GEQ, 10.0));
constraints.add(new LinearConstraint(new double[] { 0.0, 0.0, 0.0, 0.0, 0.4, 0.6 }, Relationship.GEQ, 10.0));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution1 = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(15.7143, solution1.getPoint()[0], .0001);
Assert.assertEquals(0.0, solution1.getPoint()[1], .0001);
Assert.assertEquals(14.2857, solution1.getPoint()[2], .0001);
Assert.assertEquals(0.0, solution1.getPoint()[3], .0001);
Assert.assertEquals(0.0, solution1.getPoint()[4], .0001);
Assert.assertEquals(30.0, solution1.getPoint()[5], .0001);
Assert.assertEquals(40.57143, solution1.getValue(), .0001);
double valA = 0.8 * solution1.getPoint()[0] + 0.2 * solution1.getPoint()[1];
double valB = 0.7 * solution1.getPoint()[2] + 0.3 * solution1.getPoint()[3];
double valC = 0.4 * solution1.getPoint()[4] + 0.6 * solution1.getPoint()[5];
f = new LinearObjectiveFunction(new double[] { 0.8, 0.2, 0.7, 0.3, 0.4, 0.6}, 0 );
constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1, 0, 1, 0, 1, 0 }, Relationship.EQ, 30.0));
constraints.add(new LinearConstraint(new double[] { 0, 1, 0, 1, 0, 1 }, Relationship.EQ, 30.0));
constraints.add(new LinearConstraint(new double[] { 0.8, 0.2, 0.0, 0.0, 0.0, 0.0 }, Relationship.GEQ, valA));
constraints.add(new LinearConstraint(new double[] { 0.0, 0.0, 0.7, 0.3, 0.0, 0.0 }, Relationship.GEQ, valB));
constraints.add(new LinearConstraint(new double[] { 0.0, 0.0, 0.0, 0.0, 0.4, 0.6 }, Relationship.GEQ, valC));
PointValuePair solution2 = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(40.57143, solution2.getValue(), .0001);
}
@Test
public void testSimplexSolver() {
LinearObjectiveFunction f =
new LinearObjectiveFunction(new double[] { 15, 10 }, 7);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1, 0 }, Relationship.LEQ, 2));
constraints.add(new LinearConstraint(new double[] { 0, 1 }, Relationship.LEQ, 3));
constraints.add(new LinearConstraint(new double[] { 1, 1 }, Relationship.EQ, 4));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(2.0, solution.getPoint()[0], 0.0);
Assert.assertEquals(2.0, solution.getPoint()[1], 0.0);
Assert.assertEquals(57.0, solution.getValue(), 0.0);
}
@Test
public void testSingleVariableAndConstraint() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 3 }, 0);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1 }, Relationship.LEQ, 10));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
Assert.assertEquals(10.0, solution.getPoint()[0], 0.0);
Assert.assertEquals(30.0, solution.getValue(), 0.0);
}
/**
* With no artificial variables needed (no equals and no greater than
* constraints) we can go straight to Phase 2.
*/
@Test
public void testModelWithNoArtificialVars() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 15, 10 }, 0);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1, 0 }, Relationship.LEQ, 2));
constraints.add(new LinearConstraint(new double[] { 0, 1 }, Relationship.LEQ, 3));
constraints.add(new LinearConstraint(new double[] { 1, 1 }, Relationship.LEQ, 4));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
Assert.assertEquals(2.0, solution.getPoint()[0], 0.0);
Assert.assertEquals(2.0, solution.getPoint()[1], 0.0);
Assert.assertEquals(50.0, solution.getValue(), 0.0);
}
@Test
public void testMinimization() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { -2, 1 }, -5);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1, 2 }, Relationship.LEQ, 6));
constraints.add(new LinearConstraint(new double[] { 3, 2 }, Relationship.LEQ, 12));
constraints.add(new LinearConstraint(new double[] { 0, 1 }, Relationship.GEQ, 0));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(false));
Assert.assertEquals(4.0, solution.getPoint()[0], 0.0);
Assert.assertEquals(0.0, solution.getPoint()[1], 0.0);
Assert.assertEquals(-13.0, solution.getValue(), 0.0);
}
@Test
public void testSolutionWithNegativeDecisionVariable() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { -2, 1 }, 0);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1, 1 }, Relationship.GEQ, 6));
constraints.add(new LinearConstraint(new double[] { 1, 2 }, Relationship.LEQ, 14));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
Assert.assertEquals(-2.0, solution.getPoint()[0], 0.0);
Assert.assertEquals(8.0, solution.getPoint()[1], 0.0);
Assert.assertEquals(12.0, solution.getValue(), 0.0);
}
@Test(expected = NoFeasibleSolutionException.class)
public void testInfeasibleSolution() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 15 }, 0);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1 }, Relationship.LEQ, 1));
constraints.add(new LinearConstraint(new double[] { 1 }, Relationship.GEQ, 3));
SimplexSolver solver = new SimplexSolver();
solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
}
@Test(expected = UnboundedSolutionException.class)
public void testUnboundedSolution() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 15, 10 }, 0);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1, 0 }, Relationship.EQ, 2));
SimplexSolver solver = new SimplexSolver();
solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
}
@Test
public void testRestrictVariablesToNonNegative() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 409, 523, 70, 204, 339 }, 0);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 43, 56, 345, 56, 5 }, Relationship.LEQ, 4567456));
constraints.add(new LinearConstraint(new double[] { 12, 45, 7, 56, 23 }, Relationship.LEQ, 56454));
constraints.add(new LinearConstraint(new double[] { 8, 768, 0, 34, 7456 }, Relationship.LEQ, 1923421));
constraints.add(new LinearConstraint(new double[] { 12342, 2342, 34, 678, 2342 }, Relationship.GEQ, 4356));
constraints.add(new LinearConstraint(new double[] { 45, 678, 76, 52, 23 }, Relationship.EQ, 456356));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(2902.92783505155, solution.getPoint()[0], .0000001);
Assert.assertEquals(480.419243986254, solution.getPoint()[1], .0000001);
Assert.assertEquals(0.0, solution.getPoint()[2], .0000001);
Assert.assertEquals(0.0, solution.getPoint()[3], .0000001);
Assert.assertEquals(0.0, solution.getPoint()[4], .0000001);
Assert.assertEquals(1438556.7491409, solution.getValue(), .0000001);
}
@Test
public void testEpsilon() {
LinearObjectiveFunction f =
new LinearObjectiveFunction(new double[] { 10, 5, 1 }, 0);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 9, 8, 0 }, Relationship.EQ, 17));
constraints.add(new LinearConstraint(new double[] { 0, 7, 8 }, Relationship.LEQ, 7));
constraints.add(new LinearConstraint(new double[] { 10, 0, 2 }, Relationship.LEQ, 10));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
Assert.assertEquals(1.0, solution.getPoint()[0], 0.0);
Assert.assertEquals(1.0, solution.getPoint()[1], 0.0);
Assert.assertEquals(0.0, solution.getPoint()[2], 0.0);
Assert.assertEquals(15.0, solution.getValue(), 0.0);
}
@Test
public void testTrivialModel() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 1, 1 }, 0);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] { 1, 1 }, Relationship.EQ, 0));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(0, solution.getValue(), .0000001);
}
@Test
public void testLargeModel() {
double[] objective = new double[] {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 12, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
12, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 12, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 12, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 12, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 12, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1};
LinearObjectiveFunction f = new LinearObjectiveFunction(objective, 0);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(equationFromString(objective.length, "x0 + x1 + x2 + x3 - x12 = 0"));
constraints.add(equationFromString(objective.length, "x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11 - x13 = 0"));
constraints.add(equationFromString(objective.length, "x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11 >= 49"));
constraints.add(equationFromString(objective.length, "x0 + x1 + x2 + x3 >= 42"));
constraints.add(equationFromString(objective.length, "x14 + x15 + x16 + x17 - x26 = 0"));
constraints.add(equationFromString(objective.length, "x18 + x19 + x20 + x21 + x22 + x23 + x24 + x25 - x27 = 0"));
constraints.add(equationFromString(objective.length, "x14 + x15 + x16 + x17 - x12 = 0"));
constraints.add(equationFromString(objective.length, "x18 + x19 + x20 + x21 + x22 + x23 + x24 + x25 - x13 = 0"));
constraints.add(equationFromString(objective.length, "x28 + x29 + x30 + x31 - x40 = 0"));
constraints.add(equationFromString(objective.length, "x32 + x33 + x34 + x35 + x36 + x37 + x38 + x39 - x41 = 0"));
constraints.add(equationFromString(objective.length, "x32 + x33 + x34 + x35 + x36 + x37 + x38 + x39 >= 49"));
constraints.add(equationFromString(objective.length, "x28 + x29 + x30 + x31 >= 42"));
constraints.add(equationFromString(objective.length, "x42 + x43 + x44 + x45 - x54 = 0"));
constraints.add(equationFromString(objective.length, "x46 + x47 + x48 + x49 + x50 + x51 + x52 + x53 - x55 = 0"));
constraints.add(equationFromString(objective.length, "x42 + x43 + x44 + x45 - x40 = 0"));
constraints.add(equationFromString(objective.length, "x46 + x47 + x48 + x49 + x50 + x51 + x52 + x53 - x41 = 0"));
constraints.add(equationFromString(objective.length, "x56 + x57 + x58 + x59 - x68 = 0"));
constraints.add(equationFromString(objective.length, "x60 + x61 + x62 + x63 + x64 + x65 + x66 + x67 - x69 = 0"));
constraints.add(equationFromString(objective.length, "x60 + x61 + x62 + x63 + x64 + x65 + x66 + x67 >= 51"));
constraints.add(equationFromString(objective.length, "x56 + x57 + x58 + x59 >= 44"));
constraints.add(equationFromString(objective.length, "x70 + x71 + x72 + x73 - x82 = 0"));
constraints.add(equationFromString(objective.length, "x74 + x75 + x76 + x77 + x78 + x79 + x80 + x81 - x83 = 0"));
constraints.add(equationFromString(objective.length, "x70 + x71 + x72 + x73 - x68 = 0"));
constraints.add(equationFromString(objective.length, "x74 + x75 + x76 + x77 + x78 + x79 + x80 + x81 - x69 = 0"));
constraints.add(equationFromString(objective.length, "x84 + x85 + x86 + x87 - x96 = 0"));
constraints.add(equationFromString(objective.length, "x88 + x89 + x90 + x91 + x92 + x93 + x94 + x95 - x97 = 0"));
constraints.add(equationFromString(objective.length, "x88 + x89 + x90 + x91 + x92 + x93 + x94 + x95 >= 51"));
constraints.add(equationFromString(objective.length, "x84 + x85 + x86 + x87 >= 44"));
constraints.add(equationFromString(objective.length, "x98 + x99 + x100 + x101 - x110 = 0"));
constraints.add(equationFromString(objective.length, "x102 + x103 + x104 + x105 + x106 + x107 + x108 + x109 - x111 = 0"));
constraints.add(equationFromString(objective.length, "x98 + x99 + x100 + x101 - x96 = 0"));
constraints.add(equationFromString(objective.length, "x102 + x103 + x104 + x105 + x106 + x107 + x108 + x109 - x97 = 0"));
constraints.add(equationFromString(objective.length, "x112 + x113 + x114 + x115 - x124 = 0"));
constraints.add(equationFromString(objective.length, "x116 + x117 + x118 + x119 + x120 + x121 + x122 + x123 - x125 = 0"));
constraints.add(equationFromString(objective.length, "x116 + x117 + x118 + x119 + x120 + x121 + x122 + x123 >= 49"));
constraints.add(equationFromString(objective.length, "x112 + x113 + x114 + x115 >= 42"));
constraints.add(equationFromString(objective.length, "x126 + x127 + x128 + x129 - x138 = 0"));
constraints.add(equationFromString(objective.length, "x130 + x131 + x132 + x133 + x134 + x135 + x136 + x137 - x139 = 0"));
constraints.add(equationFromString(objective.length, "x126 + x127 + x128 + x129 - x124 = 0"));
constraints.add(equationFromString(objective.length, "x130 + x131 + x132 + x133 + x134 + x135 + x136 + x137 - x125 = 0"));
constraints.add(equationFromString(objective.length, "x140 + x141 + x142 + x143 - x152 = 0"));
constraints.add(equationFromString(objective.length, "x144 + x145 + x146 + x147 + x148 + x149 + x150 + x151 - x153 = 0"));
constraints.add(equationFromString(objective.length, "x144 + x145 + x146 + x147 + x148 + x149 + x150 + x151 >= 59"));
constraints.add(equationFromString(objective.length, "x140 + x141 + x142 + x143 >= 42"));
constraints.add(equationFromString(objective.length, "x154 + x155 + x156 + x157 - x166 = 0"));
constraints.add(equationFromString(objective.length, "x158 + x159 + x160 + x161 + x162 + x163 + x164 + x165 - x167 = 0"));
constraints.add(equationFromString(objective.length, "x154 + x155 + x156 + x157 - x152 = 0"));
constraints.add(equationFromString(objective.length, "x158 + x159 + x160 + x161 + x162 + x163 + x164 + x165 - x153 = 0"));
constraints.add(equationFromString(objective.length, "x83 + x82 - x168 = 0"));
constraints.add(equationFromString(objective.length, "x111 + x110 - x169 = 0"));
constraints.add(equationFromString(objective.length, "x170 - x182 = 0"));
constraints.add(equationFromString(objective.length, "x171 - x183 = 0"));
constraints.add(equationFromString(objective.length, "x172 - x184 = 0"));
constraints.add(equationFromString(objective.length, "x173 - x185 = 0"));
constraints.add(equationFromString(objective.length, "x174 - x186 = 0"));
constraints.add(equationFromString(objective.length, "x175 + x176 - x187 = 0"));
constraints.add(equationFromString(objective.length, "x177 - x188 = 0"));
constraints.add(equationFromString(objective.length, "x178 - x189 = 0"));
constraints.add(equationFromString(objective.length, "x179 - x190 = 0"));
constraints.add(equationFromString(objective.length, "x180 - x191 = 0"));
constraints.add(equationFromString(objective.length, "x181 - x192 = 0"));
constraints.add(equationFromString(objective.length, "x170 - x26 = 0"));
constraints.add(equationFromString(objective.length, "x171 - x27 = 0"));
constraints.add(equationFromString(objective.length, "x172 - x54 = 0"));
constraints.add(equationFromString(objective.length, "x173 - x55 = 0"));
constraints.add(equationFromString(objective.length, "x174 - x168 = 0"));
constraints.add(equationFromString(objective.length, "x177 - x169 = 0"));
constraints.add(equationFromString(objective.length, "x178 - x138 = 0"));
constraints.add(equationFromString(objective.length, "x179 - x139 = 0"));
constraints.add(equationFromString(objective.length, "x180 - x166 = 0"));
constraints.add(equationFromString(objective.length, "x181 - x167 = 0"));
constraints.add(equationFromString(objective.length, "x193 - x205 = 0"));
constraints.add(equationFromString(objective.length, "x194 - x206 = 0"));
constraints.add(equationFromString(objective.length, "x195 - x207 = 0"));
constraints.add(equationFromString(objective.length, "x196 - x208 = 0"));
constraints.add(equationFromString(objective.length, "x197 - x209 = 0"));
constraints.add(equationFromString(objective.length, "x198 + x199 - x210 = 0"));
constraints.add(equationFromString(objective.length, "x200 - x211 = 0"));
constraints.add(equationFromString(objective.length, "x201 - x212 = 0"));
constraints.add(equationFromString(objective.length, "x202 - x213 = 0"));
constraints.add(equationFromString(objective.length, "x203 - x214 = 0"));
constraints.add(equationFromString(objective.length, "x204 - x215 = 0"));
constraints.add(equationFromString(objective.length, "x193 - x182 = 0"));
constraints.add(equationFromString(objective.length, "x194 - x183 = 0"));
constraints.add(equationFromString(objective.length, "x195 - x184 = 0"));
constraints.add(equationFromString(objective.length, "x196 - x185 = 0"));
constraints.add(equationFromString(objective.length, "x197 - x186 = 0"));
constraints.add(equationFromString(objective.length, "x198 + x199 - x187 = 0"));
constraints.add(equationFromString(objective.length, "x200 - x188 = 0"));
constraints.add(equationFromString(objective.length, "x201 - x189 = 0"));
constraints.add(equationFromString(objective.length, "x202 - x190 = 0"));
constraints.add(equationFromString(objective.length, "x203 - x191 = 0"));
constraints.add(equationFromString(objective.length, "x204 - x192 = 0"));
SimplexSolver solver = new SimplexSolver();
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
GoalType.MINIMIZE, new NonNegativeConstraint(true));
Assert.assertEquals(7518.0, solution.getValue(), .0000001);
}
/**
* Converts a test string to a {@link LinearConstraint}.
* Ex: x0 + x1 + x2 + x3 - x12 = 0
*/
private LinearConstraint equationFromString(int numCoefficients, String s) {
Relationship relationship;
if (s.contains(">=")) {
relationship = Relationship.GEQ;
} else if (s.contains("<=")) {
relationship = Relationship.LEQ;
} else if (s.contains("=")) {
relationship = Relationship.EQ;
} else {
throw new IllegalArgumentException();
}
String[] equationParts = s.split("[>|<]?=");
double rhs = Double.parseDouble(equationParts[1].trim());
double[] lhs = new double[numCoefficients];
String left = equationParts[0].replaceAll(" ?x", "");
String[] coefficients = left.split(" ");
for (String coefficient : coefficients) {
double value = coefficient.charAt(0) == '-' ? -1 : 1;
int index = Integer.parseInt(coefficient.replaceFirst("[+|-]", "").trim());
lhs[index] = value;
}
return new LinearConstraint(lhs, relationship, rhs);
}
private static boolean validSolution(PointValuePair solution, List<LinearConstraint> constraints, double epsilon) {
double[] vals = solution.getPoint();
for (LinearConstraint c : constraints) {
double[] coeffs = c.getCoefficients().toArray();
double result = 0.0d;
for (int i = 0; i < vals.length; i++) {
result += vals[i] * coeffs[i];
}
switch (c.getRelationship()) {
case EQ:
if (!Precision.equals(result, c.getValue(), epsilon)) {
return false;
}
break;
case GEQ:
if (Precision.compareTo(result, c.getValue(), epsilon) < 0) {
return false;
}
break;
case LEQ:
if (Precision.compareTo(result, c.getValue(), epsilon) > 0) {
return false;
}
break;
}
}
return true;
}
}

View File

@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.linear;
import java.util.ArrayList;
import java.util.Collection;
import org.apache.commons.math3.TestUtils;
import org.apache.commons.math3.optim.GoalType;
import org.junit.Assert;
import org.junit.Test;
public class SimplexTableauTest {
@Test
public void testInitialization() {
LinearObjectiveFunction f = createFunction();
Collection<LinearConstraint> constraints = createConstraints();
SimplexTableau tableau =
new SimplexTableau(f, constraints, GoalType.MAXIMIZE, false, 1.0e-6);
double[][] expectedInitialTableau = {
{-1, 0, -1, -1, 2, 0, 0, 0, -4},
{ 0, 1, -15, -10, 25, 0, 0, 0, 0},
{ 0, 0, 1, 0, -1, 1, 0, 0, 2},
{ 0, 0, 0, 1, -1, 0, 1, 0, 3},
{ 0, 0, 1, 1, -2, 0, 0, 1, 4}
};
assertMatrixEquals(expectedInitialTableau, tableau.getData());
}
@Test
public void testDropPhase1Objective() {
LinearObjectiveFunction f = createFunction();
Collection<LinearConstraint> constraints = createConstraints();
SimplexTableau tableau =
new SimplexTableau(f, constraints, GoalType.MAXIMIZE, false, 1.0e-6);
double[][] expectedTableau = {
{ 1, -15, -10, 0, 0, 0, 0},
{ 0, 1, 0, 1, 0, 0, 2},
{ 0, 0, 1, 0, 1, 0, 3},
{ 0, 1, 1, 0, 0, 1, 4}
};
tableau.dropPhase1Objective();
assertMatrixEquals(expectedTableau, tableau.getData());
}
@Test
public void testTableauWithNoArtificialVars() {
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] {15, 10}, 0);
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] {1, 0}, Relationship.LEQ, 2));
constraints.add(new LinearConstraint(new double[] {0, 1}, Relationship.LEQ, 3));
constraints.add(new LinearConstraint(new double[] {1, 1}, Relationship.LEQ, 4));
SimplexTableau tableau =
new SimplexTableau(f, constraints, GoalType.MAXIMIZE, false, 1.0e-6);
double[][] initialTableau = {
{1, -15, -10, 25, 0, 0, 0, 0},
{0, 1, 0, -1, 1, 0, 0, 2},
{0, 0, 1, -1, 0, 1, 0, 3},
{0, 1, 1, -2, 0, 0, 1, 4}
};
assertMatrixEquals(initialTableau, tableau.getData());
}
@Test
public void testSerial() {
LinearObjectiveFunction f = createFunction();
Collection<LinearConstraint> constraints = createConstraints();
SimplexTableau tableau =
new SimplexTableau(f, constraints, GoalType.MAXIMIZE, false, 1.0e-6);
Assert.assertEquals(tableau, TestUtils.serializeAndRecover(tableau));
}
private LinearObjectiveFunction createFunction() {
return new LinearObjectiveFunction(new double[] {15, 10}, 0);
}
private Collection<LinearConstraint> createConstraints() {
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
constraints.add(new LinearConstraint(new double[] {1, 0}, Relationship.LEQ, 2));
constraints.add(new LinearConstraint(new double[] {0, 1}, Relationship.LEQ, 3));
constraints.add(new LinearConstraint(new double[] {1, 1}, Relationship.EQ, 4));
return constraints;
}
private void assertMatrixEquals(double[][] expected, double[][] result) {
Assert.assertEquals("Wrong number of rows.", expected.length, result.length);
for (int i = 0; i < expected.length; i++) {
Assert.assertEquals("Wrong number of columns.", expected[i].length, result[i].length);
for (int j = 0; j < expected[i].length; j++) {
Assert.assertEquals("Wrong value at position [" + i + "," + j + "]", expected[i][j], result[i][j], 1.0e-15);
}
}
}
}

View File

@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.ObjectiveFunction;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleValueChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.CircleScalar;
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
import org.apache.commons.math3.random.GaussianRandomGenerator;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomVectorGenerator;
import org.apache.commons.math3.random.UncorrelatedRandomVectorGenerator;
import org.junit.Assert;
import org.junit.Test;
public class MultiStartMultivariateOptimizerTest {
@Test
public void testCircleFitting() {
CircleScalar circle = new CircleScalar();
circle.addPoint( 30.0, 68.0);
circle.addPoint( 50.0, -6.0);
circle.addPoint(110.0, -20.0);
circle.addPoint( 35.0, 15.0);
circle.addPoint( 45.0, 97.0);
// TODO: the wrapper around NonLinearConjugateGradientOptimizer is a temporary hack for
// version 3.1 of the library. It should be removed when NonLinearConjugateGradientOptimizer
// will officially be declared as implementing MultivariateDifferentiableOptimizer
GradientMultivariateOptimizer underlying
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-10, 1e-10));
JDKRandomGenerator g = new JDKRandomGenerator();
g.setSeed(753289573253l);
RandomVectorGenerator generator
= new UncorrelatedRandomVectorGenerator(new double[] { 50, 50 },
new double[] { 10, 10 },
new GaussianRandomGenerator(g));
MultiStartMultivariateOptimizer optimizer
= new MultiStartMultivariateOptimizer(underlying, 10, generator);
PointValuePair optimum
= optimizer.optimize(new MaxEval(200),
circle.getObjectiveFunction(),
circle.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 98.680, 47.345 }));
Assert.assertEquals(200, optimizer.getMaxEvaluations());
PointValuePair[] optima = optimizer.getOptima();
for (PointValuePair o : optima) {
Vector2D center = new Vector2D(o.getPointRef()[0], o.getPointRef()[1]);
Assert.assertEquals(69.960161753, circle.getRadius(center), 1e-8);
Assert.assertEquals(96.075902096, center.getX(), 1e-8);
Assert.assertEquals(48.135167894, center.getY(), 1e-8);
}
Assert.assertTrue(optimizer.getEvaluations() > 70);
Assert.assertTrue(optimizer.getEvaluations() < 90);
Assert.assertEquals(3.1267527, optimum.getValue(), 1e-8);
}
@Test
public void testRosenbrock() {
Rosenbrock rosenbrock = new Rosenbrock();
SimplexOptimizer underlying
= new SimplexOptimizer(new SimpleValueChecker(-1, 1e-3));
NelderMeadSimplex simplex = new NelderMeadSimplex(new double[][] {
{ -1.2, 1.0 },
{ 0.9, 1.2 } ,
{ 3.5, -2.3 }
});
JDKRandomGenerator g = new JDKRandomGenerator();
g.setSeed(16069223052l);
RandomVectorGenerator generator
= new UncorrelatedRandomVectorGenerator(2, new GaussianRandomGenerator(g));
MultiStartMultivariateOptimizer optimizer
= new MultiStartMultivariateOptimizer(underlying, 10, generator);
PointValuePair optimum
= optimizer.optimize(new MaxEval(1100),
new ObjectiveFunction(rosenbrock),
GoalType.MINIMIZE,
simplex,
new InitialGuess(new double[] { -1.2, 1.0 }));
Assert.assertEquals(rosenbrock.getCount(), optimizer.getEvaluations());
Assert.assertTrue(optimizer.getEvaluations() > 900);
Assert.assertTrue(optimizer.getEvaluations() < 1200);
Assert.assertTrue(optimum.getValue() < 8e-4);
}
private static class Rosenbrock implements MultivariateFunction {
private int count;
public Rosenbrock() {
count = 0;
}
public double value(double[] x) {
++count;
double a = x[1] - x[0] * x[0];
double b = 1 - x[0];
return 100 * a * a + b * b;
}
public int getCount() {
return count;
}
}
}

View File

@ -0,0 +1,195 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.ObjectiveFunction;
import org.apache.commons.math3.optim.SimplePointChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;
import org.junit.Assert;
import org.junit.Test;
public class MultivariateFunctionMappingAdapterTest {
@Test
public void testStartSimplexInsideRange() {
final BiQuadratic biQuadratic = new BiQuadratic(2.0, 2.5, 1.0, 3.0, 2.0, 3.0);
final MultivariateFunctionMappingAdapter wrapped
= new MultivariateFunctionMappingAdapter(biQuadratic,
biQuadratic.getLower(),
biQuadratic.getUpper());
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final AbstractSimplex simplex = new NelderMeadSimplex(new double[][] {
wrapped.boundedToUnbounded(new double[] { 1.5, 2.75 }),
wrapped.boundedToUnbounded(new double[] { 1.5, 2.95 }),
wrapped.boundedToUnbounded(new double[] { 1.7, 2.90 })
});
final PointValuePair optimum
= optimizer.optimize(new MaxEval(300),
new ObjectiveFunction(wrapped),
simplex,
GoalType.MINIMIZE,
new InitialGuess(wrapped.boundedToUnbounded(new double[] { 1.5, 2.25 })));
final double[] bounded = wrapped.unboundedToBounded(optimum.getPoint());
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), bounded[0], 2e-7);
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), bounded[1], 2e-7);
}
@Test
public void testOptimumOutsideRange() {
final BiQuadratic biQuadratic = new BiQuadratic(4.0, 0.0, 1.0, 3.0, 2.0, 3.0);
final MultivariateFunctionMappingAdapter wrapped
= new MultivariateFunctionMappingAdapter(biQuadratic,
biQuadratic.getLower(),
biQuadratic.getUpper());
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final AbstractSimplex simplex = new NelderMeadSimplex(new double[][] {
wrapped.boundedToUnbounded(new double[] { 1.5, 2.75 }),
wrapped.boundedToUnbounded(new double[] { 1.5, 2.95 }),
wrapped.boundedToUnbounded(new double[] { 1.7, 2.90 })
});
final PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
new ObjectiveFunction(wrapped),
simplex,
GoalType.MINIMIZE,
new InitialGuess(wrapped.boundedToUnbounded(new double[] { 1.5, 2.25 })));
final double[] bounded = wrapped.unboundedToBounded(optimum.getPoint());
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), bounded[0], 2e-7);
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), bounded[1], 2e-7);
}
@Test
public void testUnbounded() {
final BiQuadratic biQuadratic = new BiQuadratic(4.0, 0.0,
Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY,
Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
final MultivariateFunctionMappingAdapter wrapped
= new MultivariateFunctionMappingAdapter(biQuadratic,
biQuadratic.getLower(),
biQuadratic.getUpper());
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final AbstractSimplex simplex = new NelderMeadSimplex(new double[][] {
wrapped.boundedToUnbounded(new double[] { 1.5, 2.75 }),
wrapped.boundedToUnbounded(new double[] { 1.5, 2.95 }),
wrapped.boundedToUnbounded(new double[] { 1.7, 2.90 })
});
final PointValuePair optimum
= optimizer.optimize(new MaxEval(300),
new ObjectiveFunction(wrapped),
simplex,
GoalType.MINIMIZE,
new InitialGuess(wrapped.boundedToUnbounded(new double[] { 1.5, 2.25 })));
final double[] bounded = wrapped.unboundedToBounded(optimum.getPoint());
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), bounded[0], 2e-7);
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), bounded[1], 2e-7);
}
@Test
public void testHalfBounded() {
final BiQuadratic biQuadratic = new BiQuadratic(4.0, 4.0,
1.0, Double.POSITIVE_INFINITY,
Double.NEGATIVE_INFINITY, 3.0);
final MultivariateFunctionMappingAdapter wrapped
= new MultivariateFunctionMappingAdapter(biQuadratic,
biQuadratic.getLower(),
biQuadratic.getUpper());
SimplexOptimizer optimizer = new SimplexOptimizer(1e-13, 1e-30);
final AbstractSimplex simplex = new NelderMeadSimplex(new double[][] {
wrapped.boundedToUnbounded(new double[] { 1.5, 2.75 }),
wrapped.boundedToUnbounded(new double[] { 1.5, 2.95 }),
wrapped.boundedToUnbounded(new double[] { 1.7, 2.90 })
});
final PointValuePair optimum
= optimizer.optimize(new MaxEval(200),
new ObjectiveFunction(wrapped),
simplex,
GoalType.MINIMIZE,
new InitialGuess(wrapped.boundedToUnbounded(new double[] { 1.5, 2.25 })));
final double[] bounded = wrapped.unboundedToBounded(optimum.getPoint());
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), bounded[0], 1e-7);
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), bounded[1], 1e-7);
}
private static class BiQuadratic implements MultivariateFunction {
private final double xOptimum;
private final double yOptimum;
private final double xMin;
private final double xMax;
private final double yMin;
private final double yMax;
public BiQuadratic(final double xOptimum, final double yOptimum,
final double xMin, final double xMax,
final double yMin, final double yMax) {
this.xOptimum = xOptimum;
this.yOptimum = yOptimum;
this.xMin = xMin;
this.xMax = xMax;
this.yMin = yMin;
this.yMax = yMax;
}
public double value(double[] point) {
// the function should never be called with out of range points
Assert.assertTrue(point[0] >= xMin);
Assert.assertTrue(point[0] <= xMax);
Assert.assertTrue(point[1] >= yMin);
Assert.assertTrue(point[1] <= yMax);
final double dx = point[0] - xOptimum;
final double dy = point[1] - yOptimum;
return dx * dx + dy * dy;
}
public double[] getLower() {
return new double[] { xMin, yMin };
}
public double[] getUpper() {
return new double[] { xMax, yMax };
}
public double getBoundedXOptimum() {
return (xOptimum < xMin) ? xMin : ((xOptimum > xMax) ? xMax : xOptimum);
}
public double getBoundedYOptimum() {
return (yOptimum < yMin) ? yMin : ((yOptimum > yMax) ? yMax : yOptimum);
}
}
}

View File

@ -0,0 +1,204 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.ObjectiveFunction;
import org.apache.commons.math3.optim.SimplePointChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;
import org.junit.Assert;
import org.junit.Test;
public class MultivariateFunctionPenaltyAdapterTest {
@Test
public void testStartSimplexInsideRange() {
final BiQuadratic biQuadratic = new BiQuadratic(2.0, 2.5, 1.0, 3.0, 2.0, 3.0);
final MultivariateFunctionPenaltyAdapter wrapped
= new MultivariateFunctionPenaltyAdapter(biQuadratic,
biQuadratic.getLower(),
biQuadratic.getUpper(),
1000.0, new double[] { 100.0, 100.0 });
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final AbstractSimplex simplex = new NelderMeadSimplex(new double[] { 1.0, 0.5 });
final PointValuePair optimum
= optimizer.optimize(new MaxEval(300),
new ObjectiveFunction(wrapped),
simplex,
GoalType.MINIMIZE,
new InitialGuess(new double[] { 1.5, 2.25 }));
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), optimum.getPoint()[0], 2e-7);
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), optimum.getPoint()[1], 2e-7);
}
@Test
public void testStartSimplexOutsideRange() {
final BiQuadratic biQuadratic = new BiQuadratic(2.0, 2.5, 1.0, 3.0, 2.0, 3.0);
final MultivariateFunctionPenaltyAdapter wrapped
= new MultivariateFunctionPenaltyAdapter(biQuadratic,
biQuadratic.getLower(),
biQuadratic.getUpper(),
1000.0, new double[] { 100.0, 100.0 });
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final AbstractSimplex simplex = new NelderMeadSimplex(new double[] { 1.0, 0.5 });
final PointValuePair optimum
= optimizer.optimize(new MaxEval(300),
new ObjectiveFunction(wrapped),
simplex,
GoalType.MINIMIZE,
new InitialGuess(new double[] { -1.5, 4.0 }));
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), optimum.getPoint()[0], 2e-7);
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), optimum.getPoint()[1], 2e-7);
}
@Test
public void testOptimumOutsideRange() {
final BiQuadratic biQuadratic = new BiQuadratic(4.0, 0.0, 1.0, 3.0, 2.0, 3.0);
final MultivariateFunctionPenaltyAdapter wrapped
= new MultivariateFunctionPenaltyAdapter(biQuadratic,
biQuadratic.getLower(),
biQuadratic.getUpper(),
1000.0, new double[] { 100.0, 100.0 });
SimplexOptimizer optimizer = new SimplexOptimizer(new SimplePointChecker<PointValuePair>(1.0e-11, 1.0e-20));
final AbstractSimplex simplex = new NelderMeadSimplex(new double[] { 1.0, 0.5 });
final PointValuePair optimum
= optimizer.optimize(new MaxEval(600),
new ObjectiveFunction(wrapped),
simplex,
GoalType.MINIMIZE,
new InitialGuess(new double[] { -1.5, 4.0 }));
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), optimum.getPoint()[0], 2e-7);
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), optimum.getPoint()[1], 2e-7);
}
@Test
public void testUnbounded() {
final BiQuadratic biQuadratic = new BiQuadratic(4.0, 0.0,
Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY,
Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
final MultivariateFunctionPenaltyAdapter wrapped
= new MultivariateFunctionPenaltyAdapter(biQuadratic,
biQuadratic.getLower(),
biQuadratic.getUpper(),
1000.0, new double[] { 100.0, 100.0 });
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final AbstractSimplex simplex = new NelderMeadSimplex(new double[] { 1.0, 0.5 });
final PointValuePair optimum
= optimizer.optimize(new MaxEval(300),
new ObjectiveFunction(wrapped),
simplex,
GoalType.MINIMIZE,
new InitialGuess(new double[] { -1.5, 4.0 }));
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), optimum.getPoint()[0], 2e-7);
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), optimum.getPoint()[1], 2e-7);
}
@Test
public void testHalfBounded() {
final BiQuadratic biQuadratic = new BiQuadratic(4.0, 4.0,
1.0, Double.POSITIVE_INFINITY,
Double.NEGATIVE_INFINITY, 3.0);
final MultivariateFunctionPenaltyAdapter wrapped
= new MultivariateFunctionPenaltyAdapter(biQuadratic,
biQuadratic.getLower(),
biQuadratic.getUpper(),
1000.0, new double[] { 100.0, 100.0 });
SimplexOptimizer optimizer = new SimplexOptimizer(new SimplePointChecker<PointValuePair>(1.0e-10, 1.0e-20));
final AbstractSimplex simplex = new NelderMeadSimplex(new double[] { 1.0, 0.5 });
final PointValuePair optimum
= optimizer.optimize(new MaxEval(400),
new ObjectiveFunction(wrapped),
simplex,
GoalType.MINIMIZE,
new InitialGuess(new double[] { -1.5, 4.0 }));
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), optimum.getPoint()[0], 2e-7);
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), optimum.getPoint()[1], 2e-7);
}
private static class BiQuadratic implements MultivariateFunction {
private final double xOptimum;
private final double yOptimum;
private final double xMin;
private final double xMax;
private final double yMin;
private final double yMax;
public BiQuadratic(final double xOptimum, final double yOptimum,
final double xMin, final double xMax,
final double yMin, final double yMax) {
this.xOptimum = xOptimum;
this.yOptimum = yOptimum;
this.xMin = xMin;
this.xMax = xMax;
this.yMin = yMin;
this.yMax = yMax;
}
public double value(double[] point) {
// the function should never be called with out of range points
Assert.assertTrue(point[0] >= xMin);
Assert.assertTrue(point[0] <= xMax);
Assert.assertTrue(point[1] >= yMin);
Assert.assertTrue(point[1] <= yMax);
final double dx = point[0] - xOptimum;
final double dy = point[1] - yOptimum;
return dx * dx + dy * dy;
}
public double[] getLower() {
return new double[] { xMin, yMin };
}
public double[] getUpper() {
return new double[] { xMax, yMax };
}
public double getBoundedXOptimum() {
return (xOptimum < xMin) ? xMin : ((xOptimum > xMax) ? xMax : xOptimum);
}
public double getBoundedYOptimum() {
return (yOptimum < yMin) ? yMin : ((yOptimum > yMax) ? yMax : yOptimum);
}
}
}

View File

@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.gradient;
import java.util.ArrayList;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.optim.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient;
/**
* Class used in the tests.
*/
public class CircleScalar {
private ArrayList<Vector2D> points;
public CircleScalar() {
points = new ArrayList<Vector2D>();
}
public void addPoint(double px, double py) {
points.add(new Vector2D(px, py));
}
public double getRadius(Vector2D center) {
double r = 0;
for (Vector2D point : points) {
r += point.distance(center);
}
return r / points.size();
}
public ObjectiveFunction getObjectiveFunction() {
return new ObjectiveFunction(new MultivariateFunction() {
public double value(double[] params) {
Vector2D center = new Vector2D(params[0], params[1]);
double radius = getRadius(center);
double sum = 0;
for (Vector2D point : points) {
double di = point.distance(center) - radius;
sum += di * di;
}
return sum;
}
});
}
public ObjectiveFunctionGradient getObjectiveFunctionGradient() {
return new ObjectiveFunctionGradient(new MultivariateVectorFunction() {
public double[] value(double[] params) {
Vector2D center = new Vector2D(params[0], params[1]);
double radius = getRadius(center);
// gradient of the sum of squared residuals
double dJdX = 0;
double dJdY = 0;
for (Vector2D pk : points) {
double dk = pk.distance(center);
dJdX += (center.getX() - pk.getX()) * (dk - radius) / dk;
dJdY += (center.getY() - pk.getY()) * (dk - radius) / dk;
}
dJdX *= 2;
dJdY *= 2;
return new double[] { dJdX, dJdY };
}
});
}
}

View File

@ -0,0 +1,447 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.gradient;
import java.io.Serializable;
import org.apache.commons.math3.analysis.DifferentiableMultivariateFunction;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableFunction;
import org.apache.commons.math3.analysis.solvers.BrentSolver;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleValueChecker;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient;
import org.junit.Assert;
import org.junit.Test;
/**
* <p>Some of the unit tests are re-implementations of the MINPACK <a
* href="http://www.netlib.org/minpack/ex/file17">file17</a> and <a
* href="http://www.netlib.org/minpack/ex/file22">file22</a> test files.
* The redistribution policy for MINPACK is available <a
* href="http://www.netlib.org/minpack/disclaimer">here</a>, for
* convenience, it is reproduced below.</p>
*
* <table border="0" width="80%" cellpadding="10" align="center" bgcolor="#E0E0E0">
* <tr><td>
* Minpack Copyright Notice (1999) University of Chicago.
* All rights reserved
* </td></tr>
* <tr><td>
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* <ol>
* <li>Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.</li>
* <li>Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided
* with the distribution.</li>
* <li>The end-user documentation included with the redistribution, if any,
* must include the following acknowledgment:
* <code>This product includes software developed by the University of
* Chicago, as Operator of Argonne National Laboratory.</code>
* Alternately, this acknowledgment may appear in the software itself,
* if and wherever such third-party acknowledgments normally appear.</li>
* <li><strong>WARRANTY DISCLAIMER. THE SOFTWARE IS SUPPLIED "AS IS"
* WITHOUT WARRANTY OF ANY KIND. THE COPYRIGHT HOLDER, THE
* UNITED STATES, THE UNITED STATES DEPARTMENT OF ENERGY, AND
* THEIR EMPLOYEES: (1) DISCLAIM ANY WARRANTIES, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED WARRANTIES
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE
* OR NON-INFRINGEMENT, (2) DO NOT ASSUME ANY LEGAL LIABILITY
* OR RESPONSIBILITY FOR THE ACCURACY, COMPLETENESS, OR
* USEFULNESS OF THE SOFTWARE, (3) DO NOT REPRESENT THAT USE OF
* THE SOFTWARE WOULD NOT INFRINGE PRIVATELY OWNED RIGHTS, (4)
* DO NOT WARRANT THAT THE SOFTWARE WILL FUNCTION
* UNINTERRUPTED, THAT IT IS ERROR-FREE OR THAT ANY ERRORS WILL
* BE CORRECTED.</strong></li>
* <li><strong>LIMITATION OF LIABILITY. IN NO EVENT WILL THE COPYRIGHT
* HOLDER, THE UNITED STATES, THE UNITED STATES DEPARTMENT OF
* ENERGY, OR THEIR EMPLOYEES: BE LIABLE FOR ANY INDIRECT,
* INCIDENTAL, CONSEQUENTIAL, SPECIAL OR PUNITIVE DAMAGES OF
* ANY KIND OR NATURE, INCLUDING BUT NOT LIMITED TO LOSS OF
* PROFITS OR LOSS OF DATA, FOR ANY REASON WHATSOEVER, WHETHER
* SUCH LIABILITY IS ASSERTED ON THE BASIS OF CONTRACT, TORT
* (INCLUDING NEGLIGENCE OR STRICT LIABILITY), OR OTHERWISE,
* EVEN IF ANY OF SAID PARTIES HAS BEEN WARNED OF THE
* POSSIBILITY OF SUCH LOSS OR DAMAGES.</strong></li>
* <ol></td></tr>
* </table>
*
* @author Argonne National Laboratory. MINPACK project. March 1980 (original fortran minpack tests)
* @author Burton S. Garbow (original fortran minpack tests)
* @author Kenneth E. Hillstrom (original fortran minpack tests)
* @author Jorge J. More (original fortran minpack tests)
* @author Luc Maisonobe (non-minpack tests and minpack tests Java translation)
*/
public class NonLinearConjugateGradientOptimizerTest {
@Test
public void testTrivial() {
LinearProblem problem
= new LinearProblem(new double[][] { { 2 } }, new double[] { 3 });
NonLinearConjugateGradientOptimizer optimizer
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-6, 1e-6));
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
problem.getObjectiveFunction(),
problem.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 0 }));
Assert.assertEquals(1.5, optimum.getPoint()[0], 1.0e-10);
Assert.assertEquals(0.0, optimum.getValue(), 1.0e-10);
}
@Test
public void testColumnsPermutation() {
LinearProblem problem
= new LinearProblem(new double[][] { { 1.0, -1.0 }, { 0.0, 2.0 }, { 1.0, -2.0 } },
new double[] { 4.0, 6.0, 1.0 });
NonLinearConjugateGradientOptimizer optimizer
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-6, 1e-6));
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
problem.getObjectiveFunction(),
problem.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 0, 0 }));
Assert.assertEquals(7.0, optimum.getPoint()[0], 1.0e-10);
Assert.assertEquals(3.0, optimum.getPoint()[1], 1.0e-10);
Assert.assertEquals(0.0, optimum.getValue(), 1.0e-10);
}
@Test
public void testNoDependency() {
LinearProblem problem = new LinearProblem(new double[][] {
{ 2, 0, 0, 0, 0, 0 },
{ 0, 2, 0, 0, 0, 0 },
{ 0, 0, 2, 0, 0, 0 },
{ 0, 0, 0, 2, 0, 0 },
{ 0, 0, 0, 0, 2, 0 },
{ 0, 0, 0, 0, 0, 2 }
}, new double[] { 0.0, 1.1, 2.2, 3.3, 4.4, 5.5 });
NonLinearConjugateGradientOptimizer optimizer
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-6, 1e-6));
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
problem.getObjectiveFunction(),
problem.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 0, 0, 0, 0, 0, 0 }));
for (int i = 0; i < problem.target.length; ++i) {
Assert.assertEquals(0.55 * i, optimum.getPoint()[i], 1.0e-10);
}
}
@Test
public void testOneSet() {
LinearProblem problem = new LinearProblem(new double[][] {
{ 1, 0, 0 },
{ -1, 1, 0 },
{ 0, -1, 1 }
}, new double[] { 1, 1, 1});
NonLinearConjugateGradientOptimizer optimizer
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-6, 1e-6));
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
problem.getObjectiveFunction(),
problem.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 0, 0, 0 }));
Assert.assertEquals(1.0, optimum.getPoint()[0], 1.0e-10);
Assert.assertEquals(2.0, optimum.getPoint()[1], 1.0e-10);
Assert.assertEquals(3.0, optimum.getPoint()[2], 1.0e-10);
}
@Test
public void testTwoSets() {
final double epsilon = 1.0e-7;
LinearProblem problem = new LinearProblem(new double[][] {
{ 2, 1, 0, 4, 0, 0 },
{ -4, -2, 3, -7, 0, 0 },
{ 4, 1, -2, 8, 0, 0 },
{ 0, -3, -12, -1, 0, 0 },
{ 0, 0, 0, 0, epsilon, 1 },
{ 0, 0, 0, 0, 1, 1 }
}, new double[] { 2, -9, 2, 2, 1 + epsilon * epsilon, 2});
final Preconditioner preconditioner
= new Preconditioner() {
public double[] precondition(double[] point, double[] r) {
double[] d = r.clone();
d[0] /= 72.0;
d[1] /= 30.0;
d[2] /= 314.0;
d[3] /= 260.0;
d[4] /= 2 * (1 + epsilon * epsilon);
d[5] /= 4.0;
return d;
}
};
NonLinearConjugateGradientOptimizer optimizer
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-13, 1e-13),
new BrentSolver(),
preconditioner);
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
problem.getObjectiveFunction(),
problem.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 0, 0, 0, 0, 0, 0 }));
Assert.assertEquals( 3.0, optimum.getPoint()[0], 1.0e-10);
Assert.assertEquals( 4.0, optimum.getPoint()[1], 1.0e-10);
Assert.assertEquals(-1.0, optimum.getPoint()[2], 1.0e-10);
Assert.assertEquals(-2.0, optimum.getPoint()[3], 1.0e-10);
Assert.assertEquals( 1.0 + epsilon, optimum.getPoint()[4], 1.0e-10);
Assert.assertEquals( 1.0 - epsilon, optimum.getPoint()[5], 1.0e-10);
}
@Test
public void testNonInversible() {
LinearProblem problem = new LinearProblem(new double[][] {
{ 1, 2, -3 },
{ 2, 1, 3 },
{ -3, 0, -9 }
}, new double[] { 1, 1, 1 });
NonLinearConjugateGradientOptimizer optimizer
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-6, 1e-6));
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
problem.getObjectiveFunction(),
problem.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 0, 0, 0 }));
Assert.assertTrue(optimum.getValue() > 0.5);
}
@Test
public void testIllConditioned() {
LinearProblem problem1 = new LinearProblem(new double[][] {
{ 10.0, 7.0, 8.0, 7.0 },
{ 7.0, 5.0, 6.0, 5.0 },
{ 8.0, 6.0, 10.0, 9.0 },
{ 7.0, 5.0, 9.0, 10.0 }
}, new double[] { 32, 23, 33, 31 });
NonLinearConjugateGradientOptimizer optimizer
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-13, 1e-13),
new BrentSolver(1e-15, 1e-15));
PointValuePair optimum1
= optimizer.optimize(new MaxEval(200),
problem1.getObjectiveFunction(),
problem1.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 0, 1, 2, 3 }));
Assert.assertEquals(1.0, optimum1.getPoint()[0], 1.0e-4);
Assert.assertEquals(1.0, optimum1.getPoint()[1], 1.0e-4);
Assert.assertEquals(1.0, optimum1.getPoint()[2], 1.0e-4);
Assert.assertEquals(1.0, optimum1.getPoint()[3], 1.0e-4);
LinearProblem problem2 = new LinearProblem(new double[][] {
{ 10.00, 7.00, 8.10, 7.20 },
{ 7.08, 5.04, 6.00, 5.00 },
{ 8.00, 5.98, 9.89, 9.00 },
{ 6.99, 4.99, 9.00, 9.98 }
}, new double[] { 32, 23, 33, 31 });
PointValuePair optimum2
= optimizer.optimize(new MaxEval(200),
problem2.getObjectiveFunction(),
problem2.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 0, 1, 2, 3 }));
Assert.assertEquals(-81.0, optimum2.getPoint()[0], 1.0e-1);
Assert.assertEquals(137.0, optimum2.getPoint()[1], 1.0e-1);
Assert.assertEquals(-34.0, optimum2.getPoint()[2], 1.0e-1);
Assert.assertEquals( 22.0, optimum2.getPoint()[3], 1.0e-1);
}
@Test
public void testMoreEstimatedParametersSimple() {
LinearProblem problem = new LinearProblem(new double[][] {
{ 3.0, 2.0, 0.0, 0.0 },
{ 0.0, 1.0, -1.0, 1.0 },
{ 2.0, 0.0, 1.0, 0.0 }
}, new double[] { 7.0, 3.0, 5.0 });
NonLinearConjugateGradientOptimizer optimizer
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-6, 1e-6));
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
problem.getObjectiveFunction(),
problem.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 7, 6, 5, 4 }));
Assert.assertEquals(0, optimum.getValue(), 1.0e-10);
}
@Test
public void testMoreEstimatedParametersUnsorted() {
LinearProblem problem = new LinearProblem(new double[][] {
{ 1.0, 1.0, 0.0, 0.0, 0.0, 0.0 },
{ 0.0, 0.0, 1.0, 1.0, 1.0, 0.0 },
{ 0.0, 0.0, 0.0, 0.0, 1.0, -1.0 },
{ 0.0, 0.0, -1.0, 1.0, 0.0, 1.0 },
{ 0.0, 0.0, 0.0, -1.0, 1.0, 0.0 }
}, new double[] { 3.0, 12.0, -1.0, 7.0, 1.0 });
NonLinearConjugateGradientOptimizer optimizer
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-6, 1e-6));
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
problem.getObjectiveFunction(),
problem.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 2, 2, 2, 2, 2, 2 }));
Assert.assertEquals(0, optimum.getValue(), 1.0e-10);
}
@Test
public void testRedundantEquations() {
LinearProblem problem = new LinearProblem(new double[][] {
{ 1.0, 1.0 },
{ 1.0, -1.0 },
{ 1.0, 3.0 }
}, new double[] { 3.0, 1.0, 5.0 });
NonLinearConjugateGradientOptimizer optimizer
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-6, 1e-6));
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
problem.getObjectiveFunction(),
problem.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 1, 1 }));
Assert.assertEquals(2.0, optimum.getPoint()[0], 1.0e-8);
Assert.assertEquals(1.0, optimum.getPoint()[1], 1.0e-8);
}
@Test
public void testInconsistentEquations() {
LinearProblem problem = new LinearProblem(new double[][] {
{ 1.0, 1.0 },
{ 1.0, -1.0 },
{ 1.0, 3.0 }
}, new double[] { 3.0, 1.0, 4.0 });
NonLinearConjugateGradientOptimizer optimizer
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-6, 1e-6));
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
problem.getObjectiveFunction(),
problem.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 1, 1 }));
Assert.assertTrue(optimum.getValue() > 0.1);
}
@Test
public void testCircleFitting() {
CircleScalar problem = new CircleScalar();
problem.addPoint( 30.0, 68.0);
problem.addPoint( 50.0, -6.0);
problem.addPoint(110.0, -20.0);
problem.addPoint( 35.0, 15.0);
problem.addPoint( 45.0, 97.0);
NonLinearConjugateGradientOptimizer optimizer
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-30, 1e-30),
new BrentSolver(1e-15, 1e-13));
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
problem.getObjectiveFunction(),
problem.getObjectiveFunctionGradient(),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 98.680, 47.345 }));
Vector2D center = new Vector2D(optimum.getPointRef()[0], optimum.getPointRef()[1]);
Assert.assertEquals(69.960161753, problem.getRadius(center), 1.0e-8);
Assert.assertEquals(96.075902096, center.getX(), 1.0e-8);
Assert.assertEquals(48.135167894, center.getY(), 1.0e-8);
}
private static class LinearProblem {
final RealMatrix factors;
final double[] target;
public LinearProblem(double[][] factors,
double[] target) {
this.factors = new BlockRealMatrix(factors);
this.target = target;
}
public ObjectiveFunction getObjectiveFunction() {
return new ObjectiveFunction(new MultivariateFunction() {
public double value(double[] point) {
double[] y = factors.operate(point);
double sum = 0;
for (int i = 0; i < y.length; ++i) {
double ri = y[i] - target[i];
sum += ri * ri;
}
return sum;
}
});
}
public ObjectiveFunctionGradient getObjectiveFunctionGradient() {
return new ObjectiveFunctionGradient(new MultivariateVectorFunction() {
public double[] value(double[] point) {
double[] r = factors.operate(point);
for (int i = 0; i < r.length; ++i) {
r[i] -= target[i];
}
double[] p = factors.transpose().operate(r);
for (int i = 0; i < p.length; ++i) {
p[i] *= 2;
}
return p;
}
});
}
}
}

View File

@ -0,0 +1,627 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
import java.util.Arrays;
import java.util.Random;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.ObjectiveFunction;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.SimpleBounds;
import org.junit.Assert;
import org.junit.Test;
/**
* Test for {@link BOBYQAOptimizer}.
*/
public class BOBYQAOptimizerTest {
static final int DIM = 13;
@Test(expected=NumberIsTooLargeException.class)
public void testInitOutOfBounds() {
double[] startPoint = point(DIM, 3);
double[][] boundaries = boundaries(DIM, -1, 2);
doTest(new Rosen(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-13, 1e-6, 2000, null);
}
@Test(expected=DimensionMismatchException.class)
public void testBoundariesDimensionMismatch() {
double[] startPoint = point(DIM, 0.5);
double[][] boundaries = boundaries(DIM + 1, -1, 2);
doTest(new Rosen(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-13, 1e-6, 2000, null);
}
@Test(expected=NumberIsTooSmallException.class)
public void testProblemDimensionTooSmall() {
double[] startPoint = point(1, 0.5);
doTest(new Rosen(), startPoint, null,
GoalType.MINIMIZE,
1e-13, 1e-6, 2000, null);
}
@Test(expected=TooManyEvaluationsException.class)
public void testMaxEvaluations() {
final int lowMaxEval = 2;
double[] startPoint = point(DIM, 0.1);
double[][] boundaries = null;
doTest(new Rosen(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-13, 1e-6, lowMaxEval, null);
}
@Test
public void testRosen() {
double[] startPoint = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected = new PointValuePair(point(DIM,1.0),0.0);
doTest(new Rosen(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-13, 1e-6, 2000, expected);
}
@Test
public void testMaximize() {
double[] startPoint = point(DIM,1.0);
double[][] boundaries = null;
PointValuePair expected = new PointValuePair(point(DIM,0.0),1.0);
doTest(new MinusElli(), startPoint, boundaries,
GoalType.MAXIMIZE,
2e-10, 5e-6, 1000, expected);
boundaries = boundaries(DIM,-0.3,0.3);
startPoint = point(DIM,0.1);
doTest(new MinusElli(), startPoint, boundaries,
GoalType.MAXIMIZE,
2e-10, 5e-6, 1000, expected);
}
@Test
public void testEllipse() {
double[] startPoint = point(DIM,1.0);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Elli(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-13, 1e-6, 1000, expected);
}
@Test
public void testElliRotated() {
double[] startPoint = point(DIM,1.0);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new ElliRotated(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-12, 1e-6, 10000, expected);
}
@Test
public void testCigar() {
double[] startPoint = point(DIM,1.0);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Cigar(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-13, 1e-6, 100, expected);
}
@Test
public void testTwoAxes() {
double[] startPoint = point(DIM,1.0);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new TwoAxes(), startPoint, boundaries,
GoalType.MINIMIZE, 2*
1e-13, 1e-6, 100, expected);
}
@Test
public void testCigTab() {
double[] startPoint = point(DIM,1.0);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new CigTab(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-13, 5e-5, 100, expected);
}
@Test
public void testSphere() {
double[] startPoint = point(DIM,1.0);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Sphere(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-13, 1e-6, 100, expected);
}
@Test
public void testTablet() {
double[] startPoint = point(DIM,1.0);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Tablet(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-13, 1e-6, 100, expected);
}
@Test
public void testDiffPow() {
double[] startPoint = point(DIM/2,1.0);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM/2,0.0),0.0);
doTest(new DiffPow(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-8, 1e-1, 12000, expected);
}
@Test
public void testSsDiffPow() {
double[] startPoint = point(DIM/2,1.0);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM/2,0.0),0.0);
doTest(new SsDiffPow(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-2, 1.3e-1, 50000, expected);
}
@Test
public void testAckley() {
double[] startPoint = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Ackley(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-8, 1e-5, 1000, expected);
}
@Test
public void testRastrigin() {
double[] startPoint = point(DIM,1.0);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Rastrigin(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-13, 1e-6, 1000, expected);
}
@Test
public void testConstrainedRosen() {
double[] startPoint = point(DIM,0.1);
double[][] boundaries = boundaries(DIM,-1,2);
PointValuePair expected =
new PointValuePair(point(DIM,1.0),0.0);
doTest(new Rosen(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-13, 1e-6, 2000, expected);
}
// See MATH-728
@Test
public void testConstrainedRosenWithMoreInterpolationPoints() {
final double[] startPoint = point(DIM, 0.1);
final double[][] boundaries = boundaries(DIM, -1, 2);
final PointValuePair expected = new PointValuePair(point(DIM, 1.0), 0.0);
// This should have been 78 because in the code the hard limit is
// said to be
// ((DIM + 1) * (DIM + 2)) / 2 - (2 * DIM + 1)
// i.e. 78 in this case, but the test fails for 48, 59, 62, 63, 64,
// 65, 66, ...
final int maxAdditionalPoints = 47;
for (int num = 1; num <= maxAdditionalPoints; num++) {
doTest(new Rosen(), startPoint, boundaries,
GoalType.MINIMIZE,
1e-12, 1e-6, 2000,
num,
expected,
"num=" + num);
}
}
/**
* @param func Function to optimize.
* @param startPoint Starting point.
* @param boundaries Upper / lower point limit.
* @param goal Minimization or maximization.
* @param fTol Tolerance relative error on the objective function.
* @param pointTol Tolerance for checking that the optimum is correct.
* @param maxEvaluations Maximum number of evaluations.
* @param expected Expected point / value.
*/
private void doTest(MultivariateFunction func,
double[] startPoint,
double[][] boundaries,
GoalType goal,
double fTol,
double pointTol,
int maxEvaluations,
PointValuePair expected) {
doTest(func,
startPoint,
boundaries,
goal,
fTol,
pointTol,
maxEvaluations,
0,
expected,
"");
}
/**
* @param func Function to optimize.
* @param startPoint Starting point.
* @param boundaries Upper / lower point limit.
* @param goal Minimization or maximization.
* @param fTol Tolerance relative error on the objective function.
* @param pointTol Tolerance for checking that the optimum is correct.
* @param maxEvaluations Maximum number of evaluations.
* @param additionalInterpolationPoints Number of interpolation to used
* in addition to the default (2 * dim + 1).
* @param expected Expected point / value.
*/
private void doTest(MultivariateFunction func,
double[] startPoint,
double[][] boundaries,
GoalType goal,
double fTol,
double pointTol,
int maxEvaluations,
int additionalInterpolationPoints,
PointValuePair expected,
String assertMsg) {
// System.out.println(func.getClass().getName() + " BEGIN"); // XXX
int dim = startPoint.length;
final int numIterpolationPoints = 2 * dim + 1 + additionalInterpolationPoints;
BOBYQAOptimizer optim = new BOBYQAOptimizer(numIterpolationPoints);
PointValuePair result = boundaries == null ?
optim.optimize(new MaxEval(maxEvaluations),
new ObjectiveFunction(func),
goal,
SimpleBounds.unbounded(dim),
new InitialGuess(startPoint)) :
optim.optimize(new MaxEval(maxEvaluations),
new ObjectiveFunction(func),
goal,
new InitialGuess(startPoint),
new SimpleBounds(boundaries[0],
boundaries[1]));
// System.out.println(func.getClass().getName() + " = "
// + optim.getEvaluations() + " f(");
// for (double x: result.getPoint()) System.out.print(x + " ");
// System.out.println(") = " + result.getValue());
Assert.assertEquals(assertMsg, expected.getValue(), result.getValue(), fTol);
for (int i = 0; i < dim; i++) {
Assert.assertEquals(expected.getPoint()[i],
result.getPoint()[i], pointTol);
}
// System.out.println(func.getClass().getName() + " END"); // XXX
}
private static double[] point(int n, double value) {
double[] ds = new double[n];
Arrays.fill(ds, value);
return ds;
}
private static double[][] boundaries(int dim,
double lower, double upper) {
double[][] boundaries = new double[2][dim];
for (int i = 0; i < dim; i++)
boundaries[0][i] = lower;
for (int i = 0; i < dim; i++)
boundaries[1][i] = upper;
return boundaries;
}
private static class Sphere implements MultivariateFunction {
public double value(double[] x) {
double f = 0;
for (int i = 0; i < x.length; ++i)
f += x[i] * x[i];
return f;
}
}
private static class Cigar implements MultivariateFunction {
private double factor;
Cigar() {
this(1e3);
}
Cigar(double axisratio) {
factor = axisratio * axisratio;
}
public double value(double[] x) {
double f = x[0] * x[0];
for (int i = 1; i < x.length; ++i)
f += factor * x[i] * x[i];
return f;
}
}
private static class Tablet implements MultivariateFunction {
private double factor;
Tablet() {
this(1e3);
}
Tablet(double axisratio) {
factor = axisratio * axisratio;
}
public double value(double[] x) {
double f = factor * x[0] * x[0];
for (int i = 1; i < x.length; ++i)
f += x[i] * x[i];
return f;
}
}
private static class CigTab implements MultivariateFunction {
private double factor;
CigTab() {
this(1e4);
}
CigTab(double axisratio) {
factor = axisratio;
}
public double value(double[] x) {
int end = x.length - 1;
double f = x[0] * x[0] / factor + factor * x[end] * x[end];
for (int i = 1; i < end; ++i)
f += x[i] * x[i];
return f;
}
}
private static class TwoAxes implements MultivariateFunction {
private double factor;
TwoAxes() {
this(1e6);
}
TwoAxes(double axisratio) {
factor = axisratio * axisratio;
}
public double value(double[] x) {
double f = 0;
for (int i = 0; i < x.length; ++i)
f += (i < x.length / 2 ? factor : 1) * x[i] * x[i];
return f;
}
}
private static class ElliRotated implements MultivariateFunction {
private Basis B = new Basis();
private double factor;
ElliRotated() {
this(1e3);
}
ElliRotated(double axisratio) {
factor = axisratio * axisratio;
}
public double value(double[] x) {
double f = 0;
x = B.Rotate(x);
for (int i = 0; i < x.length; ++i)
f += Math.pow(factor, i / (x.length - 1.)) * x[i] * x[i];
return f;
}
}
private static class Elli implements MultivariateFunction {
private double factor;
Elli() {
this(1e3);
}
Elli(double axisratio) {
factor = axisratio * axisratio;
}
public double value(double[] x) {
double f = 0;
for (int i = 0; i < x.length; ++i)
f += Math.pow(factor, i / (x.length - 1.)) * x[i] * x[i];
return f;
}
}
private static class MinusElli implements MultivariateFunction {
private final Elli elli = new Elli();
public double value(double[] x) {
return 1.0 - elli.value(x);
}
}
private static class DiffPow implements MultivariateFunction {
// private int fcount = 0;
public double value(double[] x) {
double f = 0;
for (int i = 0; i < x.length; ++i)
f += Math.pow(Math.abs(x[i]), 2. + 10 * (double) i
/ (x.length - 1.));
// System.out.print("" + (fcount++) + ") ");
// for (int i = 0; i < x.length; i++)
// System.out.print(x[i] + " ");
// System.out.println(" = " + f);
return f;
}
}
private static class SsDiffPow implements MultivariateFunction {
public double value(double[] x) {
double f = Math.pow(new DiffPow().value(x), 0.25);
return f;
}
}
private static class Rosen implements MultivariateFunction {
public double value(double[] x) {
double f = 0;
for (int i = 0; i < x.length - 1; ++i)
f += 1e2 * (x[i] * x[i] - x[i + 1]) * (x[i] * x[i] - x[i + 1])
+ (x[i] - 1.) * (x[i] - 1.);
return f;
}
}
private static class Ackley implements MultivariateFunction {
private double axisratio;
Ackley(double axra) {
axisratio = axra;
}
public Ackley() {
this(1);
}
public double value(double[] x) {
double f = 0;
double res2 = 0;
double fac = 0;
for (int i = 0; i < x.length; ++i) {
fac = Math.pow(axisratio, (i - 1.) / (x.length - 1.));
f += fac * fac * x[i] * x[i];
res2 += Math.cos(2. * Math.PI * fac * x[i]);
}
f = (20. - 20. * Math.exp(-0.2 * Math.sqrt(f / x.length))
+ Math.exp(1.) - Math.exp(res2 / x.length));
return f;
}
}
private static class Rastrigin implements MultivariateFunction {
private double axisratio;
private double amplitude;
Rastrigin() {
this(1, 10);
}
Rastrigin(double axisratio, double amplitude) {
this.axisratio = axisratio;
this.amplitude = amplitude;
}
public double value(double[] x) {
double f = 0;
double fac;
for (int i = 0; i < x.length; ++i) {
fac = Math.pow(axisratio, (i - 1.) / (x.length - 1.));
if (i == 0 && x[i] < 0)
fac *= 1.;
f += fac * fac * x[i] * x[i] + amplitude
* (1. - Math.cos(2. * Math.PI * fac * x[i]));
}
return f;
}
}
private static class Basis {
double[][] basis;
Random rand = new Random(2); // use not always the same basis
double[] Rotate(double[] x) {
GenBasis(x.length);
double[] y = new double[x.length];
for (int i = 0; i < x.length; ++i) {
y[i] = 0;
for (int j = 0; j < x.length; ++j)
y[i] += basis[i][j] * x[j];
}
return y;
}
void GenBasis(int DIM) {
if (basis != null ? basis.length == DIM : false)
return;
double sp;
int i, j, k;
/* generate orthogonal basis */
basis = new double[DIM][DIM];
for (i = 0; i < DIM; ++i) {
/* sample components gaussian */
for (j = 0; j < DIM; ++j)
basis[i][j] = rand.nextGaussian();
/* substract projection of previous vectors */
for (j = i - 1; j >= 0; --j) {
for (sp = 0., k = 0; k < DIM; ++k)
sp += basis[i][k] * basis[j][k]; /* scalar product */
for (k = 0; k < DIM; ++k)
basis[i][k] -= sp * basis[j][k]; /* substract */
}
/* normalize */
for (sp = 0., k = 0; k < DIM; ++k)
sp += basis[i][k] * basis[i][k]; /* squared norm */
for (k = 0; k < DIM; ++k)
basis[i][k] /= Math.sqrt(sp);
}
}
}
}

View File

@ -0,0 +1,794 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
import java.util.Arrays;
import java.util.Random;
import org.apache.commons.math3.Retry;
import org.apache.commons.math3.RetryRunner;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathUnsupportedOperationException;
import org.apache.commons.math3.exception.MathIllegalStateException;
import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.exception.OutOfRangeException;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.ObjectiveFunction;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
import org.junit.Ignore;
import org.junit.runner.RunWith;
/**
* Test for {@link CMAESOptimizer}.
*/
@RunWith(RetryRunner.class)
public class CMAESOptimizerTest {
static final int DIM = 13;
static final int LAMBDA = 4 + (int)(3.*Math.log(DIM));
@Test(expected = NumberIsTooLargeException.class)
public void testInitOutofbounds1() {
double[] startPoint = point(DIM,3);
double[] insigma = point(DIM, 0.3);
double[][] boundaries = boundaries(DIM,-1,2);
PointValuePair expected =
new PointValuePair(point(DIM,1.0),0.0);
doTest(new Rosen(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test(expected = NumberIsTooSmallException.class)
public void testInitOutofbounds2() {
double[] startPoint = point(DIM, -2);
double[] insigma = point(DIM, 0.3);
double[][] boundaries = boundaries(DIM,-1,2);
PointValuePair expected =
new PointValuePair(point(DIM,1.0),0.0);
doTest(new Rosen(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test(expected = DimensionMismatchException.class)
public void testBoundariesDimensionMismatch() {
double[] startPoint = point(DIM,0.5);
double[] insigma = point(DIM, 0.3);
double[][] boundaries = boundaries(DIM+1,-1,2);
PointValuePair expected =
new PointValuePair(point(DIM,1.0),0.0);
doTest(new Rosen(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test(expected = NotPositiveException.class)
public void testInputSigmaNegative() {
double[] startPoint = point(DIM,0.5);
double[] insigma = point(DIM,-0.5);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,1.0),0.0);
doTest(new Rosen(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test(expected = OutOfRangeException.class)
public void testInputSigmaOutOfRange() {
double[] startPoint = point(DIM,0.5);
double[] insigma = point(DIM, 1.1);
double[][] boundaries = boundaries(DIM,-0.5,0.5);
PointValuePair expected =
new PointValuePair(point(DIM,1.0),0.0);
doTest(new Rosen(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test(expected = DimensionMismatchException.class)
public void testInputSigmaDimensionMismatch() {
double[] startPoint = point(DIM,0.5);
double[] insigma = point(DIM + 1, 0.5);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,1.0),0.0);
doTest(new Rosen(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test
@Retry(3)
public void testRosen() {
double[] startPoint = point(DIM,0.1);
double[] insigma = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,1.0),0.0);
doTest(new Rosen(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
doTest(new Rosen(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test
@Retry(3)
public void testMaximize() {
double[] startPoint = point(DIM,1.0);
double[] insigma = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),1.0);
doTest(new MinusElli(), startPoint, insigma, boundaries,
GoalType.MAXIMIZE, LAMBDA, true, 0, 1.0-1e-13,
2e-10, 5e-6, 100000, expected);
doTest(new MinusElli(), startPoint, insigma, boundaries,
GoalType.MAXIMIZE, LAMBDA, false, 0, 1.0-1e-13,
2e-10, 5e-6, 100000, expected);
boundaries = boundaries(DIM,-0.3,0.3);
startPoint = point(DIM,0.1);
doTest(new MinusElli(), startPoint, insigma, boundaries,
GoalType.MAXIMIZE, LAMBDA, true, 0, 1.0-1e-13,
2e-10, 5e-6, 100000, expected);
}
@Test
public void testEllipse() {
double[] startPoint = point(DIM,1.0);
double[] insigma = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Elli(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
doTest(new Elli(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test
public void testElliRotated() {
double[] startPoint = point(DIM,1.0);
double[] insigma = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new ElliRotated(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
doTest(new ElliRotated(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test
public void testCigar() {
double[] startPoint = point(DIM,1.0);
double[] insigma = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Cigar(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 200000, expected);
doTest(new Cigar(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test
public void testCigarWithBoundaries() {
double[] startPoint = point(DIM,1.0);
double[] insigma = point(DIM,0.1);
double[][] boundaries = boundaries(DIM, -1e100, Double.POSITIVE_INFINITY);
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Cigar(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 200000, expected);
doTest(new Cigar(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test
public void testTwoAxes() {
double[] startPoint = point(DIM,1.0);
double[] insigma = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new TwoAxes(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, 2*LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 200000, expected);
doTest(new TwoAxes(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, 2*LAMBDA, false, 0, 1e-13,
1e-8, 1e-3, 200000, expected);
}
@Test
public void testCigTab() {
double[] startPoint = point(DIM,1.0);
double[] insigma = point(DIM,0.3);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new CigTab(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 5e-5, 100000, expected);
doTest(new CigTab(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
1e-13, 5e-5, 100000, expected);
}
@Test
public void testSphere() {
double[] startPoint = point(DIM,1.0);
double[] insigma = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Sphere(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
doTest(new Sphere(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test
public void testTablet() {
double[] startPoint = point(DIM,1.0);
double[] insigma = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Tablet(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
doTest(new Tablet(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test
public void testDiffPow() {
double[] startPoint = point(DIM,1.0);
double[] insigma = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new DiffPow(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, 10, true, 0, 1e-13,
1e-8, 1e-1, 100000, expected);
doTest(new DiffPow(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, 10, false, 0, 1e-13,
1e-8, 2e-1, 100000, expected);
}
@Test
public void testSsDiffPow() {
double[] startPoint = point(DIM,1.0);
double[] insigma = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new SsDiffPow(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, 10, true, 0, 1e-13,
1e-4, 1e-1, 200000, expected);
doTest(new SsDiffPow(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, 10, false, 0, 1e-13,
1e-4, 1e-1, 200000, expected);
}
@Test
public void testAckley() {
double[] startPoint = point(DIM,1.0);
double[] insigma = point(DIM,1.0);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Ackley(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, 2*LAMBDA, true, 0, 1e-13,
1e-9, 1e-5, 100000, expected);
doTest(new Ackley(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, 2*LAMBDA, false, 0, 1e-13,
1e-9, 1e-5, 100000, expected);
}
@Test
public void testRastrigin() {
double[] startPoint = point(DIM,0.1);
double[] insigma = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,0.0),0.0);
doTest(new Rastrigin(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, (int)(200*Math.sqrt(DIM)), true, 0, 1e-13,
1e-13, 1e-6, 200000, expected);
doTest(new Rastrigin(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, (int)(200*Math.sqrt(DIM)), false, 0, 1e-13,
1e-13, 1e-6, 200000, expected);
}
@Test
public void testConstrainedRosen() {
double[] startPoint = point(DIM, 0.1);
double[] insigma = point(DIM, 0.1);
double[][] boundaries = boundaries(DIM, -1, 2);
PointValuePair expected =
new PointValuePair(point(DIM,1.0),0.0);
doTest(new Rosen(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, 2*LAMBDA, true, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
doTest(new Rosen(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, 2*LAMBDA, false, 0, 1e-13,
1e-13, 1e-6, 100000, expected);
}
@Test
public void testDiagonalRosen() {
double[] startPoint = point(DIM,0.1);
double[] insigma = point(DIM,0.1);
double[][] boundaries = null;
PointValuePair expected =
new PointValuePair(point(DIM,1.0),0.0);
doTest(new Rosen(), startPoint, insigma, boundaries,
GoalType.MINIMIZE, LAMBDA, false, 1, 1e-13,
1e-10, 1e-4, 1000000, expected);
}
@Test
public void testMath864() {
final CMAESOptimizer optimizer
= new CMAESOptimizer(30000, 0, true, 10,
0, new MersenneTwister(), false, null);
final MultivariateFunction fitnessFunction = new MultivariateFunction() {
public double value(double[] parameters) {
final double target = 1;
final double error = target - parameters[0];
return error * error;
}
};
final double[] start = { 0 };
final double[] lower = { -1e6 };
final double[] upper = { 1.5 };
final double[] sigma = { 1e-1 };
final double[] result = optimizer.optimize(new MaxEval(10000),
new ObjectiveFunction(fitnessFunction),
GoalType.MINIMIZE,
new CMAESOptimizer.PopulationSize(5),
new CMAESOptimizer.Sigma(sigma),
new InitialGuess(start),
new SimpleBounds(lower, upper)).getPoint();
Assert.assertTrue("Out of bounds (" + result[0] + " > " + upper[0] + ")",
result[0] <= upper[0]);
}
/**
* Cf. MATH-867
*/
@Test
public void testFitAccuracyDependsOnBoundary() {
final CMAESOptimizer optimizer
= new CMAESOptimizer(30000, 0, true, 10,
0, new MersenneTwister(), false, null);
final MultivariateFunction fitnessFunction = new MultivariateFunction() {
public double value(double[] parameters) {
final double target = 11.1;
final double error = target - parameters[0];
return error * error;
}
};
final double[] start = { 1 };
// No bounds.
PointValuePair result = optimizer.optimize(new MaxEval(100000),
new ObjectiveFunction(fitnessFunction),
GoalType.MINIMIZE,
SimpleBounds.unbounded(1),
new CMAESOptimizer.PopulationSize(5),
new CMAESOptimizer.Sigma(new double[] { 1e-1 }),
new InitialGuess(start));
final double resNoBound = result.getPoint()[0];
// Optimum is near the lower bound.
final double[] lower = { -20 };
final double[] upper = { 5e16 };
final double[] sigma = { 10 };
result = optimizer.optimize(new MaxEval(100000),
new ObjectiveFunction(fitnessFunction),
GoalType.MINIMIZE,
new CMAESOptimizer.PopulationSize(5),
new CMAESOptimizer.Sigma(sigma),
new InitialGuess(start),
new SimpleBounds(lower, upper));
final double resNearLo = result.getPoint()[0];
// Optimum is near the upper bound.
lower[0] = -5e16;
upper[0] = 20;
result = optimizer.optimize(new MaxEval(100000),
new ObjectiveFunction(fitnessFunction),
GoalType.MINIMIZE,
new CMAESOptimizer.PopulationSize(5),
new CMAESOptimizer.Sigma(sigma),
new InitialGuess(start),
new SimpleBounds(lower, upper));
final double resNearHi = result.getPoint()[0];
// System.out.println("resNoBound=" + resNoBound +
// " resNearLo=" + resNearLo +
// " resNearHi=" + resNearHi);
// The two values currently differ by a substantial amount, indicating that
// the bounds definition can prevent reaching the optimum.
Assert.assertEquals(resNoBound, resNearLo, 1e-3);
Assert.assertEquals(resNoBound, resNearHi, 1e-3);
}
/**
* @param func Function to optimize.
* @param startPoint Starting point.
* @param inSigma Individual input sigma.
* @param boundaries Upper / lower point limit.
* @param goal Minimization or maximization.
* @param lambda Population size used for offspring.
* @param isActive Covariance update mechanism.
* @param diagonalOnly Simplified covariance update.
* @param stopValue Termination criteria for optimization.
* @param fTol Tolerance relative error on the objective function.
* @param pointTol Tolerance for checking that the optimum is correct.
* @param maxEvaluations Maximum number of evaluations.
* @param expected Expected point / value.
*/
private void doTest(MultivariateFunction func,
double[] startPoint,
double[] inSigma,
double[][] boundaries,
GoalType goal,
int lambda,
boolean isActive,
int diagonalOnly,
double stopValue,
double fTol,
double pointTol,
int maxEvaluations,
PointValuePair expected) {
int dim = startPoint.length;
// test diagonalOnly = 0 - slow but normally fewer feval#
CMAESOptimizer optim = new CMAESOptimizer(30000, stopValue, isActive, diagonalOnly,
0, new MersenneTwister(), false, null);
PointValuePair result = boundaries == null ?
optim.optimize(new MaxEval(maxEvaluations),
new ObjectiveFunction(func),
goal,
new InitialGuess(startPoint),
SimpleBounds.unbounded(dim),
new CMAESOptimizer.Sigma(inSigma),
new CMAESOptimizer.PopulationSize(lambda)) :
optim.optimize(new MaxEval(maxEvaluations),
new ObjectiveFunction(func),
goal,
new SimpleBounds(boundaries[0],
boundaries[1]),
new InitialGuess(startPoint),
new CMAESOptimizer.Sigma(inSigma),
new CMAESOptimizer.PopulationSize(lambda));
// System.out.println("sol=" + Arrays.toString(result.getPoint()));
Assert.assertEquals(expected.getValue(), result.getValue(), fTol);
for (int i = 0; i < dim; i++) {
Assert.assertEquals(expected.getPoint()[i], result.getPoint()[i], pointTol);
}
}
private static double[] point(int n, double value) {
double[] ds = new double[n];
Arrays.fill(ds, value);
return ds;
}
private static double[][] boundaries(int dim,
double lower, double upper) {
double[][] boundaries = new double[2][dim];
for (int i = 0; i < dim; i++)
boundaries[0][i] = lower;
for (int i = 0; i < dim; i++)
boundaries[1][i] = upper;
return boundaries;
}
private static class Sphere implements MultivariateFunction {
public double value(double[] x) {
double f = 0;
for (int i = 0; i < x.length; ++i)
f += x[i] * x[i];
return f;
}
}
private static class Cigar implements MultivariateFunction {
private double factor;
Cigar() {
this(1e3);
}
Cigar(double axisratio) {
factor = axisratio * axisratio;
}
public double value(double[] x) {
double f = x[0] * x[0];
for (int i = 1; i < x.length; ++i)
f += factor * x[i] * x[i];
return f;
}
}
private static class Tablet implements MultivariateFunction {
private double factor;
Tablet() {
this(1e3);
}
Tablet(double axisratio) {
factor = axisratio * axisratio;
}
public double value(double[] x) {
double f = factor * x[0] * x[0];
for (int i = 1; i < x.length; ++i)
f += x[i] * x[i];
return f;
}
}
private static class CigTab implements MultivariateFunction {
private double factor;
CigTab() {
this(1e4);
}
CigTab(double axisratio) {
factor = axisratio;
}
public double value(double[] x) {
int end = x.length - 1;
double f = x[0] * x[0] / factor + factor * x[end] * x[end];
for (int i = 1; i < end; ++i)
f += x[i] * x[i];
return f;
}
}
private static class TwoAxes implements MultivariateFunction {
private double factor;
TwoAxes() {
this(1e6);
}
TwoAxes(double axisratio) {
factor = axisratio * axisratio;
}
public double value(double[] x) {
double f = 0;
for (int i = 0; i < x.length; ++i)
f += (i < x.length / 2 ? factor : 1) * x[i] * x[i];
return f;
}
}
private static class ElliRotated implements MultivariateFunction {
private Basis B = new Basis();
private double factor;
ElliRotated() {
this(1e3);
}
ElliRotated(double axisratio) {
factor = axisratio * axisratio;
}
public double value(double[] x) {
double f = 0;
x = B.Rotate(x);
for (int i = 0; i < x.length; ++i)
f += Math.pow(factor, i / (x.length - 1.)) * x[i] * x[i];
return f;
}
}
private static class Elli implements MultivariateFunction {
private double factor;
Elli() {
this(1e3);
}
Elli(double axisratio) {
factor = axisratio * axisratio;
}
public double value(double[] x) {
double f = 0;
for (int i = 0; i < x.length; ++i)
f += Math.pow(factor, i / (x.length - 1.)) * x[i] * x[i];
return f;
}
}
private static class MinusElli implements MultivariateFunction {
public double value(double[] x) {
return 1.0-(new Elli().value(x));
}
}
private static class DiffPow implements MultivariateFunction {
public double value(double[] x) {
double f = 0;
for (int i = 0; i < x.length; ++i)
f += Math.pow(Math.abs(x[i]), 2. + 10 * (double) i
/ (x.length - 1.));
return f;
}
}
private static class SsDiffPow implements MultivariateFunction {
public double value(double[] x) {
double f = Math.pow(new DiffPow().value(x), 0.25);
return f;
}
}
private static class Rosen implements MultivariateFunction {
public double value(double[] x) {
double f = 0;
for (int i = 0; i < x.length - 1; ++i)
f += 1e2 * (x[i] * x[i] - x[i + 1]) * (x[i] * x[i] - x[i + 1])
+ (x[i] - 1.) * (x[i] - 1.);
return f;
}
}
private static class Ackley implements MultivariateFunction {
private double axisratio;
Ackley(double axra) {
axisratio = axra;
}
public Ackley() {
this(1);
}
public double value(double[] x) {
double f = 0;
double res2 = 0;
double fac = 0;
for (int i = 0; i < x.length; ++i) {
fac = Math.pow(axisratio, (i - 1.) / (x.length - 1.));
f += fac * fac * x[i] * x[i];
res2 += Math.cos(2. * Math.PI * fac * x[i]);
}
f = (20. - 20. * Math.exp(-0.2 * Math.sqrt(f / x.length))
+ Math.exp(1.) - Math.exp(res2 / x.length));
return f;
}
}
private static class Rastrigin implements MultivariateFunction {
private double axisratio;
private double amplitude;
Rastrigin() {
this(1, 10);
}
Rastrigin(double axisratio, double amplitude) {
this.axisratio = axisratio;
this.amplitude = amplitude;
}
public double value(double[] x) {
double f = 0;
double fac;
for (int i = 0; i < x.length; ++i) {
fac = Math.pow(axisratio, (i - 1.) / (x.length - 1.));
if (i == 0 && x[i] < 0)
fac *= 1.;
f += fac * fac * x[i] * x[i] + amplitude
* (1. - Math.cos(2. * Math.PI * fac * x[i]));
}
return f;
}
}
private static class Basis {
double[][] basis;
Random rand = new Random(2); // use not always the same basis
double[] Rotate(double[] x) {
GenBasis(x.length);
double[] y = new double[x.length];
for (int i = 0; i < x.length; ++i) {
y[i] = 0;
for (int j = 0; j < x.length; ++j)
y[i] += basis[i][j] * x[j];
}
return y;
}
void GenBasis(int DIM) {
if (basis != null ? basis.length == DIM : false)
return;
double sp;
int i, j, k;
/* generate orthogonal basis */
basis = new double[DIM][DIM];
for (i = 0; i < DIM; ++i) {
/* sample components gaussian */
for (j = 0; j < DIM; ++j)
basis[i][j] = rand.nextGaussian();
/* substract projection of previous vectors */
for (j = i - 1; j >= 0; --j) {
for (sp = 0., k = 0; k < DIM; ++k)
sp += basis[i][k] * basis[j][k]; /* scalar product */
for (k = 0; k < DIM; ++k)
basis[i][k] -= sp * basis[j][k]; /* substract */
}
/* normalize */
for (sp = 0., k = 0; k < DIM; ++k)
sp += basis[i][k] * basis[i][k]; /* squared norm */
for (k = 0; k < DIM; ++k)
basis[i][k] /= Math.sqrt(sp);
}
}
}
}

View File

@ -0,0 +1,251 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.SumSincFunction;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.ObjectiveFunction;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
/**
* Test for {@link PowellOptimizer}.
*/
public class PowellOptimizerTest {
@Test
public void testSumSinc() {
final MultivariateFunction func = new SumSincFunction(-1);
int dim = 2;
final double[] minPoint = new double[dim];
for (int i = 0; i < dim; i++) {
minPoint[i] = 0;
}
double[] init = new double[dim];
// Initial is minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i];
}
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-9, 1e-9);
// Initial is far from minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i] + 3;
}
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-9, 1e-5);
// More stringent line search tolerance enhances the precision
// of the result.
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-9, 1e-9, 1e-7);
}
@Test
public void testQuadratic() {
final MultivariateFunction func = new MultivariateFunction() {
public double value(double[] x) {
final double a = x[0] - 1;
final double b = x[1] - 1;
return a * a + b * b + 1;
}
};
int dim = 2;
final double[] minPoint = new double[dim];
for (int i = 0; i < dim; i++) {
minPoint[i] = 1;
}
double[] init = new double[dim];
// Initial is minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i];
}
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-9, 1e-8);
// Initial is far from minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i] - 20;
}
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-9, 1e-8);
}
@Test
public void testMaximizeQuadratic() {
final MultivariateFunction func = new MultivariateFunction() {
public double value(double[] x) {
final double a = x[0] - 1;
final double b = x[1] - 1;
return -a * a - b * b + 1;
}
};
int dim = 2;
final double[] maxPoint = new double[dim];
for (int i = 0; i < dim; i++) {
maxPoint[i] = 1;
}
double[] init = new double[dim];
// Initial is minimum.
for (int i = 0; i < dim; i++) {
init[i] = maxPoint[i];
}
doTest(func, maxPoint, init, GoalType.MAXIMIZE, 1e-9, 1e-8);
// Initial is far from minimum.
for (int i = 0; i < dim; i++) {
init[i] = maxPoint[i] - 20;
}
doTest(func, maxPoint, init, GoalType.MAXIMIZE, 1e-9, 1e-8);
}
/**
* Ensure that we do not increase the number of function evaluations when
* the function values are scaled up.
* Note that the tolerances parameters passed to the constructor must
* still hold sensible values because they are used to set the line search
* tolerances.
*/
@Test
public void testRelativeToleranceOnScaledValues() {
final MultivariateFunction func = new MultivariateFunction() {
public double value(double[] x) {
final double a = x[0] - 1;
final double b = x[1] - 1;
return a * a * FastMath.sqrt(FastMath.abs(a)) + b * b + 1;
}
};
int dim = 2;
final double[] minPoint = new double[dim];
for (int i = 0; i < dim; i++) {
minPoint[i] = 1;
}
double[] init = new double[dim];
// Initial is far from minimum.
for (int i = 0; i < dim; i++) {
init[i] = minPoint[i] - 20;
}
final double relTol = 1e-10;
final int maxEval = 1000;
// Very small absolute tolerance to rely solely on the relative
// tolerance as a stopping criterion
final PowellOptimizer optim = new PowellOptimizer(relTol, 1e-100);
final PointValuePair funcResult = optim.optimize(new MaxEval(maxEval),
new ObjectiveFunction(func),
GoalType.MINIMIZE,
new InitialGuess(init));
final double funcValue = func.value(funcResult.getPoint());
final int funcEvaluations = optim.getEvaluations();
final double scale = 1e10;
final MultivariateFunction funcScaled = new MultivariateFunction() {
public double value(double[] x) {
return scale * func.value(x);
}
};
final PointValuePair funcScaledResult = optim.optimize(new MaxEval(maxEval),
new ObjectiveFunction(funcScaled),
GoalType.MINIMIZE,
new InitialGuess(init));
final double funcScaledValue = funcScaled.value(funcScaledResult.getPoint());
final int funcScaledEvaluations = optim.getEvaluations();
// Check that both minima provide the same objective funciton values,
// within the relative function tolerance.
Assert.assertEquals(1, funcScaledValue / (scale * funcValue), relTol);
// Check that the numbers of evaluations are the same.
Assert.assertEquals(funcEvaluations, funcScaledEvaluations);
}
/**
* @param func Function to optimize.
* @param optimum Expected optimum.
* @param init Starting point.
* @param goal Minimization or maximization.
* @param fTol Tolerance (relative error on the objective function) for
* "Powell" algorithm.
* @param pointTol Tolerance for checking that the optimum is correct.
*/
private void doTest(MultivariateFunction func,
double[] optimum,
double[] init,
GoalType goal,
double fTol,
double pointTol) {
final PowellOptimizer optim = new PowellOptimizer(fTol, Math.ulp(1d));
final PointValuePair result = optim.optimize(new MaxEval(1000),
new ObjectiveFunction(func),
goal,
new InitialGuess(init));
final double[] point = result.getPoint();
for (int i = 0, dim = optimum.length; i < dim; i++) {
Assert.assertEquals("found[" + i + "]=" + point[i] + " value=" + result.getValue(),
optimum[i], point[i], pointTol);
}
}
/**
* @param func Function to optimize.
* @param optimum Expected optimum.
* @param init Starting point.
* @param goal Minimization or maximization.
* @param fTol Tolerance (relative error on the objective function) for
* "Powell" algorithm.
* @param fLineTol Tolerance (relative error on the objective function)
* for the internal line search algorithm.
* @param pointTol Tolerance for checking that the optimum is correct.
*/
private void doTest(MultivariateFunction func,
double[] optimum,
double[] init,
GoalType goal,
double fTol,
double fLineTol,
double pointTol) {
final PowellOptimizer optim = new PowellOptimizer(fTol, Math.ulp(1d),
fLineTol, Math.ulp(1d));
final PointValuePair result = optim.optimize(new MaxEval(1000),
new ObjectiveFunction(func),
goal,
new InitialGuess(init));
final double[] point = result.getPoint();
for (int i = 0, dim = optimum.length; i < dim; i++) {
Assert.assertEquals("found[" + i + "]=" + point[i] + " value=" + result.getValue(),
optimum[i], point[i], pointTol);
}
}
}

View File

@ -0,0 +1,228 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleValueChecker;
import org.apache.commons.math3.optim.ObjectiveFunction;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
public class SimplexOptimizerMultiDirectionalTest {
@Test
public void testMinimize1() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-11, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(200),
new ObjectiveFunction(fourExtrema),
GoalType.MINIMIZE,
new InitialGuess(new double[] { -3, 0 }),
new MultiDirectionalSimplex(new double[] { 0.2, 0.2 }));
Assert.assertEquals(fourExtrema.xM, optimum.getPoint()[0], 4e-6);
Assert.assertEquals(fourExtrema.yP, optimum.getPoint()[1], 3e-6);
Assert.assertEquals(fourExtrema.valueXmYp, optimum.getValue(), 8e-13);
Assert.assertTrue(optimizer.getEvaluations() > 120);
Assert.assertTrue(optimizer.getEvaluations() < 150);
}
@Test
public void testMinimize2() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-11, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(200),
new ObjectiveFunction(fourExtrema),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 1, 0 }),
new MultiDirectionalSimplex(new double[] { 0.2, 0.2 }));
Assert.assertEquals(fourExtrema.xP, optimum.getPoint()[0], 2e-8);
Assert.assertEquals(fourExtrema.yM, optimum.getPoint()[1], 3e-6);
Assert.assertEquals(fourExtrema.valueXpYm, optimum.getValue(), 2e-12);
Assert.assertTrue(optimizer.getEvaluations() > 120);
Assert.assertTrue(optimizer.getEvaluations() < 150);
}
@Test
public void testMaximize1() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-11, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(200),
new ObjectiveFunction(fourExtrema),
GoalType.MAXIMIZE,
new InitialGuess(new double[] { -3.0, 0.0 }),
new MultiDirectionalSimplex(new double[] { 0.2, 0.2 }));
Assert.assertEquals(fourExtrema.xM, optimum.getPoint()[0], 7e-7);
Assert.assertEquals(fourExtrema.yM, optimum.getPoint()[1], 3e-7);
Assert.assertEquals(fourExtrema.valueXmYm, optimum.getValue(), 2e-14);
Assert.assertTrue(optimizer.getEvaluations() > 120);
Assert.assertTrue(optimizer.getEvaluations() < 150);
}
@Test
public void testMaximize2() {
SimplexOptimizer optimizer = new SimplexOptimizer(new SimpleValueChecker(1e-15, 1e-30));
final FourExtrema fourExtrema = new FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(200),
new ObjectiveFunction(fourExtrema),
GoalType.MAXIMIZE,
new InitialGuess(new double[] { 1, 0 }),
new MultiDirectionalSimplex(new double[] { 0.2, 0.2 }));
Assert.assertEquals(fourExtrema.xP, optimum.getPoint()[0], 2e-8);
Assert.assertEquals(fourExtrema.yP, optimum.getPoint()[1], 3e-6);
Assert.assertEquals(fourExtrema.valueXpYp, optimum.getValue(), 2e-12);
Assert.assertTrue(optimizer.getEvaluations() > 180);
Assert.assertTrue(optimizer.getEvaluations() < 220);
}
@Test
public void testRosenbrock() {
MultivariateFunction rosenbrock
= new MultivariateFunction() {
public double value(double[] x) {
++count;
double a = x[1] - x[0] * x[0];
double b = 1.0 - x[0];
return 100 * a * a + b * b;
}
};
count = 0;
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
new ObjectiveFunction(rosenbrock),
GoalType.MINIMIZE,
new InitialGuess(new double[] { -1.2, 1 }),
new MultiDirectionalSimplex(new double[][] {
{ -1.2, 1.0 },
{ 0.9, 1.2 },
{ 3.5, -2.3 } }));
Assert.assertEquals(count, optimizer.getEvaluations());
Assert.assertTrue(optimizer.getEvaluations() > 50);
Assert.assertTrue(optimizer.getEvaluations() < 100);
Assert.assertTrue(optimum.getValue() > 1e-2);
}
@Test
public void testPowell() {
MultivariateFunction powell
= new MultivariateFunction() {
public double value(double[] x) {
++count;
double a = x[0] + 10 * x[1];
double b = x[2] - x[3];
double c = x[1] - 2 * x[2];
double d = x[0] - x[3];
return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
}
};
count = 0;
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
PointValuePair optimum
= optimizer.optimize(new MaxEval(1000),
new ObjectiveFunction(powell),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 3, -1, 0, 1 }),
new MultiDirectionalSimplex(4));
Assert.assertEquals(count, optimizer.getEvaluations());
Assert.assertTrue(optimizer.getEvaluations() > 800);
Assert.assertTrue(optimizer.getEvaluations() < 900);
Assert.assertTrue(optimum.getValue() > 1e-2);
}
@Test
public void testMath283() {
// fails because MultiDirectional.iterateSimplex is looping forever
// the while(true) should be replaced with a convergence check
SimplexOptimizer optimizer = new SimplexOptimizer(1e-14, 1e-14);
final Gaussian2D function = new Gaussian2D(0, 0, 1);
PointValuePair estimate = optimizer.optimize(new MaxEval(1000),
new ObjectiveFunction(function),
GoalType.MAXIMIZE,
new InitialGuess(function.getMaximumPosition()),
new MultiDirectionalSimplex(2));
final double EPSILON = 1e-5;
final double expectedMaximum = function.getMaximum();
final double actualMaximum = estimate.getValue();
Assert.assertEquals(expectedMaximum, actualMaximum, EPSILON);
final double[] expectedPosition = function.getMaximumPosition();
final double[] actualPosition = estimate.getPoint();
Assert.assertEquals(expectedPosition[0], actualPosition[0], EPSILON );
Assert.assertEquals(expectedPosition[1], actualPosition[1], EPSILON );
}
private static class FourExtrema implements MultivariateFunction {
// The following function has 4 local extrema.
final double xM = -3.841947088256863675365;
final double yM = -1.391745200270734924416;
final double xP = 0.2286682237349059125691;
final double yP = -yM;
final double valueXmYm = 0.2373295333134216789769; // Local maximum.
final double valueXmYp = -valueXmYm; // Local minimum.
final double valueXpYm = -0.7290400707055187115322; // Global minimum.
final double valueXpYp = -valueXpYm; // Global maximum.
public double value(double[] variables) {
final double x = variables[0];
final double y = variables[1];
return (x == 0 || y == 0) ? 0 :
FastMath.atan(x) * FastMath.atan(x + 2) * FastMath.atan(y) * FastMath.atan(y) / (x * y);
}
}
private static class Gaussian2D implements MultivariateFunction {
private final double[] maximumPosition;
private final double std;
public Gaussian2D(double xOpt, double yOpt, double std) {
maximumPosition = new double[] { xOpt, yOpt };
this.std = std;
}
public double getMaximum() {
return value(maximumPosition);
}
public double[] getMaximumPosition() {
return maximumPosition.clone();
}
public double value(double[] point) {
final double x = point[0], y = point[1];
final double twoS2 = 2.0 * std * std;
return 1.0 / (twoS2 * FastMath.PI) * FastMath.exp(-(x * x + y * y) / twoS2);
}
}
private int count;
}

View File

@ -0,0 +1,295 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.optim.GoalType;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.ObjectiveFunction;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.nonlinear.scalar.LeastSquaresConverter;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
public class SimplexOptimizerNelderMeadTest {
@Test
public void testMinimize1() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
new ObjectiveFunction(fourExtrema),
GoalType.MINIMIZE,
new InitialGuess(new double[] { -3, 0 }),
new NelderMeadSimplex(new double[] { 0.2, 0.2 }));
Assert.assertEquals(fourExtrema.xM, optimum.getPoint()[0], 2e-7);
Assert.assertEquals(fourExtrema.yP, optimum.getPoint()[1], 2e-5);
Assert.assertEquals(fourExtrema.valueXmYp, optimum.getValue(), 6e-12);
Assert.assertTrue(optimizer.getEvaluations() > 60);
Assert.assertTrue(optimizer.getEvaluations() < 90);
}
@Test
public void testMinimize2() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
new ObjectiveFunction(fourExtrema),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 1, 0 }),
new NelderMeadSimplex(new double[] { 0.2, 0.2 }));
Assert.assertEquals(fourExtrema.xP, optimum.getPoint()[0], 5e-6);
Assert.assertEquals(fourExtrema.yM, optimum.getPoint()[1], 6e-6);
Assert.assertEquals(fourExtrema.valueXpYm, optimum.getValue(), 1e-11);
Assert.assertTrue(optimizer.getEvaluations() > 60);
Assert.assertTrue(optimizer.getEvaluations() < 90);
}
@Test
public void testMaximize1() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
new ObjectiveFunction(fourExtrema),
GoalType.MAXIMIZE,
new InitialGuess(new double[] { -3, 0 }),
new NelderMeadSimplex(new double[] { 0.2, 0.2 }));
Assert.assertEquals(fourExtrema.xM, optimum.getPoint()[0], 1e-5);
Assert.assertEquals(fourExtrema.yM, optimum.getPoint()[1], 3e-6);
Assert.assertEquals(fourExtrema.valueXmYm, optimum.getValue(), 3e-12);
Assert.assertTrue(optimizer.getEvaluations() > 60);
Assert.assertTrue(optimizer.getEvaluations() < 90);
}
@Test
public void testMaximize2() {
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
final FourExtrema fourExtrema = new FourExtrema();
final PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
new ObjectiveFunction(fourExtrema),
GoalType.MAXIMIZE,
new InitialGuess(new double[] { 1, 0 }),
new NelderMeadSimplex(new double[] { 0.2, 0.2 }));
Assert.assertEquals(fourExtrema.xP, optimum.getPoint()[0], 4e-6);
Assert.assertEquals(fourExtrema.yP, optimum.getPoint()[1], 5e-6);
Assert.assertEquals(fourExtrema.valueXpYp, optimum.getValue(), 7e-12);
Assert.assertTrue(optimizer.getEvaluations() > 60);
Assert.assertTrue(optimizer.getEvaluations() < 90);
}
@Test
public void testRosenbrock() {
Rosenbrock rosenbrock = new Rosenbrock();
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
PointValuePair optimum
= optimizer.optimize(new MaxEval(100),
new ObjectiveFunction(rosenbrock),
GoalType.MINIMIZE,
new InitialGuess(new double[] { -1.2, 1 }),
new NelderMeadSimplex(new double[][] {
{ -1.2, 1 },
{ 0.9, 1.2 },
{ 3.5, -2.3 } }));
Assert.assertEquals(rosenbrock.getCount(), optimizer.getEvaluations());
Assert.assertTrue(optimizer.getEvaluations() > 40);
Assert.assertTrue(optimizer.getEvaluations() < 50);
Assert.assertTrue(optimum.getValue() < 8e-4);
}
@Test
public void testPowell() {
Powell powell = new Powell();
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
PointValuePair optimum =
optimizer.optimize(new MaxEval(200),
new ObjectiveFunction(powell),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 3, -1, 0, 1 }),
new NelderMeadSimplex(4));
Assert.assertEquals(powell.getCount(), optimizer.getEvaluations());
Assert.assertTrue(optimizer.getEvaluations() > 110);
Assert.assertTrue(optimizer.getEvaluations() < 130);
Assert.assertTrue(optimum.getValue() < 2e-3);
}
@Test
public void testLeastSquares1() {
final RealMatrix factors
= new Array2DRowRealMatrix(new double[][] {
{ 1, 0 },
{ 0, 1 }
}, false);
LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorFunction() {
public double[] value(double[] variables) {
return factors.operate(variables);
}
}, new double[] { 2.0, -3.0 });
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-6);
PointValuePair optimum =
optimizer.optimize(new MaxEval(200),
new ObjectiveFunction(ls),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 10, 10 }),
new NelderMeadSimplex(2));
Assert.assertEquals( 2, optimum.getPointRef()[0], 3e-5);
Assert.assertEquals(-3, optimum.getPointRef()[1], 4e-4);
Assert.assertTrue(optimizer.getEvaluations() > 60);
Assert.assertTrue(optimizer.getEvaluations() < 80);
Assert.assertTrue(optimum.getValue() < 1.0e-6);
}
@Test
public void testLeastSquares2() {
final RealMatrix factors
= new Array2DRowRealMatrix(new double[][] {
{ 1, 0 },
{ 0, 1 }
}, false);
LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorFunction() {
public double[] value(double[] variables) {
return factors.operate(variables);
}
}, new double[] { 2, -3 }, new double[] { 10, 0.1 });
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-6);
PointValuePair optimum =
optimizer.optimize(new MaxEval(200),
new ObjectiveFunction(ls),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 10, 10 }),
new NelderMeadSimplex(2));
Assert.assertEquals( 2, optimum.getPointRef()[0], 5e-5);
Assert.assertEquals(-3, optimum.getPointRef()[1], 8e-4);
Assert.assertTrue(optimizer.getEvaluations() > 60);
Assert.assertTrue(optimizer.getEvaluations() < 80);
Assert.assertTrue(optimum.getValue() < 1e-6);
}
@Test
public void testLeastSquares3() {
final RealMatrix factors =
new Array2DRowRealMatrix(new double[][] {
{ 1, 0 },
{ 0, 1 }
}, false);
LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorFunction() {
public double[] value(double[] variables) {
return factors.operate(variables);
}
}, new double[] { 2, -3 }, new Array2DRowRealMatrix(new double [][] {
{ 1, 1.2 }, { 1.2, 2 }
}));
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-6);
PointValuePair optimum
= optimizer.optimize(new MaxEval(200),
new ObjectiveFunction(ls),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 10, 10 }),
new NelderMeadSimplex(2));
Assert.assertEquals( 2, optimum.getPointRef()[0], 2e-3);
Assert.assertEquals(-3, optimum.getPointRef()[1], 8e-4);
Assert.assertTrue(optimizer.getEvaluations() > 60);
Assert.assertTrue(optimizer.getEvaluations() < 80);
Assert.assertTrue(optimum.getValue() < 1e-6);
}
@Test(expected=TooManyEvaluationsException.class)
public void testMaxIterations() {
Powell powell = new Powell();
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
optimizer.optimize(new MaxEval(20),
new ObjectiveFunction(powell),
GoalType.MINIMIZE,
new InitialGuess(new double[] { 3, -1, 0, 1 }),
new NelderMeadSimplex(4));
}
private static class FourExtrema implements MultivariateFunction {
// The following function has 4 local extrema.
final double xM = -3.841947088256863675365;
final double yM = -1.391745200270734924416;
final double xP = 0.2286682237349059125691;
final double yP = -yM;
final double valueXmYm = 0.2373295333134216789769; // Local maximum.
final double valueXmYp = -valueXmYm; // Local minimum.
final double valueXpYm = -0.7290400707055187115322; // Global minimum.
final double valueXpYp = -valueXpYm; // Global maximum.
public double value(double[] variables) {
final double x = variables[0];
final double y = variables[1];
return (x == 0 || y == 0) ? 0 :
FastMath.atan(x) * FastMath.atan(x + 2) * FastMath.atan(y) * FastMath.atan(y) / (x * y);
}
}
private static class Rosenbrock implements MultivariateFunction {
private int count;
public Rosenbrock() {
count = 0;
}
public double value(double[] x) {
++count;
double a = x[1] - x[0] * x[0];
double b = 1.0 - x[0];
return 100 * a * a + b * b;
}
public int getCount() {
return count;
}
}
private static class Powell implements MultivariateFunction {
private int count;
public Powell() {
count = 0;
}
public double value(double[] x) {
++count;
double a = x[0] + 10 * x[1];
double b = x[2] - x[3];
double c = x[1] - 2 * x[2];
double d = x[0] - x[3];
return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
}
public int getCount() {
return count;
}
}
}

Some files were not shown because too many files have changed in this diff Show More