MATH-874
Refactored of the contents of package "o.a.c.m.optimization" into the new "o.a.c.m.optim" and "o.a.c.m.fitting" packages. * All deprecated classes/fields/methods have been removed in the replacement packages. * Simplified API: a single "optimize(OptimizationData... data)" for all optimizer types. * Simplified class hierarchy, merged interfaces and abstract classes, only base classes are generic. * The new classes do not use the "DerivativeStructure" type. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1420684 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
63623c9236
commit
7ee7843ffe
|
@ -0,0 +1,40 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.exception;
|
||||
|
||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
|
||||
/**
|
||||
* Exception to be thrown when the maximal number of iterations is exceeded.
|
||||
*
|
||||
* @since 3.1
|
||||
* @version $Id$
|
||||
*/
|
||||
public class TooManyIterationsException extends MaxCountExceededException {
|
||||
/** Serializable version Id. */
|
||||
private static final long serialVersionUID = 20121211L;
|
||||
|
||||
/**
|
||||
* Construct the exception.
|
||||
*
|
||||
* @param max Maximum number of evaluations.
|
||||
*/
|
||||
public TooManyIterationsException(Number max) {
|
||||
super(max);
|
||||
getContext().addMessage(LocalizedFormats.ITERATIONS);
|
||||
}
|
||||
}
|
|
@ -148,6 +148,7 @@ public enum LocalizedFormats implements Localizable {
|
|||
INVALID_REGRESSION_OBSERVATION("length of regressor array = {0} does not match the number of variables = {1} in the model"),
|
||||
INVALID_ROUNDING_METHOD("invalid rounding method {0}, valid methods: {1} ({2}), {3} ({4}), {5} ({6}), {7} ({8}), {9} ({10}), {11} ({12}), {13} ({14}), {15} ({16})"),
|
||||
ITERATOR_EXHAUSTED("iterator exhausted"),
|
||||
ITERATIONS("iterations"), /* keep */
|
||||
LCM_OVERFLOW_32_BITS("overflow: lcm({0}, {1}) is 2^31"),
|
||||
LCM_OVERFLOW_64_BITS("overflow: lcm({0}, {1}) is 2^63"),
|
||||
LIST_OF_CHROMOSOMES_BIGGER_THAN_POPULATION_SIZE("list of chromosomes bigger than maxPopulationSize"),
|
||||
|
|
|
@ -0,0 +1,230 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.fitting;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
||||
import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
|
||||
import org.apache.commons.math3.optim.MaxEval;
|
||||
import org.apache.commons.math3.optim.InitialGuess;
|
||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.ModelFunction;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.ModelFunctionJacobian;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.Target;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.Weight;
|
||||
|
||||
/**
|
||||
* Fitter for parametric univariate real functions y = f(x).
|
||||
* <br/>
|
||||
* When a univariate real function y = f(x) does depend on some
|
||||
* unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
|
||||
* this class can be used to find these parameters. It does this
|
||||
* by <em>fitting</em> the curve so it remains very close to a set of
|
||||
* observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
|
||||
* y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
|
||||
* is done by finding the parameters values that minimizes the objective
|
||||
* function ∑(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
|
||||
* really a least squares problem.
|
||||
*
|
||||
* @param <T> Function to use for the fit.
|
||||
*
|
||||
* @version $Id: CurveFitter.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public class CurveFitter<T extends ParametricUnivariateFunction> {
|
||||
/** Optimizer to use for the fitting. */
|
||||
private final MultivariateVectorOptimizer optimizer;
|
||||
/** Observed points. */
|
||||
private final List<WeightedObservedPoint> observations;
|
||||
|
||||
/**
|
||||
* Simple constructor.
|
||||
*
|
||||
* @param optimizer Optimizer to use for the fitting.
|
||||
* @since 3.1
|
||||
*/
|
||||
public CurveFitter(final MultivariateVectorOptimizer optimizer) {
|
||||
this.optimizer = optimizer;
|
||||
observations = new ArrayList<WeightedObservedPoint>();
|
||||
}
|
||||
|
||||
/** Add an observed (x,y) point to the sample with unit weight.
|
||||
* <p>Calling this method is equivalent to call
|
||||
* {@code addObservedPoint(1.0, x, y)}.</p>
|
||||
* @param x abscissa of the point
|
||||
* @param y observed value of the point at x, after fitting we should
|
||||
* have f(x) as close as possible to this value
|
||||
* @see #addObservedPoint(double, double, double)
|
||||
* @see #addObservedPoint(WeightedObservedPoint)
|
||||
* @see #getObservations()
|
||||
*/
|
||||
public void addObservedPoint(double x, double y) {
|
||||
addObservedPoint(1.0, x, y);
|
||||
}
|
||||
|
||||
/** Add an observed weighted (x,y) point to the sample.
|
||||
* @param weight weight of the observed point in the fit
|
||||
* @param x abscissa of the point
|
||||
* @param y observed value of the point at x, after fitting we should
|
||||
* have f(x) as close as possible to this value
|
||||
* @see #addObservedPoint(double, double)
|
||||
* @see #addObservedPoint(WeightedObservedPoint)
|
||||
* @see #getObservations()
|
||||
*/
|
||||
public void addObservedPoint(double weight, double x, double y) {
|
||||
observations.add(new WeightedObservedPoint(weight, x, y));
|
||||
}
|
||||
|
||||
/** Add an observed weighted (x,y) point to the sample.
|
||||
* @param observed observed point to add
|
||||
* @see #addObservedPoint(double, double)
|
||||
* @see #addObservedPoint(double, double, double)
|
||||
* @see #getObservations()
|
||||
*/
|
||||
public void addObservedPoint(WeightedObservedPoint observed) {
|
||||
observations.add(observed);
|
||||
}
|
||||
|
||||
/** Get the observed points.
|
||||
* @return observed points
|
||||
* @see #addObservedPoint(double, double)
|
||||
* @see #addObservedPoint(double, double, double)
|
||||
* @see #addObservedPoint(WeightedObservedPoint)
|
||||
*/
|
||||
public WeightedObservedPoint[] getObservations() {
|
||||
return observations.toArray(new WeightedObservedPoint[observations.size()]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove all observations.
|
||||
*/
|
||||
public void clearObservations() {
|
||||
observations.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Fit a curve.
|
||||
* This method compute the coefficients of the curve that best
|
||||
* fit the sample of observed points previously given through calls
|
||||
* to the {@link #addObservedPoint(WeightedObservedPoint)
|
||||
* addObservedPoint} method.
|
||||
*
|
||||
* @param f parametric function to fit.
|
||||
* @param initialGuess first guess of the function parameters.
|
||||
* @return the fitted parameters.
|
||||
* @throws org.apache.commons.math3.exception.DimensionMismatchException
|
||||
* if the start point dimension is wrong.
|
||||
*/
|
||||
public double[] fit(T f, final double[] initialGuess) {
|
||||
return fit(Integer.MAX_VALUE, f, initialGuess);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fit a curve.
|
||||
* This method compute the coefficients of the curve that best
|
||||
* fit the sample of observed points previously given through calls
|
||||
* to the {@link #addObservedPoint(WeightedObservedPoint)
|
||||
* addObservedPoint} method.
|
||||
*
|
||||
* @param f parametric function to fit.
|
||||
* @param initialGuess first guess of the function parameters.
|
||||
* @param maxEval Maximum number of function evaluations.
|
||||
* @return the fitted parameters.
|
||||
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
|
||||
* if the number of allowed evaluations is exceeded.
|
||||
* @throws org.apache.commons.math3.exception.DimensionMismatchException
|
||||
* if the start point dimension is wrong.
|
||||
* @since 3.0
|
||||
*/
|
||||
public double[] fit(int maxEval, T f,
|
||||
final double[] initialGuess) {
|
||||
// Prepare least squares problem.
|
||||
double[] target = new double[observations.size()];
|
||||
double[] weights = new double[observations.size()];
|
||||
int i = 0;
|
||||
for (WeightedObservedPoint point : observations) {
|
||||
target[i] = point.getY();
|
||||
weights[i] = point.getWeight();
|
||||
++i;
|
||||
}
|
||||
|
||||
// Input to the optimizer: the model and its Jacobian.
|
||||
final TheoreticalValuesFunction model = new TheoreticalValuesFunction(f);
|
||||
|
||||
// Perform the fit.
|
||||
final PointVectorValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(maxEval),
|
||||
model.getModelFunction(),
|
||||
model.getModelFunctionJacobian(),
|
||||
new Target(target),
|
||||
new Weight(weights),
|
||||
new InitialGuess(initialGuess));
|
||||
// Extract the coefficients.
|
||||
return optimum.getPointRef();
|
||||
}
|
||||
|
||||
/** Vectorial function computing function theoretical values. */
|
||||
private class TheoreticalValuesFunction {
|
||||
/** Function to fit. */
|
||||
private final ParametricUnivariateFunction f;
|
||||
|
||||
/**
|
||||
* @param f function to fit.
|
||||
*/
|
||||
public TheoreticalValuesFunction(final ParametricUnivariateFunction f) {
|
||||
this.f = f;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the model function values.
|
||||
*/
|
||||
public ModelFunction getModelFunction() {
|
||||
return new ModelFunction(new MultivariateVectorFunction() {
|
||||
/** {@inheritDoc} */
|
||||
public double[] value(double[] point) {
|
||||
// compute the residuals
|
||||
final double[] values = new double[observations.size()];
|
||||
int i = 0;
|
||||
for (WeightedObservedPoint observed : observations) {
|
||||
values[i++] = f.value(observed.getX(), point);
|
||||
}
|
||||
|
||||
return values;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the model function Jacobian.
|
||||
*/
|
||||
public ModelFunctionJacobian getModelFunctionJacobian() {
|
||||
return new ModelFunctionJacobian(new MultivariateMatrixFunction() {
|
||||
public double[][] value(double[] point) {
|
||||
final double[][] jacobian = new double[observations.size()][];
|
||||
int i = 0;
|
||||
for (WeightedObservedPoint observed : observations) {
|
||||
jacobian[i++] = f.gradient(observed.getX(), point);
|
||||
}
|
||||
return jacobian;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,362 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.fitting;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import org.apache.commons.math3.analysis.function.Gaussian;
|
||||
import org.apache.commons.math3.exception.NullArgumentException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
import org.apache.commons.math3.exception.ZeroException;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
|
||||
/**
|
||||
* Fits points to a {@link
|
||||
* org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian} function.
|
||||
* <p>
|
||||
* Usage example:
|
||||
* <pre>
|
||||
* GaussianFitter fitter = new GaussianFitter(
|
||||
* new LevenbergMarquardtOptimizer());
|
||||
* fitter.addObservedPoint(4.0254623, 531026.0);
|
||||
* fitter.addObservedPoint(4.03128248, 984167.0);
|
||||
* fitter.addObservedPoint(4.03839603, 1887233.0);
|
||||
* fitter.addObservedPoint(4.04421621, 2687152.0);
|
||||
* fitter.addObservedPoint(4.05132976, 3461228.0);
|
||||
* fitter.addObservedPoint(4.05326982, 3580526.0);
|
||||
* fitter.addObservedPoint(4.05779662, 3439750.0);
|
||||
* fitter.addObservedPoint(4.0636168, 2877648.0);
|
||||
* fitter.addObservedPoint(4.06943698, 2175960.0);
|
||||
* fitter.addObservedPoint(4.07525716, 1447024.0);
|
||||
* fitter.addObservedPoint(4.08237071, 717104.0);
|
||||
* fitter.addObservedPoint(4.08366408, 620014.0);
|
||||
* double[] parameters = fitter.fit();
|
||||
* </pre>
|
||||
*
|
||||
* @since 2.2
|
||||
* @version $Id: GaussianFitter.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
*/
|
||||
public class GaussianFitter extends CurveFitter<Gaussian.Parametric> {
|
||||
/**
|
||||
* Constructs an instance using the specified optimizer.
|
||||
*
|
||||
* @param optimizer Optimizer to use for the fitting.
|
||||
*/
|
||||
public GaussianFitter(MultivariateVectorOptimizer optimizer) {
|
||||
super(optimizer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fits a Gaussian function to the observed points.
|
||||
*
|
||||
* @param initialGuess First guess values in the following order:
|
||||
* <ul>
|
||||
* <li>Norm</li>
|
||||
* <li>Mean</li>
|
||||
* <li>Sigma</li>
|
||||
* </ul>
|
||||
* @return the parameters of the Gaussian function that best fits the
|
||||
* observed points (in the same order as above).
|
||||
* @since 3.0
|
||||
*/
|
||||
public double[] fit(double[] initialGuess) {
|
||||
final Gaussian.Parametric f = new Gaussian.Parametric() {
|
||||
@Override
|
||||
public double value(double x, double ... p) {
|
||||
double v = Double.POSITIVE_INFINITY;
|
||||
try {
|
||||
v = super.value(x, p);
|
||||
} catch (NotStrictlyPositiveException e) {
|
||||
// Do nothing.
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] gradient(double x, double ... p) {
|
||||
double[] v = { Double.POSITIVE_INFINITY,
|
||||
Double.POSITIVE_INFINITY,
|
||||
Double.POSITIVE_INFINITY };
|
||||
try {
|
||||
v = super.gradient(x, p);
|
||||
} catch (NotStrictlyPositiveException e) {
|
||||
// Do nothing.
|
||||
}
|
||||
return v;
|
||||
}
|
||||
};
|
||||
|
||||
return fit(f, initialGuess);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fits a Gaussian function to the observed points.
|
||||
*
|
||||
* @return the parameters of the Gaussian function that best fits the
|
||||
* observed points (in the same order as above).
|
||||
*/
|
||||
public double[] fit() {
|
||||
final double[] guess = (new ParameterGuesser(getObservations())).guess();
|
||||
return fit(guess);
|
||||
}
|
||||
|
||||
/**
|
||||
* Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
|
||||
* of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric}
|
||||
* based on the specified observed points.
|
||||
*/
|
||||
public static class ParameterGuesser {
|
||||
/** Normalization factor. */
|
||||
private final double norm;
|
||||
/** Mean. */
|
||||
private final double mean;
|
||||
/** Standard deviation. */
|
||||
private final double sigma;
|
||||
|
||||
/**
|
||||
* Constructs instance with the specified observed points.
|
||||
*
|
||||
* @param observations Observed points from which to guess the
|
||||
* parameters of the Gaussian.
|
||||
* @throws NullArgumentException if {@code observations} is
|
||||
* {@code null}.
|
||||
* @throws NumberIsTooSmallException if there are less than 3
|
||||
* observations.
|
||||
*/
|
||||
public ParameterGuesser(WeightedObservedPoint[] observations) {
|
||||
if (observations == null) {
|
||||
throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
|
||||
}
|
||||
if (observations.length < 3) {
|
||||
throw new NumberIsTooSmallException(observations.length, 3, true);
|
||||
}
|
||||
|
||||
final WeightedObservedPoint[] sorted = sortObservations(observations);
|
||||
final double[] params = basicGuess(sorted);
|
||||
|
||||
norm = params[0];
|
||||
mean = params[1];
|
||||
sigma = params[2];
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets an estimation of the parameters.
|
||||
*
|
||||
* @return the guessed parameters, in the following order:
|
||||
* <ul>
|
||||
* <li>Normalization factor</li>
|
||||
* <li>Mean</li>
|
||||
* <li>Standard deviation</li>
|
||||
* </ul>
|
||||
*/
|
||||
public double[] guess() {
|
||||
return new double[] { norm, mean, sigma };
|
||||
}
|
||||
|
||||
/**
|
||||
* Sort the observations.
|
||||
*
|
||||
* @param unsorted Input observations.
|
||||
* @return the input observations, sorted.
|
||||
*/
|
||||
private WeightedObservedPoint[] sortObservations(WeightedObservedPoint[] unsorted) {
|
||||
final WeightedObservedPoint[] observations = unsorted.clone();
|
||||
final Comparator<WeightedObservedPoint> cmp
|
||||
= new Comparator<WeightedObservedPoint>() {
|
||||
public int compare(WeightedObservedPoint p1,
|
||||
WeightedObservedPoint p2) {
|
||||
if (p1 == null && p2 == null) {
|
||||
return 0;
|
||||
}
|
||||
if (p1 == null) {
|
||||
return -1;
|
||||
}
|
||||
if (p2 == null) {
|
||||
return 1;
|
||||
}
|
||||
if (p1.getX() < p2.getX()) {
|
||||
return -1;
|
||||
}
|
||||
if (p1.getX() > p2.getX()) {
|
||||
return 1;
|
||||
}
|
||||
if (p1.getY() < p2.getY()) {
|
||||
return -1;
|
||||
}
|
||||
if (p1.getY() > p2.getY()) {
|
||||
return 1;
|
||||
}
|
||||
if (p1.getWeight() < p2.getWeight()) {
|
||||
return -1;
|
||||
}
|
||||
if (p1.getWeight() > p2.getWeight()) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
Arrays.sort(observations, cmp);
|
||||
return observations;
|
||||
}
|
||||
|
||||
/**
|
||||
* Guesses the parameters based on the specified observed points.
|
||||
*
|
||||
* @param points Observed points, sorted.
|
||||
* @return the guessed parameters (normalization factor, mean and
|
||||
* sigma).
|
||||
*/
|
||||
private double[] basicGuess(WeightedObservedPoint[] points) {
|
||||
final int maxYIdx = findMaxY(points);
|
||||
final double n = points[maxYIdx].getY();
|
||||
final double m = points[maxYIdx].getX();
|
||||
|
||||
double fwhmApprox;
|
||||
try {
|
||||
final double halfY = n + ((m - n) / 2);
|
||||
final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
|
||||
final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
|
||||
fwhmApprox = fwhmX2 - fwhmX1;
|
||||
} catch (OutOfRangeException e) {
|
||||
// TODO: Exceptions should not be used for flow control.
|
||||
fwhmApprox = points[points.length - 1].getX() - points[0].getX();
|
||||
}
|
||||
final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));
|
||||
|
||||
return new double[] { n, m, s };
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds index of point in specified points with the largest Y.
|
||||
*
|
||||
* @param points Points to search.
|
||||
* @return the index in specified points array.
|
||||
*/
|
||||
private int findMaxY(WeightedObservedPoint[] points) {
|
||||
int maxYIdx = 0;
|
||||
for (int i = 1; i < points.length; i++) {
|
||||
if (points[i].getY() > points[maxYIdx].getY()) {
|
||||
maxYIdx = i;
|
||||
}
|
||||
}
|
||||
return maxYIdx;
|
||||
}
|
||||
|
||||
/**
|
||||
* Interpolates using the specified points to determine X at the
|
||||
* specified Y.
|
||||
*
|
||||
* @param points Points to use for interpolation.
|
||||
* @param startIdx Index within points from which to start the search for
|
||||
* interpolation bounds points.
|
||||
* @param idxStep Index step for searching interpolation bounds points.
|
||||
* @param y Y value for which X should be determined.
|
||||
* @return the value of X for the specified Y.
|
||||
* @throws ZeroException if {@code idxStep} is 0.
|
||||
* @throws OutOfRangeException if specified {@code y} is not within the
|
||||
* range of the specified {@code points}.
|
||||
*/
|
||||
private double interpolateXAtY(WeightedObservedPoint[] points,
|
||||
int startIdx,
|
||||
int idxStep,
|
||||
double y)
|
||||
throws OutOfRangeException {
|
||||
if (idxStep == 0) {
|
||||
throw new ZeroException();
|
||||
}
|
||||
final WeightedObservedPoint[] twoPoints
|
||||
= getInterpolationPointsForY(points, startIdx, idxStep, y);
|
||||
final WeightedObservedPoint p1 = twoPoints[0];
|
||||
final WeightedObservedPoint p2 = twoPoints[1];
|
||||
if (p1.getY() == y) {
|
||||
return p1.getX();
|
||||
}
|
||||
if (p2.getY() == y) {
|
||||
return p2.getX();
|
||||
}
|
||||
return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
|
||||
(p2.getY() - p1.getY()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the two bounding interpolation points from the specified points
|
||||
* suitable for determining X at the specified Y.
|
||||
*
|
||||
* @param points Points to use for interpolation.
|
||||
* @param startIdx Index within points from which to start search for
|
||||
* interpolation bounds points.
|
||||
* @param idxStep Index step for search for interpolation bounds points.
|
||||
* @param y Y value for which X should be determined.
|
||||
* @return the array containing two points suitable for determining X at
|
||||
* the specified Y.
|
||||
* @throws ZeroException if {@code idxStep} is 0.
|
||||
* @throws OutOfRangeException if specified {@code y} is not within the
|
||||
* range of the specified {@code points}.
|
||||
*/
|
||||
private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
|
||||
int startIdx,
|
||||
int idxStep,
|
||||
double y)
|
||||
throws OutOfRangeException {
|
||||
if (idxStep == 0) {
|
||||
throw new ZeroException();
|
||||
}
|
||||
for (int i = startIdx;
|
||||
idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
|
||||
i += idxStep) {
|
||||
final WeightedObservedPoint p1 = points[i];
|
||||
final WeightedObservedPoint p2 = points[i + idxStep];
|
||||
if (isBetween(y, p1.getY(), p2.getY())) {
|
||||
if (idxStep < 0) {
|
||||
return new WeightedObservedPoint[] { p2, p1 };
|
||||
} else {
|
||||
return new WeightedObservedPoint[] { p1, p2 };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Boundaries are replaced by dummy values because the raised
|
||||
// exception is caught and the message never displayed.
|
||||
// TODO: Exceptions should not be used for flow control.
|
||||
throw new OutOfRangeException(y,
|
||||
Double.NEGATIVE_INFINITY,
|
||||
Double.POSITIVE_INFINITY);
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines whether a value is between two other values.
|
||||
*
|
||||
* @param value Value to test whether it is between {@code boundary1}
|
||||
* and {@code boundary2}.
|
||||
* @param boundary1 One end of the range.
|
||||
* @param boundary2 Other end of the range.
|
||||
* @return {@code true} if {@code value} is between {@code boundary1} and
|
||||
* {@code boundary2} (inclusive), {@code false} otherwise.
|
||||
*/
|
||||
private boolean isBetween(double value,
|
||||
double boundary1,
|
||||
double boundary2) {
|
||||
return (value >= boundary1 && value <= boundary2) ||
|
||||
(value >= boundary2 && value <= boundary1);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,382 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.fitting;
|
||||
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
|
||||
import org.apache.commons.math3.analysis.function.HarmonicOscillator;
|
||||
import org.apache.commons.math3.exception.ZeroException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.exception.MathIllegalStateException;
|
||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
|
||||
/**
|
||||
* Class that implements a curve fitting specialized for sinusoids.
|
||||
*
|
||||
* Harmonic fitting is a very simple case of curve fitting. The
|
||||
* estimated coefficients are the amplitude a, the pulsation ω and
|
||||
* the phase φ: <code>f (t) = a cos (ω t + φ)</code>. They are
|
||||
* searched by a least square estimator initialized with a rough guess
|
||||
* based on integrals.
|
||||
*
|
||||
* @version $Id: HarmonicFitter.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public class HarmonicFitter extends CurveFitter<HarmonicOscillator.Parametric> {
|
||||
/**
|
||||
* Simple constructor.
|
||||
* @param optimizer Optimizer to use for the fitting.
|
||||
*/
|
||||
public HarmonicFitter(final MultivariateVectorOptimizer optimizer) {
|
||||
super(optimizer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fit an harmonic function to the observed points.
|
||||
*
|
||||
* @param initialGuess First guess values in the following order:
|
||||
* <ul>
|
||||
* <li>Amplitude</li>
|
||||
* <li>Angular frequency</li>
|
||||
* <li>Phase</li>
|
||||
* </ul>
|
||||
* @return the parameters of the harmonic function that best fits the
|
||||
* observed points (in the same order as above).
|
||||
*/
|
||||
public double[] fit(double[] initialGuess) {
|
||||
return fit(new HarmonicOscillator.Parametric(), initialGuess);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fit an harmonic function to the observed points.
|
||||
* An initial guess will be automatically computed.
|
||||
*
|
||||
* @return the parameters of the harmonic function that best fits the
|
||||
* observed points (see the other {@link #fit(double[]) fit} method.
|
||||
* @throws NumberIsTooSmallException if the sample is too short for the
|
||||
* the first guess to be computed.
|
||||
* @throws ZeroException if the first guess cannot be computed because
|
||||
* the abscissa range is zero.
|
||||
*/
|
||||
public double[] fit() {
|
||||
return fit((new ParameterGuesser(getObservations())).guess());
|
||||
}
|
||||
|
||||
/**
|
||||
* This class guesses harmonic coefficients from a sample.
|
||||
* <p>The algorithm used to guess the coefficients is as follows:</p>
|
||||
*
|
||||
* <p>We know f (t) at some sampling points t<sub>i</sub> and want to find a,
|
||||
* ω and φ such that f (t) = a cos (ω t + φ).
|
||||
* </p>
|
||||
*
|
||||
* <p>From the analytical expression, we can compute two primitives :
|
||||
* <pre>
|
||||
* If2 (t) = ∫ f<sup>2</sup> = a<sup>2</sup> × [t + S (t)] / 2
|
||||
* If'2 (t) = ∫ f'<sup>2</sup> = a<sup>2</sup> ω<sup>2</sup> × [t - S (t)] / 2
|
||||
* where S (t) = sin (2 (ω t + φ)) / (2 ω)
|
||||
* </pre>
|
||||
* </p>
|
||||
*
|
||||
* <p>We can remove S between these expressions :
|
||||
* <pre>
|
||||
* If'2 (t) = a<sup>2</sup> ω<sup>2</sup> t - ω<sup>2</sup> If2 (t)
|
||||
* </pre>
|
||||
* </p>
|
||||
*
|
||||
* <p>The preceding expression shows that If'2 (t) is a linear
|
||||
* combination of both t and If2 (t): If'2 (t) = A × t + B × If2 (t)
|
||||
* </p>
|
||||
*
|
||||
* <p>From the primitive, we can deduce the same form for definite
|
||||
* integrals between t<sub>1</sub> and t<sub>i</sub> for each t<sub>i</sub> :
|
||||
* <pre>
|
||||
* If2 (t<sub>i</sub>) - If2 (t<sub>1</sub>) = A × (t<sub>i</sub> - t<sub>1</sub>) + B × (If2 (t<sub>i</sub>) - If2 (t<sub>1</sub>))
|
||||
* </pre>
|
||||
* </p>
|
||||
*
|
||||
* <p>We can find the coefficients A and B that best fit the sample
|
||||
* to this linear expression by computing the definite integrals for
|
||||
* each sample points.
|
||||
* </p>
|
||||
*
|
||||
* <p>For a bilinear expression z (x<sub>i</sub>, y<sub>i</sub>) = A × x<sub>i</sub> + B × y<sub>i</sub>, the
|
||||
* coefficients A and B that minimize a least square criterion
|
||||
* ∑ (z<sub>i</sub> - z (x<sub>i</sub>, y<sub>i</sub>))<sup>2</sup> are given by these expressions:</p>
|
||||
* <pre>
|
||||
*
|
||||
* ∑y<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>z<sub>i</sub> - ∑x<sub>i</sub>y<sub>i</sub> ∑y<sub>i</sub>z<sub>i</sub>
|
||||
* A = ------------------------
|
||||
* ∑x<sub>i</sub>x<sub>i</sub> ∑y<sub>i</sub>y<sub>i</sub> - ∑x<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>y<sub>i</sub>
|
||||
*
|
||||
* ∑x<sub>i</sub>x<sub>i</sub> ∑y<sub>i</sub>z<sub>i</sub> - ∑x<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>z<sub>i</sub>
|
||||
* B = ------------------------
|
||||
* ∑x<sub>i</sub>x<sub>i</sub> ∑y<sub>i</sub>y<sub>i</sub> - ∑x<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>y<sub>i</sub>
|
||||
* </pre>
|
||||
* </p>
|
||||
*
|
||||
*
|
||||
* <p>In fact, we can assume both a and ω are positive and
|
||||
* compute them directly, knowing that A = a<sup>2</sup> ω<sup>2</sup> and that
|
||||
* B = - ω<sup>2</sup>. The complete algorithm is therefore:</p>
|
||||
* <pre>
|
||||
*
|
||||
* for each t<sub>i</sub> from t<sub>1</sub> to t<sub>n-1</sub>, compute:
|
||||
* f (t<sub>i</sub>)
|
||||
* f' (t<sub>i</sub>) = (f (t<sub>i+1</sub>) - f(t<sub>i-1</sub>)) / (t<sub>i+1</sub> - t<sub>i-1</sub>)
|
||||
* x<sub>i</sub> = t<sub>i</sub> - t<sub>1</sub>
|
||||
* y<sub>i</sub> = ∫ f<sup>2</sup> from t<sub>1</sub> to t<sub>i</sub>
|
||||
* z<sub>i</sub> = ∫ f'<sup>2</sup> from t<sub>1</sub> to t<sub>i</sub>
|
||||
* update the sums ∑x<sub>i</sub>x<sub>i</sub>, ∑y<sub>i</sub>y<sub>i</sub>, ∑x<sub>i</sub>y<sub>i</sub>, ∑x<sub>i</sub>z<sub>i</sub> and ∑y<sub>i</sub>z<sub>i</sub>
|
||||
* end for
|
||||
*
|
||||
* |--------------------------
|
||||
* \ | ∑y<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>z<sub>i</sub> - ∑x<sub>i</sub>y<sub>i</sub> ∑y<sub>i</sub>z<sub>i</sub>
|
||||
* a = \ | ------------------------
|
||||
* \| ∑x<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>z<sub>i</sub> - ∑x<sub>i</sub>x<sub>i</sub> ∑y<sub>i</sub>z<sub>i</sub>
|
||||
*
|
||||
*
|
||||
* |--------------------------
|
||||
* \ | ∑x<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>z<sub>i</sub> - ∑x<sub>i</sub>x<sub>i</sub> ∑y<sub>i</sub>z<sub>i</sub>
|
||||
* ω = \ | ------------------------
|
||||
* \| ∑x<sub>i</sub>x<sub>i</sub> ∑y<sub>i</sub>y<sub>i</sub> - ∑x<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>y<sub>i</sub>
|
||||
*
|
||||
* </pre>
|
||||
* </p>
|
||||
*
|
||||
* <p>Once we know ω, we can compute:
|
||||
* <pre>
|
||||
* fc = ω f (t) cos (ω t) - f' (t) sin (ω t)
|
||||
* fs = ω f (t) sin (ω t) + f' (t) cos (ω t)
|
||||
* </pre>
|
||||
* </p>
|
||||
*
|
||||
* <p>It appears that <code>fc = a ω cos (φ)</code> and
|
||||
* <code>fs = -a ω sin (φ)</code>, so we can use these
|
||||
* expressions to compute φ. The best estimate over the sample is
|
||||
* given by averaging these expressions.
|
||||
* </p>
|
||||
*
|
||||
* <p>Since integrals and means are involved in the preceding
|
||||
* estimations, these operations run in O(n) time, where n is the
|
||||
* number of measurements.</p>
|
||||
*/
|
||||
public static class ParameterGuesser {
|
||||
/** Amplitude. */
|
||||
private final double a;
|
||||
/** Angular frequency. */
|
||||
private final double omega;
|
||||
/** Phase. */
|
||||
private final double phi;
|
||||
|
||||
/**
|
||||
* Simple constructor.
|
||||
*
|
||||
* @param observations Sampled observations.
|
||||
* @throws NumberIsTooSmallException if the sample is too short.
|
||||
* @throws ZeroException if the abscissa range is zero.
|
||||
* @throws MathIllegalStateException when the guessing procedure cannot
|
||||
* produce sensible results.
|
||||
*/
|
||||
public ParameterGuesser(WeightedObservedPoint[] observations) {
|
||||
if (observations.length < 4) {
|
||||
throw new NumberIsTooSmallException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE,
|
||||
observations.length, 4, true);
|
||||
}
|
||||
|
||||
final WeightedObservedPoint[] sorted = sortObservations(observations);
|
||||
|
||||
final double aOmega[] = guessAOmega(sorted);
|
||||
a = aOmega[0];
|
||||
omega = aOmega[1];
|
||||
|
||||
phi = guessPhi(sorted);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets an estimation of the parameters.
|
||||
*
|
||||
* @return the guessed parameters, in the following order:
|
||||
* <ul>
|
||||
* <li>Amplitude</li>
|
||||
* <li>Angular frequency</li>
|
||||
* <li>Phase</li>
|
||||
* </ul>
|
||||
*/
|
||||
public double[] guess() {
|
||||
return new double[] { a, omega, phi };
|
||||
}
|
||||
|
||||
/**
|
||||
* Sort the observations with respect to the abscissa.
|
||||
*
|
||||
* @param unsorted Input observations.
|
||||
* @return the input observations, sorted.
|
||||
*/
|
||||
private WeightedObservedPoint[] sortObservations(WeightedObservedPoint[] unsorted) {
|
||||
final WeightedObservedPoint[] observations = unsorted.clone();
|
||||
|
||||
// Since the samples are almost always already sorted, this
|
||||
// method is implemented as an insertion sort that reorders the
|
||||
// elements in place. Insertion sort is very efficient in this case.
|
||||
WeightedObservedPoint curr = observations[0];
|
||||
for (int j = 1; j < observations.length; ++j) {
|
||||
WeightedObservedPoint prec = curr;
|
||||
curr = observations[j];
|
||||
if (curr.getX() < prec.getX()) {
|
||||
// the current element should be inserted closer to the beginning
|
||||
int i = j - 1;
|
||||
WeightedObservedPoint mI = observations[i];
|
||||
while ((i >= 0) && (curr.getX() < mI.getX())) {
|
||||
observations[i + 1] = mI;
|
||||
if (i-- != 0) {
|
||||
mI = observations[i];
|
||||
}
|
||||
}
|
||||
observations[i + 1] = curr;
|
||||
curr = observations[j];
|
||||
}
|
||||
}
|
||||
|
||||
return observations;
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate a first guess of the amplitude and angular frequency.
|
||||
* This method assumes that the {@link #sortObservations()} method
|
||||
* has been called previously.
|
||||
*
|
||||
* @param observations Observations, sorted w.r.t. abscissa.
|
||||
* @throws ZeroException if the abscissa range is zero.
|
||||
* @throws MathIllegalStateException when the guessing procedure cannot
|
||||
* produce sensible results.
|
||||
* @return the guessed amplitude (at index 0) and circular frequency
|
||||
* (at index 1).
|
||||
*/
|
||||
private double[] guessAOmega(WeightedObservedPoint[] observations) {
|
||||
final double[] aOmega = new double[2];
|
||||
|
||||
// initialize the sums for the linear model between the two integrals
|
||||
double sx2 = 0;
|
||||
double sy2 = 0;
|
||||
double sxy = 0;
|
||||
double sxz = 0;
|
||||
double syz = 0;
|
||||
|
||||
double currentX = observations[0].getX();
|
||||
double currentY = observations[0].getY();
|
||||
double f2Integral = 0;
|
||||
double fPrime2Integral = 0;
|
||||
final double startX = currentX;
|
||||
for (int i = 1; i < observations.length; ++i) {
|
||||
// one step forward
|
||||
final double previousX = currentX;
|
||||
final double previousY = currentY;
|
||||
currentX = observations[i].getX();
|
||||
currentY = observations[i].getY();
|
||||
|
||||
// update the integrals of f<sup>2</sup> and f'<sup>2</sup>
|
||||
// considering a linear model for f (and therefore constant f')
|
||||
final double dx = currentX - previousX;
|
||||
final double dy = currentY - previousY;
|
||||
final double f2StepIntegral =
|
||||
dx * (previousY * previousY + previousY * currentY + currentY * currentY) / 3;
|
||||
final double fPrime2StepIntegral = dy * dy / dx;
|
||||
|
||||
final double x = currentX - startX;
|
||||
f2Integral += f2StepIntegral;
|
||||
fPrime2Integral += fPrime2StepIntegral;
|
||||
|
||||
sx2 += x * x;
|
||||
sy2 += f2Integral * f2Integral;
|
||||
sxy += x * f2Integral;
|
||||
sxz += x * fPrime2Integral;
|
||||
syz += f2Integral * fPrime2Integral;
|
||||
}
|
||||
|
||||
// compute the amplitude and pulsation coefficients
|
||||
double c1 = sy2 * sxz - sxy * syz;
|
||||
double c2 = sxy * sxz - sx2 * syz;
|
||||
double c3 = sx2 * sy2 - sxy * sxy;
|
||||
if ((c1 / c2 < 0) || (c2 / c3 < 0)) {
|
||||
final int last = observations.length - 1;
|
||||
// Range of the observations, assuming that the
|
||||
// observations are sorted.
|
||||
final double xRange = observations[last].getX() - observations[0].getX();
|
||||
if (xRange == 0) {
|
||||
throw new ZeroException();
|
||||
}
|
||||
aOmega[1] = 2 * Math.PI / xRange;
|
||||
|
||||
double yMin = Double.POSITIVE_INFINITY;
|
||||
double yMax = Double.NEGATIVE_INFINITY;
|
||||
for (int i = 1; i < observations.length; ++i) {
|
||||
final double y = observations[i].getY();
|
||||
if (y < yMin) {
|
||||
yMin = y;
|
||||
}
|
||||
if (y > yMax) {
|
||||
yMax = y;
|
||||
}
|
||||
}
|
||||
aOmega[0] = 0.5 * (yMax - yMin);
|
||||
} else {
|
||||
if (c2 == 0) {
|
||||
// In some ill-conditioned cases (cf. MATH-844), the guesser
|
||||
// procedure cannot produce sensible results.
|
||||
throw new MathIllegalStateException(LocalizedFormats.ZERO_DENOMINATOR);
|
||||
}
|
||||
|
||||
aOmega[0] = FastMath.sqrt(c1 / c2);
|
||||
aOmega[1] = FastMath.sqrt(c2 / c3);
|
||||
}
|
||||
|
||||
return aOmega;
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate a first guess of the phase.
|
||||
*
|
||||
* @param observations Observations, sorted w.r.t. abscissa.
|
||||
* @return the guessed phase.
|
||||
*/
|
||||
private double guessPhi(WeightedObservedPoint[] observations) {
|
||||
// initialize the means
|
||||
double fcMean = 0;
|
||||
double fsMean = 0;
|
||||
|
||||
double currentX = observations[0].getX();
|
||||
double currentY = observations[0].getY();
|
||||
for (int i = 1; i < observations.length; ++i) {
|
||||
// one step forward
|
||||
final double previousX = currentX;
|
||||
final double previousY = currentY;
|
||||
currentX = observations[i].getX();
|
||||
currentY = observations[i].getY();
|
||||
final double currentYPrime = (currentY - previousY) / (currentX - previousX);
|
||||
|
||||
double omegaX = omega * currentX;
|
||||
double cosine = FastMath.cos(omegaX);
|
||||
double sine = FastMath.sin(omegaX);
|
||||
fcMean += omega * currentY * cosine - currentYPrime * sine;
|
||||
fsMean += omega * currentY * sine + currentYPrime * cosine;
|
||||
}
|
||||
|
||||
return FastMath.atan2(-fsMean, fcMean);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.fitting;
|
||||
|
||||
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
|
||||
|
||||
/**
|
||||
* Polynomial fitting is a very simple case of {@link CurveFitter curve fitting}.
|
||||
* The estimated coefficients are the polynomial coefficients (see the
|
||||
* {@link #fit(double[]) fit} method).
|
||||
*
|
||||
* @version $Id: PolynomialFitter.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public class PolynomialFitter extends CurveFitter<PolynomialFunction.Parametric> {
|
||||
/**
|
||||
* Simple constructor.
|
||||
*
|
||||
* @param optimizer Optimizer to use for the fitting.
|
||||
*/
|
||||
public PolynomialFitter(MultivariateVectorOptimizer optimizer) {
|
||||
super(optimizer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the coefficients of the polynomial fitting the weighted data points.
|
||||
* The degree of the fitting polynomial is {@code guess.length - 1}.
|
||||
*
|
||||
* @param guess First guess for the coefficients. They must be sorted in
|
||||
* increasing order of the polynomial's degree.
|
||||
* @param maxEval Maximum number of evaluations of the polynomial.
|
||||
* @return the coefficients of the polynomial that best fits the observed points.
|
||||
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException if
|
||||
* the number of evaluations exceeds {@code maxEval}.
|
||||
* @throws org.apache.commons.math3.exception.ConvergenceException
|
||||
* if the algorithm failed to converge.
|
||||
*/
|
||||
public double[] fit(int maxEval, double[] guess) {
|
||||
return fit(maxEval, new PolynomialFunction.Parametric(), guess);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the coefficients of the polynomial fitting the weighted data points.
|
||||
* The degree of the fitting polynomial is {@code guess.length - 1}.
|
||||
*
|
||||
* @param guess First guess for the coefficients. They must be sorted in
|
||||
* increasing order of the polynomial's degree.
|
||||
* @return the coefficients of the polynomial that best fits the observed points.
|
||||
* @throws org.apache.commons.math3.exception.ConvergenceException
|
||||
* if the algorithm failed to converge.
|
||||
*/
|
||||
public double[] fit(double[] guess) {
|
||||
return fit(new PolynomialFunction.Parametric(), guess);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.fitting;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* This class is a simple container for weighted observed point in
|
||||
* {@link CurveFitter curve fitting}.
|
||||
* <p>Instances of this class are guaranteed to be immutable.</p>
|
||||
* @version $Id: WeightedObservedPoint.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public class WeightedObservedPoint implements Serializable {
|
||||
/** Serializable version id. */
|
||||
private static final long serialVersionUID = 5306874947404636157L;
|
||||
/** Weight of the measurement in the fitting process. */
|
||||
private final double weight;
|
||||
/** Abscissa of the point. */
|
||||
private final double x;
|
||||
/** Observed value of the function at x. */
|
||||
private final double y;
|
||||
|
||||
/**
|
||||
* Simple constructor.
|
||||
*
|
||||
* @param weight Weight of the measurement in the fitting process.
|
||||
* @param x Abscissa of the measurement.
|
||||
* @param y Ordinate of the measurement.
|
||||
*/
|
||||
public WeightedObservedPoint(final double weight, final double x, final double y) {
|
||||
this.weight = weight;
|
||||
this.x = x;
|
||||
this.y = y;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the weight of the measurement in the fitting process.
|
||||
*
|
||||
* @return the weight of the measurement in the fitting process.
|
||||
*/
|
||||
public double getWeight() {
|
||||
return weight;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the abscissa of the point.
|
||||
*
|
||||
* @return the abscissa of the point.
|
||||
*/
|
||||
public double getX() {
|
||||
return x;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the observed value of the function at x.
|
||||
*
|
||||
* @return the observed value of the function at x.
|
||||
*/
|
||||
public double getY() {
|
||||
return y;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/**
|
||||
* Classes to perform curve fitting.
|
||||
*
|
||||
* Curve fitting is a special case of a least squares problem
|
||||
* were the parameters are the coefficients of a function {@code f}
|
||||
* whose graph {@code y = f(x)} should pass through sample points, and
|
||||
* were the objective function is the squared sum of the residuals
|
||||
* <code>f(x<sub>i</sub>) - y<sub>i</sub></code> for observed points
|
||||
* <code>(x<sub>i</sub>, y<sub>i</sub>)</code>.
|
||||
*/
|
||||
package org.apache.commons.math3.fitting;
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
/**
|
||||
* Base class for all convergence checker implementations.
|
||||
*
|
||||
* @param <PAIR> Type of (point, value) pair.
|
||||
*
|
||||
* @version $Id: AbstractConvergenceChecker.java 1370215 2012-08-07 12:38:59Z sebb $
|
||||
* @since 3.0
|
||||
*/
|
||||
public abstract class AbstractConvergenceChecker<PAIR>
|
||||
implements ConvergenceChecker<PAIR> {
|
||||
/**
|
||||
* Relative tolerance threshold.
|
||||
*/
|
||||
private final double relativeThreshold;
|
||||
/**
|
||||
* Absolute tolerance threshold.
|
||||
*/
|
||||
private final double absoluteThreshold;
|
||||
|
||||
/**
|
||||
* Build an instance with a specified thresholds.
|
||||
*
|
||||
* @param relativeThreshold relative tolerance threshold
|
||||
* @param absoluteThreshold absolute tolerance threshold
|
||||
*/
|
||||
public AbstractConvergenceChecker(final double relativeThreshold,
|
||||
final double absoluteThreshold) {
|
||||
this.relativeThreshold = relativeThreshold;
|
||||
this.absoluteThreshold = absoluteThreshold;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the relative threshold.
|
||||
*/
|
||||
public double getRelativeThreshold() {
|
||||
return relativeThreshold;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the absolute threshold.
|
||||
*/
|
||||
public double getAbsoluteThreshold() {
|
||||
return absoluteThreshold;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
public abstract boolean converged(int iteration,
|
||||
PAIR previous,
|
||||
PAIR current);
|
||||
}
|
|
@ -0,0 +1,214 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.exception.MathIllegalStateException;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.NullArgumentException;
|
||||
import org.apache.commons.math3.random.RandomVectorGenerator;
|
||||
import org.apache.commons.math3.optim.InitialGuess;
|
||||
|
||||
/**
|
||||
* Base class multi-start optimizer for a multivariate function.
|
||||
* <br/>
|
||||
* This class wraps an optimizer in order to use it several times in
|
||||
* turn with different starting points (trying to avoid being trapped
|
||||
* in a local extremum when looking for a global one).
|
||||
* <em>It is not a "user" class.</em>
|
||||
*
|
||||
* @param <PAIR> Type of the point/value pair returned by the optimization
|
||||
* algorithm.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.0
|
||||
*/
|
||||
public abstract class BaseMultiStartMultivariateOptimizer<PAIR>
|
||||
extends BaseMultivariateOptimizer<PAIR> {
|
||||
/** Underlying classical optimizer. */
|
||||
private final BaseMultivariateOptimizer<PAIR> optimizer;
|
||||
/** Number of evaluations already performed for all starts. */
|
||||
private int totalEvaluations;
|
||||
/** Number of starts to go. */
|
||||
private int starts;
|
||||
/** Random generator for multi-start. */
|
||||
private RandomVectorGenerator generator;
|
||||
/** Optimization data. */
|
||||
private OptimizationData[] optimData;
|
||||
/**
|
||||
* Location in {@link #optimData} where the updated maximum
|
||||
* number of evaluations will be stored.
|
||||
*/
|
||||
private int maxEvalIndex = -1;
|
||||
/**
|
||||
* Location in {@link #optimData} where the updated start value
|
||||
* will be stored.
|
||||
*/
|
||||
private int initialGuessIndex = -1;
|
||||
|
||||
/**
|
||||
* Create a multi-start optimizer from a single-start optimizer.
|
||||
*
|
||||
* @param optimizer Single-start optimizer to wrap.
|
||||
* @param starts Number of starts to perform. If {@code starts == 1},
|
||||
* the {@link #optimize(OptimizationData[]) optimize} will return the
|
||||
* same solution as the given {@code optimizer} would return.
|
||||
* @param generator Random vector generator to use for restarts.
|
||||
* @throws NullArgumentException if {@code optimizer} or {@code generator}
|
||||
* is {@code null}.
|
||||
* @throws NotStrictlyPositiveException if {@code starts < 1}.
|
||||
*/
|
||||
public BaseMultiStartMultivariateOptimizer(final BaseMultivariateOptimizer<PAIR> optimizer,
|
||||
final int starts,
|
||||
final RandomVectorGenerator generator) {
|
||||
super(optimizer.getConvergenceChecker());
|
||||
|
||||
if (optimizer == null ||
|
||||
generator == null) {
|
||||
throw new NullArgumentException();
|
||||
}
|
||||
if (starts < 1) {
|
||||
throw new NotStrictlyPositiveException(starts);
|
||||
}
|
||||
|
||||
this.optimizer = optimizer;
|
||||
this.starts = starts;
|
||||
this.generator = generator;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public int getEvaluations() {
|
||||
return totalEvaluations;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all the optima found during the last call to {@code optimize}.
|
||||
* The optimizer stores all the optima found during a set of
|
||||
* restarts. The {@code optimize} method returns the best point only.
|
||||
* This method returns all the points found at the end of each starts,
|
||||
* including the best one already returned by the {@code optimize} method.
|
||||
* <br/>
|
||||
* The returned array as one element for each start as specified
|
||||
* in the constructor. It is ordered with the results from the
|
||||
* runs that did converge first, sorted from best to worst
|
||||
* objective value (i.e in ascending order if minimizing and in
|
||||
* descending order if maximizing), followed by {@code null} elements
|
||||
* corresponding to the runs that did not converge. This means all
|
||||
* elements will be {@code null} if the {@code optimize} method did throw
|
||||
* an exception.
|
||||
* This also means that if the first element is not {@code null}, it is
|
||||
* the best point found across all starts.
|
||||
* <br/>
|
||||
* The behaviour is undefined if this method is called before
|
||||
* {@code optimize}; it will likely throw {@code NullPointerException}.
|
||||
*
|
||||
* @return an array containing the optima sorted from best to worst.
|
||||
*/
|
||||
public abstract PAIR[] getOptima();
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @throws MathIllegalStateException if {@code optData} does not contain an
|
||||
* instance of {@link MaxEval} or {@link InitialGuess}.
|
||||
*/
|
||||
@Override
|
||||
public PAIR optimize(OptimizationData... optData) {
|
||||
// Store arguments in order to pass them to the internal optimizer.
|
||||
optimData = optData;
|
||||
// Set up base class and perform computations.
|
||||
return super.optimize(optData);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
protected PAIR doOptimize() {
|
||||
// Remove all instances of "MaxEval" and "InitialGuess" from the
|
||||
// array that will be passed to the internal optimizer.
|
||||
// The former is to enforce smaller numbers of allowed evaluations
|
||||
// (according to how many have been used up already), and the latter
|
||||
// to impose a different start value for each start.
|
||||
for (int i = 0; i < optimData.length; i++) {
|
||||
if (optimData[i] instanceof MaxEval) {
|
||||
optimData[i] = null;
|
||||
maxEvalIndex = i;
|
||||
}
|
||||
if (optimData[i] instanceof InitialGuess) {
|
||||
optimData[i] = null;
|
||||
initialGuessIndex = i;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (maxEvalIndex == -1) {
|
||||
throw new MathIllegalStateException();
|
||||
}
|
||||
if (initialGuessIndex == -1) {
|
||||
throw new MathIllegalStateException();
|
||||
}
|
||||
|
||||
RuntimeException lastException = null;
|
||||
totalEvaluations = 0;
|
||||
clear();
|
||||
|
||||
final int maxEval = getMaxEvaluations();
|
||||
final double[] min = getLowerBound();
|
||||
final double[] max = getUpperBound();
|
||||
final double[] startPoint = getStartPoint();
|
||||
|
||||
// Multi-start loop.
|
||||
for (int i = 0; i < starts; i++) {
|
||||
// CHECKSTYLE: stop IllegalCatch
|
||||
try {
|
||||
// Decrease number of allowed evaluations.
|
||||
optimData[maxEvalIndex] = new MaxEval(maxEval - totalEvaluations);
|
||||
// New start value.
|
||||
final double[] s = (i == 0) ?
|
||||
startPoint :
|
||||
generator.nextVector(); // XXX This does not enforce bounds!
|
||||
optimData[initialGuessIndex] = new InitialGuess(s);
|
||||
// Optimize.
|
||||
final PAIR result = optimizer.optimize(optimData);
|
||||
store(result);
|
||||
} catch (RuntimeException mue) {
|
||||
lastException = mue;
|
||||
}
|
||||
// CHECKSTYLE: resume IllegalCatch
|
||||
|
||||
totalEvaluations += optimizer.getEvaluations();
|
||||
}
|
||||
|
||||
final PAIR[] optima = getOptima();
|
||||
if (optima.length == 0) {
|
||||
// All runs failed.
|
||||
throw lastException; // Cannot be null if starts >= 1.
|
||||
}
|
||||
|
||||
// Return the best optimum.
|
||||
return optima[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Method that will be called in order to store each found optimum.
|
||||
*
|
||||
* @param optimum Result of an optimization run.
|
||||
*/
|
||||
protected abstract void store(PAIR optimum);
|
||||
/**
|
||||
* Method that will called in order to clear all stored optima.
|
||||
*/
|
||||
protected abstract void clear();
|
||||
}
|
|
@ -0,0 +1,157 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.optim.InitialGuess;
|
||||
import org.apache.commons.math3.optim.SimpleBounds;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
|
||||
/**
|
||||
* Base class for implementing optimizers for multivariate functions.
|
||||
* It contains the boiler-plate code for initial guess and bounds
|
||||
* specifications.
|
||||
* <em>It is not a "user" class.</em>
|
||||
*
|
||||
* @param <PAIR> Type of the point/value pair returned by the optimization
|
||||
* algorithm.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public abstract class BaseMultivariateOptimizer<PAIR>
|
||||
extends BaseOptimizer<PAIR> {
|
||||
/** Initial guess. */
|
||||
private double[] start;
|
||||
/** Lower bounds. */
|
||||
private double[] lowerBound;
|
||||
/** Upper bounds. */
|
||||
private double[] upperBound;
|
||||
|
||||
/**
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
protected BaseMultivariateOptimizer(ConvergenceChecker<PAIR> checker) {
|
||||
super(checker);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link MaxEval}</li>
|
||||
* <li>{@link InitialGuess}</li>
|
||||
* <li>{@link SimpleBounds}</li>
|
||||
* </ul>
|
||||
* @return {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public PAIR optimize(OptimizationData... optData) {
|
||||
// Retrieve settings.
|
||||
parseOptimizationData(optData);
|
||||
// Check input consistency.
|
||||
checkParameters();
|
||||
// Perform optimization.
|
||||
return super.optimize(optData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Scans the list of (required and optional) optimization data that
|
||||
* characterize the problem.
|
||||
*
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link InitialGuess}</li>
|
||||
* <li>{@link SimpleBounds}</li>
|
||||
* </ul>
|
||||
*/
|
||||
private void parseOptimizationData(OptimizationData... optData) {
|
||||
// The existing values (as set by the previous call) are reused if
|
||||
// not provided in the argument list.
|
||||
for (OptimizationData data : optData) {
|
||||
if (data instanceof InitialGuess) {
|
||||
start = ((InitialGuess) data).getInitialGuess();
|
||||
continue;
|
||||
}
|
||||
if (data instanceof SimpleBounds) {
|
||||
final SimpleBounds bounds = (SimpleBounds) data;
|
||||
lowerBound = bounds.getLower();
|
||||
upperBound = bounds.getUpper();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the initial guess.
|
||||
*
|
||||
* @return the initial guess, or {@code null} if not set.
|
||||
*/
|
||||
public double[] getStartPoint() {
|
||||
return start == null ? null : start.clone();
|
||||
}
|
||||
/**
|
||||
* @return the lower bounds, or {@code null} if not set.
|
||||
*/
|
||||
public double[] getLowerBound() {
|
||||
return lowerBound == null ? null : lowerBound.clone();
|
||||
}
|
||||
/**
|
||||
* @return the upper bounds, or {@code null} if not set.
|
||||
*/
|
||||
public double[] getUpperBound() {
|
||||
return upperBound == null ? null : upperBound.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check parameters consistency.
|
||||
*/
|
||||
private void checkParameters() {
|
||||
if (start != null) {
|
||||
final int dim = start.length;
|
||||
if (lowerBound != null) {
|
||||
if (lowerBound.length != dim) {
|
||||
throw new DimensionMismatchException(lowerBound.length, dim);
|
||||
}
|
||||
for (int i = 0; i < dim; i++) {
|
||||
final double v = start[i];
|
||||
final double lo = lowerBound[i];
|
||||
if (v < lo) {
|
||||
throw new NumberIsTooSmallException(v, lo, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (upperBound != null) {
|
||||
if (upperBound.length != dim) {
|
||||
throw new DimensionMismatchException(upperBound.length, dim);
|
||||
}
|
||||
for (int i = 0; i < dim; i++) {
|
||||
final double v = start[i];
|
||||
final double hi = upperBound[i];
|
||||
if (v > hi) {
|
||||
throw new NumberIsTooLargeException(v, hi, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,218 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.util.Incrementor;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||
import org.apache.commons.math3.exception.TooManyIterationsException;
|
||||
|
||||
/**
|
||||
* Base class for implementing optimizers.
|
||||
* It contains the boiler-plate code for counting the number of evaluations
|
||||
* of the objective function and the number of iterations of the algorithm,
|
||||
* and storing the convergence checker.
|
||||
* <em>It is not a "user" class.</em>
|
||||
*
|
||||
* @param <PAIR> Type of the point/value pair returned by the optimization
|
||||
* algorithm.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public abstract class BaseOptimizer<PAIR> {
|
||||
/** Evaluations counter. */
|
||||
protected final Incrementor evaluations;
|
||||
/** Iterations counter. */
|
||||
protected final Incrementor iterations;
|
||||
/** Convergence checker. */
|
||||
private ConvergenceChecker<PAIR> checker;
|
||||
|
||||
/**
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
protected BaseOptimizer(ConvergenceChecker<PAIR> checker) {
|
||||
this.checker = checker;
|
||||
|
||||
evaluations = new Incrementor(0, new MaxEvalCallback());
|
||||
iterations = new Incrementor(0, new MaxIterCallback());
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the maximal number of function evaluations.
|
||||
*
|
||||
* @return the maximal number of function evaluations.
|
||||
*/
|
||||
public int getMaxEvaluations() {
|
||||
return evaluations.getMaximalCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the number of evaluations of the objective function.
|
||||
* The number of evaluations corresponds to the last call to the
|
||||
* {@code optimize} method. It is 0 if the method has not been
|
||||
* called yet.
|
||||
*
|
||||
* @return the number of evaluations of the objective function.
|
||||
*/
|
||||
public int getEvaluations() {
|
||||
return evaluations.getCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the maximal number of iterations.
|
||||
*
|
||||
* @return the maximal number of iterations.
|
||||
*/
|
||||
public int getMaxIterations() {
|
||||
return iterations.getMaximalCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the number of iterations performed by the algorithm.
|
||||
* The number iterations corresponds to the last call to the
|
||||
* {@code optimize} method. It is 0 if the method has not been
|
||||
* called yet.
|
||||
*
|
||||
* @return the number of evaluations of the objective function.
|
||||
*/
|
||||
public int getIterations() {
|
||||
return iterations.getCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the convergence checker.
|
||||
*
|
||||
* @return the object used to check for convergence.
|
||||
*/
|
||||
public ConvergenceChecker<PAIR> getConvergenceChecker() {
|
||||
return checker;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stores data and performs the optimization.
|
||||
*
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link MaxEval}</li>
|
||||
* <li>{@link MaxIter}</li>
|
||||
* </ul>
|
||||
* @return a point/value pair that satifies the convergence criteria.
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations is exceeded.
|
||||
* @throws TooManyIterationsException if the maximal number of
|
||||
* iterations is exceeded.
|
||||
*/
|
||||
public PAIR optimize(OptimizationData... optData)
|
||||
throws TooManyEvaluationsException,
|
||||
TooManyIterationsException {
|
||||
// Retrieve settings.
|
||||
parseOptimizationData(optData);
|
||||
// Reset counters.
|
||||
evaluations.resetCount();
|
||||
iterations.resetCount();
|
||||
// Perform optimization.
|
||||
return doOptimize();
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs the bulk of the optimization algorithm.
|
||||
*
|
||||
* @return the point/value pair giving the optimal value of the
|
||||
* objective function.
|
||||
*/
|
||||
protected abstract PAIR doOptimize();
|
||||
|
||||
/**
|
||||
* Increment the evaluation count.
|
||||
*
|
||||
* @throws TooManyEvaluationsException if the allowed evaluations
|
||||
* have been exhausted.
|
||||
*/
|
||||
protected void incrementEvaluationCount()
|
||||
throws TooManyEvaluationsException {
|
||||
evaluations.incrementCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Increment the iteration count.
|
||||
*
|
||||
* @throws TooManyIterationsException if the allowed iterations
|
||||
* have been exhausted.
|
||||
*/
|
||||
protected void incrementIterationCount()
|
||||
throws TooManyIterationsException {
|
||||
iterations.incrementCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Scans the list of (required and optional) optimization data that
|
||||
* characterize the problem.
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link MaxEval}</li>
|
||||
* <li>{@link MaxIter}</li>
|
||||
* </ul>
|
||||
*/
|
||||
private void parseOptimizationData(OptimizationData... optData) {
|
||||
// The existing values (as set by the previous call) are reused if
|
||||
// not provided in the argument list.
|
||||
for (OptimizationData data : optData) {
|
||||
if (data instanceof MaxEval) {
|
||||
evaluations.setMaximalCount(((MaxEval) data).getMaxEval());
|
||||
continue;
|
||||
}
|
||||
if (data instanceof MaxIter) {
|
||||
iterations.setMaximalCount(((MaxIter) data).getMaxIter());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines the action to perform when reaching the maximum number
|
||||
* of evaluations.
|
||||
*/
|
||||
private static class MaxEvalCallback
|
||||
implements Incrementor.MaxCountExceededCallback {
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
* @throws TooManyEvaluationsException.
|
||||
*/
|
||||
public void trigger(int max) {
|
||||
throw new TooManyEvaluationsException(max);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines the action to perform when reaching the maximum number
|
||||
* of evaluations.
|
||||
*/
|
||||
private static class MaxIterCallback
|
||||
implements Incrementor.MaxCountExceededCallback {
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
* @throws TooManyIterationsException.
|
||||
*/
|
||||
public void trigger(int max) {
|
||||
throw new TooManyIterationsException(max);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
/**
|
||||
* This interface specifies how to check if an optimization algorithm has
|
||||
* converged.
|
||||
* <br/>
|
||||
* Deciding if convergence has been reached is a problem-dependent issue. The
|
||||
* user should provide a class implementing this interface to allow the
|
||||
* optimization algorithm to stop its search according to the problem at hand.
|
||||
* <br/>
|
||||
* For convenience, three implementations that fit simple needs are already
|
||||
* provided: {@link SimpleValueChecker}, {@link SimpleVectorValueChecker} and
|
||||
* {@link SimplePointChecker}. The first two consider that convergence is
|
||||
* reached when the objective function value does not change much anymore, it
|
||||
* does not use the point set at all.
|
||||
* The third one considers that convergence is reached when the input point
|
||||
* set does not change much anymore, it does not use objective function value
|
||||
* at all.
|
||||
*
|
||||
* @param <PAIR> Type of the (point, objective value) pair.
|
||||
*
|
||||
* @see org.apache.commons.math3.optim.SimplePointChecker
|
||||
* @see org.apache.commons.math3.optim.SimpleValueChecker
|
||||
* @see org.apache.commons.math3.optim.SimpleVectorValueChecker
|
||||
*
|
||||
* @version $Id: ConvergenceChecker.java 1364392 2012-07-22 18:27:12Z tn $
|
||||
* @since 3.0
|
||||
*/
|
||||
public interface ConvergenceChecker<PAIR> {
|
||||
/**
|
||||
* Check if the optimization algorithm has converged.
|
||||
*
|
||||
* @param iteration Current iteration.
|
||||
* @param previous Best point in the previous iteration.
|
||||
* @param current Best point in the current iteration.
|
||||
* @return {@code true} if the algorithm is considered to have converged.
|
||||
*/
|
||||
boolean converged(int iteration, PAIR previous, PAIR current);
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
/**
|
||||
* Goal type for an optimization problem (minimization or maximization of
|
||||
* a scalar function.
|
||||
*
|
||||
* @version $Id: GoalType.java 1364392 2012-07-22 18:27:12Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public enum GoalType implements OptimizationData {
|
||||
/** Maximization. */
|
||||
MAXIMIZE,
|
||||
/** Minimization. */
|
||||
MINIMIZE
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
/**
|
||||
* Starting point (first guess) of the optimization procedure.
|
||||
* <br/>
|
||||
* Immutable class.
|
||||
*
|
||||
* @version $Id: InitialGuess.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 3.1
|
||||
*/
|
||||
public class InitialGuess implements OptimizationData {
|
||||
/** Initial guess. */
|
||||
private final double[] init;
|
||||
|
||||
/**
|
||||
* @param startPoint Initial guess.
|
||||
*/
|
||||
public InitialGuess(double[] startPoint) {
|
||||
init = startPoint.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the initial guess.
|
||||
*
|
||||
* @return the initial guess.
|
||||
*/
|
||||
public double[] getInitialGuess() {
|
||||
return init.clone();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
|
||||
/**
|
||||
* Maximum number of evaluations of the function to be optimized.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public class MaxEval implements OptimizationData {
|
||||
/** Allowed number of evalutations. */
|
||||
private final int maxEval;
|
||||
|
||||
/**
|
||||
* @param max Allowed number of evalutations.
|
||||
* @throws NotStrictlyPositiveException if {@code max <= 0}.
|
||||
*/
|
||||
public MaxEval(int max) {
|
||||
if (max <= 0) {
|
||||
throw new NotStrictlyPositiveException(max);
|
||||
}
|
||||
|
||||
maxEval = max;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the maximum number of evaluations.
|
||||
*
|
||||
* @return the allowed number of evaluations.
|
||||
*/
|
||||
public int getMaxEval() {
|
||||
return maxEval;
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory method that creates instance of this class that represents
|
||||
* a virtually unlimited number of evaluations.
|
||||
*
|
||||
* @return a new instance suitable for allowing {@link Integer#MAX_VALUE}
|
||||
* evaluations.
|
||||
*/
|
||||
public static MaxEval unlimited() {
|
||||
return new MaxEval(Integer.MAX_VALUE);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
|
||||
/**
|
||||
* Maximum number of iterations performed by an (iterative) algorithm.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public class MaxIter implements OptimizationData {
|
||||
/** Allowed number of evalutations. */
|
||||
private final int maxIter;
|
||||
|
||||
/**
|
||||
* @param max Allowed number of iterations.
|
||||
* @throws NotStrictlyPositiveException if {@code max <= 0}.
|
||||
*/
|
||||
public MaxIter(int max) {
|
||||
if (max <= 0) {
|
||||
throw new NotStrictlyPositiveException(max);
|
||||
}
|
||||
|
||||
maxIter = max;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the maximum number of evaluations.
|
||||
*
|
||||
* @return the allowed number of evaluations.
|
||||
*/
|
||||
public int getMaxIter() {
|
||||
return maxIter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory method that creates instance of this class that represents
|
||||
* a virtually unlimited number of iterations.
|
||||
*
|
||||
* @return a new instance suitable for allowing {@link Integer#MAX_VALUE}
|
||||
* evaluations.
|
||||
*/
|
||||
public static MaxIter unlimited() {
|
||||
return new MaxIter(Integer.MAX_VALUE);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
|
||||
/**
|
||||
* Scalar function to be optimized.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public class ObjectiveFunction implements OptimizationData {
|
||||
/** Function to be optimized. */
|
||||
private final MultivariateFunction function;
|
||||
|
||||
/**
|
||||
* @param f Function to be optimized.
|
||||
*/
|
||||
public ObjectiveFunction(MultivariateFunction f) {
|
||||
function = f;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the function to be optimized.
|
||||
*
|
||||
* @return the objective function.
|
||||
*/
|
||||
public MultivariateFunction getObjectiveFunction() {
|
||||
return function;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
/**
|
||||
* Marker interface.
|
||||
* Implementations will provide functionality (optional or required) needed
|
||||
* by the optimizers, and those will need to check the actual type of the
|
||||
* arguments and perform the appropriate cast in order to access the data
|
||||
* they need.
|
||||
*
|
||||
* @version $Id: OptimizationData.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 3.1
|
||||
*/
|
||||
public interface OptimizationData {}
|
|
@ -0,0 +1,122 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import java.io.Serializable;
|
||||
import org.apache.commons.math3.util.Pair;
|
||||
|
||||
/**
|
||||
* This class holds a point and the value of an objective function at
|
||||
* that point.
|
||||
*
|
||||
* @see PointVectorValuePair
|
||||
* @see org.apache.commons.math3.analysis.MultivariateFunction
|
||||
* @version $Id: PointValuePair.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 3.0
|
||||
*/
|
||||
public class PointValuePair extends Pair<double[], Double> implements Serializable {
|
||||
/** Serializable UID. */
|
||||
private static final long serialVersionUID = 20120513L;
|
||||
|
||||
/**
|
||||
* Builds a point/objective function value pair.
|
||||
*
|
||||
* @param point Point coordinates. This instance will store
|
||||
* a copy of the array, not the array passed as argument.
|
||||
* @param value Value of the objective function at the point.
|
||||
*/
|
||||
public PointValuePair(final double[] point,
|
||||
final double value) {
|
||||
this(point, value, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a point/objective function value pair.
|
||||
*
|
||||
* @param point Point coordinates.
|
||||
* @param value Value of the objective function at the point.
|
||||
* @param copyArray if {@code true}, the input array will be copied,
|
||||
* otherwise it will be referenced.
|
||||
*/
|
||||
public PointValuePair(final double[] point,
|
||||
final double value,
|
||||
final boolean copyArray) {
|
||||
super(copyArray ? ((point == null) ? null :
|
||||
point.clone()) :
|
||||
point,
|
||||
value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the point.
|
||||
*
|
||||
* @return a copy of the stored point.
|
||||
*/
|
||||
public double[] getPoint() {
|
||||
final double[] p = getKey();
|
||||
return p == null ? null : p.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a reference to the point.
|
||||
*
|
||||
* @return a reference to the internal array storing the point.
|
||||
*/
|
||||
public double[] getPointRef() {
|
||||
return getKey();
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace the instance with a data transfer object for serialization.
|
||||
* @return data transfer object that will be serialized
|
||||
*/
|
||||
private Object writeReplace() {
|
||||
return new DataTransferObject(getKey(), getValue());
|
||||
}
|
||||
|
||||
/** Internal class used only for serialization. */
|
||||
private static class DataTransferObject implements Serializable {
|
||||
/** Serializable UID. */
|
||||
private static final long serialVersionUID = 20120513L;
|
||||
/**
|
||||
* Point coordinates.
|
||||
* @Serial
|
||||
*/
|
||||
private final double[] point;
|
||||
/**
|
||||
* Value of the objective function at the point.
|
||||
* @Serial
|
||||
*/
|
||||
private final double value;
|
||||
|
||||
/** Simple constructor.
|
||||
* @param point Point coordinates.
|
||||
* @param value Value of the objective function at the point.
|
||||
*/
|
||||
public DataTransferObject(final double[] point, final double value) {
|
||||
this.point = point.clone();
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/** Replace the deserialized data transfer object with a {@link PointValuePair}.
|
||||
* @return replacement {@link PointValuePair}
|
||||
*/
|
||||
private Object readResolve() {
|
||||
return new PointValuePair(point, value, false);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,147 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import java.io.Serializable;
|
||||
import org.apache.commons.math3.util.Pair;
|
||||
|
||||
/**
|
||||
* This class holds a point and the vectorial value of an objective function at
|
||||
* that point.
|
||||
*
|
||||
* @see PointValuePair
|
||||
* @see org.apache.commons.math3.analysis.MultivariateVectorFunction
|
||||
* @version $Id: PointVectorValuePair.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 3.0
|
||||
*/
|
||||
public class PointVectorValuePair extends Pair<double[], double[]> implements Serializable {
|
||||
/** Serializable UID. */
|
||||
private static final long serialVersionUID = 20120513L;
|
||||
|
||||
/**
|
||||
* Builds a point/objective function value pair.
|
||||
*
|
||||
* @param point Point coordinates. This instance will store
|
||||
* a copy of the array, not the array passed as argument.
|
||||
* @param value Value of the objective function at the point.
|
||||
*/
|
||||
public PointVectorValuePair(final double[] point,
|
||||
final double[] value) {
|
||||
this(point, value, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a point/objective function value pair.
|
||||
*
|
||||
* @param point Point coordinates.
|
||||
* @param value Value of the objective function at the point.
|
||||
* @param copyArray if {@code true}, the input arrays will be copied,
|
||||
* otherwise they will be referenced.
|
||||
*/
|
||||
public PointVectorValuePair(final double[] point,
|
||||
final double[] value,
|
||||
final boolean copyArray) {
|
||||
super(copyArray ?
|
||||
((point == null) ? null :
|
||||
point.clone()) :
|
||||
point,
|
||||
copyArray ?
|
||||
((value == null) ? null :
|
||||
value.clone()) :
|
||||
value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the point.
|
||||
*
|
||||
* @return a copy of the stored point.
|
||||
*/
|
||||
public double[] getPoint() {
|
||||
final double[] p = getKey();
|
||||
return p == null ? null : p.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a reference to the point.
|
||||
*
|
||||
* @return a reference to the internal array storing the point.
|
||||
*/
|
||||
public double[] getPointRef() {
|
||||
return getKey();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the value of the objective function.
|
||||
*
|
||||
* @return a copy of the stored value of the objective function.
|
||||
*/
|
||||
@Override
|
||||
public double[] getValue() {
|
||||
final double[] v = super.getValue();
|
||||
return v == null ? null : v.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a reference to the value of the objective function.
|
||||
*
|
||||
* @return a reference to the internal array storing the value of
|
||||
* the objective function.
|
||||
*/
|
||||
public double[] getValueRef() {
|
||||
return super.getValue();
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace the instance with a data transfer object for serialization.
|
||||
* @return data transfer object that will be serialized
|
||||
*/
|
||||
private Object writeReplace() {
|
||||
return new DataTransferObject(getKey(), getValue());
|
||||
}
|
||||
|
||||
/** Internal class used only for serialization. */
|
||||
private static class DataTransferObject implements Serializable {
|
||||
/** Serializable UID. */
|
||||
private static final long serialVersionUID = 20120513L;
|
||||
/**
|
||||
* Point coordinates.
|
||||
* @Serial
|
||||
*/
|
||||
private final double[] point;
|
||||
/**
|
||||
* Value of the objective function at the point.
|
||||
* @Serial
|
||||
*/
|
||||
private final double[] value;
|
||||
|
||||
/** Simple constructor.
|
||||
* @param point Point coordinates.
|
||||
* @param value Value of the objective function at the point.
|
||||
*/
|
||||
public DataTransferObject(final double[] point, final double[] value) {
|
||||
this.point = point.clone();
|
||||
this.value = value.clone();
|
||||
}
|
||||
|
||||
/** Replace the deserialized data transfer object with a {@link PointValuePair}.
|
||||
* @return replacement {@link PointValuePair}
|
||||
*/
|
||||
private Object readResolve() {
|
||||
return new PointVectorValuePair(point, value, false);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* Simple optimization constraints: lower and upper bounds.
|
||||
* The valid range of the parameters is an interval that can be infinite
|
||||
* (in one or both directions).
|
||||
* <br/>
|
||||
* Immutable class.
|
||||
*
|
||||
* @version $Id: SimpleBounds.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 3.1
|
||||
*/
|
||||
public class SimpleBounds implements OptimizationData {
|
||||
/** Lower bounds. */
|
||||
private final double[] lower;
|
||||
/** Upper bounds. */
|
||||
private final double[] upper;
|
||||
|
||||
/**
|
||||
* @param lB Lower bounds.
|
||||
* @param uB Upper bounds.
|
||||
*/
|
||||
public SimpleBounds(double[] lB,
|
||||
double[] uB) {
|
||||
lower = lB.clone();
|
||||
upper = uB.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the lower bounds.
|
||||
*
|
||||
* @return the lower bounds.
|
||||
*/
|
||||
public double[] getLower() {
|
||||
return lower.clone();
|
||||
}
|
||||
/**
|
||||
* Gets the upper bounds.
|
||||
*
|
||||
* @return the upper bounds.
|
||||
*/
|
||||
public double[] getUpper() {
|
||||
return upper.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory method that creates instance of this class that represents
|
||||
* unbounded ranges.
|
||||
*
|
||||
* @param dim Number of parameters.
|
||||
* @return a new instance suitable for passing to an optimizer that
|
||||
* requires bounds specification.
|
||||
*/
|
||||
public static SimpleBounds unbounded(int dim) {
|
||||
final double[] lB = new double[dim];
|
||||
Arrays.fill(lB, Double.NEGATIVE_INFINITY);
|
||||
final double[] uB = new double[dim];
|
||||
Arrays.fill(uB, Double.POSITIVE_INFINITY);
|
||||
|
||||
return new SimpleBounds(lB, uB);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.util.Pair;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
|
||||
/**
|
||||
* Simple implementation of the {@link ConvergenceChecker} interface using
|
||||
* only point coordinates.
|
||||
*
|
||||
* Convergence is considered to have been reached if either the relative
|
||||
* difference between each point coordinate are smaller than a threshold
|
||||
* or if either the absolute difference between the point coordinates are
|
||||
* smaller than another threshold.
|
||||
* <br/>
|
||||
* The {@link #converged(int,Pair,Pair) converged} method will also return
|
||||
* {@code true} if the number of iterations has been set (see
|
||||
* {@link #SimplePointChecker(double,double,int) this constructor}).
|
||||
*
|
||||
* @param <PAIR> Type of the (point, value) pair.
|
||||
* The type of the "value" part of the pair (not used by this class).
|
||||
*
|
||||
* @version $Id: SimplePointChecker.java 1413127 2012-11-24 04:37:30Z psteitz $
|
||||
* @since 3.0
|
||||
*/
|
||||
public class SimplePointChecker<PAIR extends Pair<double[], ? extends Object>>
|
||||
extends AbstractConvergenceChecker<PAIR> {
|
||||
/**
|
||||
* If {@link #maxIterationCount} is set to this value, the number of
|
||||
* iterations will never cause {@link #converged(int, Pair, Pair)}
|
||||
* to return {@code true}.
|
||||
*/
|
||||
private static final int ITERATION_CHECK_DISABLED = -1;
|
||||
/**
|
||||
* Number of iterations after which the
|
||||
* {@link #converged(int, Pair, Pair)} method
|
||||
* will return true (unless the check is disabled).
|
||||
*/
|
||||
private final int maxIterationCount;
|
||||
|
||||
/**
|
||||
* Build an instance with specified thresholds.
|
||||
* In order to perform only relative checks, the absolute tolerance
|
||||
* must be set to a negative value. In order to perform only absolute
|
||||
* checks, the relative tolerance must be set to a negative value.
|
||||
*
|
||||
* @param relativeThreshold relative tolerance threshold
|
||||
* @param absoluteThreshold absolute tolerance threshold
|
||||
*/
|
||||
public SimplePointChecker(final double relativeThreshold,
|
||||
final double absoluteThreshold) {
|
||||
super(relativeThreshold, absoluteThreshold);
|
||||
maxIterationCount = ITERATION_CHECK_DISABLED;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an instance with specified thresholds.
|
||||
* In order to perform only relative checks, the absolute tolerance
|
||||
* must be set to a negative value. In order to perform only absolute
|
||||
* checks, the relative tolerance must be set to a negative value.
|
||||
*
|
||||
* @param relativeThreshold Relative tolerance threshold.
|
||||
* @param absoluteThreshold Absolute tolerance threshold.
|
||||
* @param maxIter Maximum iteration count.
|
||||
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
|
||||
*
|
||||
* @since 3.1
|
||||
*/
|
||||
public SimplePointChecker(final double relativeThreshold,
|
||||
final double absoluteThreshold,
|
||||
final int maxIter) {
|
||||
super(relativeThreshold, absoluteThreshold);
|
||||
|
||||
if (maxIter <= 0) {
|
||||
throw new NotStrictlyPositiveException(maxIter);
|
||||
}
|
||||
maxIterationCount = maxIter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the optimization algorithm has converged considering the
|
||||
* last two points.
|
||||
* This method may be called several times from the same algorithm
|
||||
* iteration with different points. This can be detected by checking the
|
||||
* iteration number at each call if needed. Each time this method is
|
||||
* called, the previous and current point correspond to points with the
|
||||
* same role at each iteration, so they can be compared. As an example,
|
||||
* simplex-based algorithms call this method for all points of the simplex,
|
||||
* not only for the best or worst ones.
|
||||
*
|
||||
* @param iteration Index of current iteration
|
||||
* @param previous Best point in the previous iteration.
|
||||
* @param current Best point in the current iteration.
|
||||
* @return {@code true} if the arguments satify the convergence criterion.
|
||||
*/
|
||||
@Override
|
||||
public boolean converged(final int iteration,
|
||||
final PAIR previous,
|
||||
final PAIR current) {
|
||||
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
|
||||
if (iteration >= maxIterationCount) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
final double[] p = previous.getKey();
|
||||
final double[] c = current.getKey();
|
||||
for (int i = 0; i < p.length; ++i) {
|
||||
final double pi = p[i];
|
||||
final double ci = c[i];
|
||||
final double difference = FastMath.abs(pi - ci);
|
||||
final double size = FastMath.max(FastMath.abs(pi), FastMath.abs(ci));
|
||||
if (difference > size * getRelativeThreshold() &&
|
||||
difference > getAbsoluteThreshold()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
|
||||
/**
|
||||
* Simple implementation of the {@link ConvergenceChecker} interface using
|
||||
* only objective function values.
|
||||
*
|
||||
* Convergence is considered to have been reached if either the relative
|
||||
* difference between the objective function values is smaller than a
|
||||
* threshold or if either the absolute difference between the objective
|
||||
* function values is smaller than another threshold.
|
||||
* <br/>
|
||||
* The {@link #converged(int,PointValuePair,PointValuePair) converged}
|
||||
* method will also return {@code true} if the number of iterations has been set
|
||||
* (see {@link #SimpleValueChecker(double,double,int) this constructor}).
|
||||
*
|
||||
* @version $Id: SimpleValueChecker.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 3.0
|
||||
*/
|
||||
public class SimpleValueChecker
|
||||
extends AbstractConvergenceChecker<PointValuePair> {
|
||||
/**
|
||||
* If {@link #maxIterationCount} is set to this value, the number of
|
||||
* iterations will never cause
|
||||
* {@link #converged(int,PointValuePair,PointValuePair)}
|
||||
* to return {@code true}.
|
||||
*/
|
||||
private static final int ITERATION_CHECK_DISABLED = -1;
|
||||
/**
|
||||
* Number of iterations after which the
|
||||
* {@link #converged(int,PointValuePair,PointValuePair)} method
|
||||
* will return true (unless the check is disabled).
|
||||
*/
|
||||
private final int maxIterationCount;
|
||||
|
||||
/** Build an instance with specified thresholds.
|
||||
*
|
||||
* In order to perform only relative checks, the absolute tolerance
|
||||
* must be set to a negative value. In order to perform only absolute
|
||||
* checks, the relative tolerance must be set to a negative value.
|
||||
*
|
||||
* @param relativeThreshold relative tolerance threshold
|
||||
* @param absoluteThreshold absolute tolerance threshold
|
||||
*/
|
||||
public SimpleValueChecker(final double relativeThreshold,
|
||||
final double absoluteThreshold) {
|
||||
super(relativeThreshold, absoluteThreshold);
|
||||
maxIterationCount = ITERATION_CHECK_DISABLED;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an instance with specified thresholds.
|
||||
*
|
||||
* In order to perform only relative checks, the absolute tolerance
|
||||
* must be set to a negative value. In order to perform only absolute
|
||||
* checks, the relative tolerance must be set to a negative value.
|
||||
*
|
||||
* @param relativeThreshold relative tolerance threshold
|
||||
* @param absoluteThreshold absolute tolerance threshold
|
||||
* @param maxIter Maximum iteration count.
|
||||
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
|
||||
*
|
||||
* @since 3.1
|
||||
*/
|
||||
public SimpleValueChecker(final double relativeThreshold,
|
||||
final double absoluteThreshold,
|
||||
final int maxIter) {
|
||||
super(relativeThreshold, absoluteThreshold);
|
||||
|
||||
if (maxIter <= 0) {
|
||||
throw new NotStrictlyPositiveException(maxIter);
|
||||
}
|
||||
maxIterationCount = maxIter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the optimization algorithm has converged considering the
|
||||
* last two points.
|
||||
* This method may be called several time from the same algorithm
|
||||
* iteration with different points. This can be detected by checking the
|
||||
* iteration number at each call if needed. Each time this method is
|
||||
* called, the previous and current point correspond to points with the
|
||||
* same role at each iteration, so they can be compared. As an example,
|
||||
* simplex-based algorithms call this method for all points of the simplex,
|
||||
* not only for the best or worst ones.
|
||||
*
|
||||
* @param iteration Index of current iteration
|
||||
* @param previous Best point in the previous iteration.
|
||||
* @param current Best point in the current iteration.
|
||||
* @return {@code true} if the algorithm has converged.
|
||||
*/
|
||||
@Override
|
||||
public boolean converged(final int iteration,
|
||||
final PointValuePair previous,
|
||||
final PointValuePair current) {
|
||||
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
|
||||
if (iteration >= maxIterationCount) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
final double p = previous.getValue();
|
||||
final double c = current.getValue();
|
||||
final double difference = FastMath.abs(p - c);
|
||||
final double size = FastMath.max(FastMath.abs(p), FastMath.abs(c));
|
||||
return difference <= size * getRelativeThreshold() ||
|
||||
difference <= getAbsoluteThreshold();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
|
||||
/**
|
||||
* Simple implementation of the {@link ConvergenceChecker} interface using
|
||||
* only objective function values.
|
||||
*
|
||||
* Convergence is considered to have been reached if either the relative
|
||||
* difference between the objective function values is smaller than a
|
||||
* threshold or if either the absolute difference between the objective
|
||||
* function values is smaller than another threshold for all vectors elements.
|
||||
* <br/>
|
||||
* The {@link #converged(int,PointVectorValuePair,PointVectorValuePair) converged}
|
||||
* method will also return {@code true} if the number of iterations has been set
|
||||
* (see {@link #SimpleVectorValueChecker(double,double,int) this constructor}).
|
||||
*
|
||||
* @version $Id: SimpleVectorValueChecker.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 3.0
|
||||
*/
|
||||
public class SimpleVectorValueChecker
|
||||
extends AbstractConvergenceChecker<PointVectorValuePair> {
|
||||
/**
|
||||
* If {@link #maxIterationCount} is set to this value, the number of
|
||||
* iterations will never cause
|
||||
* {@link #converged(int,PointVectorValuePair,PointVectorValuePair)}
|
||||
* to return {@code true}.
|
||||
*/
|
||||
private static final int ITERATION_CHECK_DISABLED = -1;
|
||||
/**
|
||||
* Number of iterations after which the
|
||||
* {@link #converged(int,PointVectorValuePair,PointVectorValuePair)} method
|
||||
* will return true (unless the check is disabled).
|
||||
*/
|
||||
private final int maxIterationCount;
|
||||
|
||||
/**
|
||||
* Build an instance with specified thresholds.
|
||||
*
|
||||
* In order to perform only relative checks, the absolute tolerance
|
||||
* must be set to a negative value. In order to perform only absolute
|
||||
* checks, the relative tolerance must be set to a negative value.
|
||||
*
|
||||
* @param relativeThreshold relative tolerance threshold
|
||||
* @param absoluteThreshold absolute tolerance threshold
|
||||
*/
|
||||
public SimpleVectorValueChecker(final double relativeThreshold,
|
||||
final double absoluteThreshold) {
|
||||
super(relativeThreshold, absoluteThreshold);
|
||||
maxIterationCount = ITERATION_CHECK_DISABLED;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an instance with specified tolerance thresholds and
|
||||
* iteration count.
|
||||
*
|
||||
* In order to perform only relative checks, the absolute tolerance
|
||||
* must be set to a negative value. In order to perform only absolute
|
||||
* checks, the relative tolerance must be set to a negative value.
|
||||
*
|
||||
* @param relativeThreshold Relative tolerance threshold.
|
||||
* @param absoluteThreshold Absolute tolerance threshold.
|
||||
* @param maxIter Maximum iteration count.
|
||||
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
|
||||
*
|
||||
* @since 3.1
|
||||
*/
|
||||
public SimpleVectorValueChecker(final double relativeThreshold,
|
||||
final double absoluteThreshold,
|
||||
final int maxIter) {
|
||||
super(relativeThreshold, absoluteThreshold);
|
||||
|
||||
if (maxIter <= 0) {
|
||||
throw new NotStrictlyPositiveException(maxIter);
|
||||
}
|
||||
maxIterationCount = maxIter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the optimization algorithm has converged considering the
|
||||
* last two points.
|
||||
* This method may be called several times from the same algorithm
|
||||
* iteration with different points. This can be detected by checking the
|
||||
* iteration number at each call if needed. Each time this method is
|
||||
* called, the previous and current point correspond to points with the
|
||||
* same role at each iteration, so they can be compared. As an example,
|
||||
* simplex-based algorithms call this method for all points of the simplex,
|
||||
* not only for the best or worst ones.
|
||||
*
|
||||
* @param iteration Index of current iteration
|
||||
* @param previous Best point in the previous iteration.
|
||||
* @param current Best point in the current iteration.
|
||||
* @return {@code true} if the arguments satify the convergence criterion.
|
||||
*/
|
||||
@Override
|
||||
public boolean converged(final int iteration,
|
||||
final PointVectorValuePair previous,
|
||||
final PointVectorValuePair current) {
|
||||
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
|
||||
if (iteration >= maxIterationCount) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
final double[] p = previous.getValueRef();
|
||||
final double[] c = current.getValueRef();
|
||||
for (int i = 0; i < p.length; ++i) {
|
||||
final double pi = p[i];
|
||||
final double ci = c[i];
|
||||
final double difference = FastMath.abs(pi - ci);
|
||||
final double size = FastMath.max(FastMath.abs(pi), FastMath.abs(ci));
|
||||
if (difference > size * getRelativeThreshold() &&
|
||||
difference > getAbsoluteThreshold()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,229 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.Serializable;
|
||||
import org.apache.commons.math3.linear.MatrixUtils;
|
||||
import org.apache.commons.math3.linear.RealVector;
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
|
||||
/**
|
||||
* A linear constraint for a linear optimization problem.
|
||||
* <p>
|
||||
* A linear constraint has one of the forms:
|
||||
* <ul>
|
||||
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> = v</li>
|
||||
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> <= v</li>
|
||||
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> >= v</li>
|
||||
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> =
|
||||
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
|
||||
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> <=
|
||||
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
|
||||
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> >=
|
||||
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
|
||||
* </ul>
|
||||
* The c<sub>i</sub>, l<sub>i</sub> or r<sub>i</sub> are the coefficients of the constraints, the x<sub>i</sub>
|
||||
* are the coordinates of the current point and v is the value of the constraint.
|
||||
* </p>
|
||||
*
|
||||
* @version $Id: LinearConstraint.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public class LinearConstraint implements Serializable {
|
||||
/** Serializable version identifier. */
|
||||
private static final long serialVersionUID = -764632794033034092L;
|
||||
/** Coefficients of the constraint (left hand side). */
|
||||
private final transient RealVector coefficients;
|
||||
/** Relationship between left and right hand sides (=, <=, >=). */
|
||||
private final Relationship relationship;
|
||||
/** Value of the constraint (right hand side). */
|
||||
private final double value;
|
||||
|
||||
/**
|
||||
* Build a constraint involving a single linear equation.
|
||||
* <p>
|
||||
* A linear constraint with a single linear equation has one of the forms:
|
||||
* <ul>
|
||||
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> = v</li>
|
||||
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> <= v</li>
|
||||
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> >= v</li>
|
||||
* </ul>
|
||||
* </p>
|
||||
* @param coefficients The coefficients of the constraint (left hand side)
|
||||
* @param relationship The type of (in)equality used in the constraint
|
||||
* @param value The value of the constraint (right hand side)
|
||||
*/
|
||||
public LinearConstraint(final double[] coefficients,
|
||||
final Relationship relationship,
|
||||
final double value) {
|
||||
this(new ArrayRealVector(coefficients), relationship, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a constraint involving a single linear equation.
|
||||
* <p>
|
||||
* A linear constraint with a single linear equation has one of the forms:
|
||||
* <ul>
|
||||
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> = v</li>
|
||||
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> <= v</li>
|
||||
* <li>c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> >= v</li>
|
||||
* </ul>
|
||||
* </p>
|
||||
* @param coefficients The coefficients of the constraint (left hand side)
|
||||
* @param relationship The type of (in)equality used in the constraint
|
||||
* @param value The value of the constraint (right hand side)
|
||||
*/
|
||||
public LinearConstraint(final RealVector coefficients,
|
||||
final Relationship relationship,
|
||||
final double value) {
|
||||
this.coefficients = coefficients;
|
||||
this.relationship = relationship;
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a constraint involving two linear equations.
|
||||
* <p>
|
||||
* A linear constraint with two linear equation has one of the forms:
|
||||
* <ul>
|
||||
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> =
|
||||
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
|
||||
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> <=
|
||||
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
|
||||
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> >=
|
||||
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
|
||||
* </ul>
|
||||
* </p>
|
||||
* @param lhsCoefficients The coefficients of the linear expression on the left hand side of the constraint
|
||||
* @param lhsConstant The constant term of the linear expression on the left hand side of the constraint
|
||||
* @param relationship The type of (in)equality used in the constraint
|
||||
* @param rhsCoefficients The coefficients of the linear expression on the right hand side of the constraint
|
||||
* @param rhsConstant The constant term of the linear expression on the right hand side of the constraint
|
||||
*/
|
||||
public LinearConstraint(final double[] lhsCoefficients, final double lhsConstant,
|
||||
final Relationship relationship,
|
||||
final double[] rhsCoefficients, final double rhsConstant) {
|
||||
double[] sub = new double[lhsCoefficients.length];
|
||||
for (int i = 0; i < sub.length; ++i) {
|
||||
sub[i] = lhsCoefficients[i] - rhsCoefficients[i];
|
||||
}
|
||||
this.coefficients = new ArrayRealVector(sub, false);
|
||||
this.relationship = relationship;
|
||||
this.value = rhsConstant - lhsConstant;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a constraint involving two linear equations.
|
||||
* <p>
|
||||
* A linear constraint with two linear equation has one of the forms:
|
||||
* <ul>
|
||||
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> =
|
||||
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
|
||||
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> <=
|
||||
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
|
||||
* <li>l<sub>1</sub>x<sub>1</sub> + ... l<sub>n</sub>x<sub>n</sub> + l<sub>cst</sub> >=
|
||||
* r<sub>1</sub>x<sub>1</sub> + ... r<sub>n</sub>x<sub>n</sub> + r<sub>cst</sub></li>
|
||||
* </ul>
|
||||
* </p>
|
||||
* @param lhsCoefficients The coefficients of the linear expression on the left hand side of the constraint
|
||||
* @param lhsConstant The constant term of the linear expression on the left hand side of the constraint
|
||||
* @param relationship The type of (in)equality used in the constraint
|
||||
* @param rhsCoefficients The coefficients of the linear expression on the right hand side of the constraint
|
||||
* @param rhsConstant The constant term of the linear expression on the right hand side of the constraint
|
||||
*/
|
||||
public LinearConstraint(final RealVector lhsCoefficients, final double lhsConstant,
|
||||
final Relationship relationship,
|
||||
final RealVector rhsCoefficients, final double rhsConstant) {
|
||||
this.coefficients = lhsCoefficients.subtract(rhsCoefficients);
|
||||
this.relationship = relationship;
|
||||
this.value = rhsConstant - lhsConstant;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the coefficients of the constraint (left hand side).
|
||||
*
|
||||
* @return the coefficients of the constraint (left hand side).
|
||||
*/
|
||||
public RealVector getCoefficients() {
|
||||
return coefficients;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the relationship between left and right hand sides.
|
||||
*
|
||||
* @return the relationship between left and right hand sides.
|
||||
*/
|
||||
public Relationship getRelationship() {
|
||||
return relationship;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the value of the constraint (right hand side).
|
||||
*
|
||||
* @return the value of the constraint (right hand side).
|
||||
*/
|
||||
public double getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object other) {
|
||||
if (this == other) {
|
||||
return true;
|
||||
}
|
||||
if (other instanceof LinearConstraint) {
|
||||
LinearConstraint rhs = (LinearConstraint) other;
|
||||
return relationship == rhs.relationship &&
|
||||
value == rhs.value &&
|
||||
coefficients.equals(rhs.coefficients);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return relationship.hashCode() ^
|
||||
Double.valueOf(value).hashCode() ^
|
||||
coefficients.hashCode();
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize the instance.
|
||||
* @param oos stream where object should be written
|
||||
* @throws IOException if object cannot be written to stream
|
||||
*/
|
||||
private void writeObject(ObjectOutputStream oos)
|
||||
throws IOException {
|
||||
oos.defaultWriteObject();
|
||||
MatrixUtils.serializeRealVector(coefficients, oos);
|
||||
}
|
||||
|
||||
/**
|
||||
* Deserialize the instance.
|
||||
* @param ois stream from which the object should be read
|
||||
* @throws ClassNotFoundException if a class in the stream cannot be found
|
||||
* @throws IOException if object cannot be read from the stream
|
||||
*/
|
||||
private void readObject(ObjectInputStream ois)
|
||||
throws ClassNotFoundException, IOException {
|
||||
ois.defaultReadObject();
|
||||
MatrixUtils.deserializeRealVector(this, "coefficients", ois);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
import java.util.Set;
|
||||
import java.util.HashSet;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
|
||||
/**
|
||||
* Class that represents a set of {@link LinearConstraint linear constraints}.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public class LinearConstraintSet implements OptimizationData {
|
||||
/** Set of constraints. */
|
||||
private final Set<LinearConstraint> linearConstraints
|
||||
= new HashSet<LinearConstraint>();
|
||||
|
||||
/**
|
||||
* Creates a set containing the given constraints.
|
||||
*
|
||||
* @param constraints Constraints.
|
||||
*/
|
||||
public LinearConstraintSet(LinearConstraint... constraints) {
|
||||
for (LinearConstraint c : constraints) {
|
||||
linearConstraints.add(c);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a set containing the given constraints.
|
||||
*
|
||||
* @param constraints Constraints.
|
||||
*/
|
||||
public LinearConstraintSet(Collection<LinearConstraint> constraints) {
|
||||
linearConstraints.addAll(constraints);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the set of linear constraints.
|
||||
*
|
||||
* @return the constraints.
|
||||
*/
|
||||
public Collection<LinearConstraint> getConstraints() {
|
||||
return Collections.unmodifiableSet(linearConstraints);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.Serializable;
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.linear.MatrixUtils;
|
||||
import org.apache.commons.math3.linear.RealVector;
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
|
||||
/**
|
||||
* An objective function for a linear optimization problem.
|
||||
* <p>
|
||||
* A linear objective function has one the form:
|
||||
* <pre>
|
||||
* c<sub>1</sub>x<sub>1</sub> + ... c<sub>n</sub>x<sub>n</sub> + d
|
||||
* </pre>
|
||||
* The c<sub>i</sub> and d are the coefficients of the equation,
|
||||
* the x<sub>i</sub> are the coordinates of the current point.
|
||||
* </p>
|
||||
*
|
||||
* @version $Id: LinearObjectiveFunction.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public class LinearObjectiveFunction
|
||||
implements MultivariateFunction,
|
||||
OptimizationData,
|
||||
Serializable {
|
||||
/** Serializable version identifier. */
|
||||
private static final long serialVersionUID = -4531815507568396090L;
|
||||
/** Coefficients of the linear equation (c<sub>i</sub>). */
|
||||
private final transient RealVector coefficients;
|
||||
/** Constant term of the linear equation. */
|
||||
private final double constantTerm;
|
||||
|
||||
/**
|
||||
* @param coefficients Coefficients for the linear equation being optimized.
|
||||
* @param constantTerm Constant term of the linear equation.
|
||||
*/
|
||||
public LinearObjectiveFunction(double[] coefficients, double constantTerm) {
|
||||
this(new ArrayRealVector(coefficients), constantTerm);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param coefficients Coefficients for the linear equation being optimized.
|
||||
* @param constantTerm Constant term of the linear equation.
|
||||
*/
|
||||
public LinearObjectiveFunction(RealVector coefficients, double constantTerm) {
|
||||
this.coefficients = coefficients;
|
||||
this.constantTerm = constantTerm;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the coefficients of the linear equation being optimized.
|
||||
*
|
||||
* @return coefficients of the linear equation being optimized.
|
||||
*/
|
||||
public RealVector getCoefficients() {
|
||||
return coefficients;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the constant of the linear equation being optimized.
|
||||
*
|
||||
* @return constant of the linear equation being optimized.
|
||||
*/
|
||||
public double getConstantTerm() {
|
||||
return constantTerm;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the value of the linear equation at the current point.
|
||||
*
|
||||
* @param point Point at which linear equation must be evaluated.
|
||||
* @return the value of the linear equation at the current point.
|
||||
*/
|
||||
public double value(final double[] point) {
|
||||
return value(new ArrayRealVector(point, false));
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the value of the linear equation at the current point.
|
||||
*
|
||||
* @param point Point at which linear equation must be evaluated.
|
||||
* @return the value of the linear equation at the current point.
|
||||
*/
|
||||
public double value(final RealVector point) {
|
||||
return coefficients.dotProduct(point) + constantTerm;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object other) {
|
||||
if (this == other) {
|
||||
return true;
|
||||
}
|
||||
if (other instanceof LinearObjectiveFunction) {
|
||||
LinearObjectiveFunction rhs = (LinearObjectiveFunction) other;
|
||||
return (constantTerm == rhs.constantTerm) && coefficients.equals(rhs.coefficients);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Double.valueOf(constantTerm).hashCode() ^ coefficients.hashCode();
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize the instance.
|
||||
* @param oos stream where object should be written
|
||||
* @throws IOException if object cannot be written to stream
|
||||
*/
|
||||
private void writeObject(ObjectOutputStream oos)
|
||||
throws IOException {
|
||||
oos.defaultWriteObject();
|
||||
MatrixUtils.serializeRealVector(coefficients, oos);
|
||||
}
|
||||
|
||||
/**
|
||||
* Deserialize the instance.
|
||||
* @param ois stream from which the object should be read
|
||||
* @throws ClassNotFoundException if a class in the stream cannot be found
|
||||
* @throws IOException if object cannot be read from the stream
|
||||
*/
|
||||
private void readObject(ObjectInputStream ois)
|
||||
throws ClassNotFoundException, IOException {
|
||||
ois.defaultReadObject();
|
||||
MatrixUtils.deserializeRealVector(this, "coefficients", ois);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import org.apache.commons.math3.exception.TooManyIterationsException;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer;
|
||||
|
||||
/**
|
||||
* Base class for implementing linear optimizers.
|
||||
*
|
||||
* @version $Id: AbstractLinearOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 3.1
|
||||
*/
|
||||
public abstract class LinearOptimizer
|
||||
extends MultivariateOptimizer {
|
||||
/**
|
||||
* Linear objective function.
|
||||
*/
|
||||
private LinearObjectiveFunction function;
|
||||
/**
|
||||
* Linear constraints.
|
||||
*/
|
||||
private Collection<LinearConstraint> linearConstraints;
|
||||
/**
|
||||
* Whether to restrict the variables to non-negative values.
|
||||
*/
|
||||
private boolean nonNegative;
|
||||
|
||||
/**
|
||||
* Simple constructor with default settings.
|
||||
*
|
||||
*/
|
||||
protected LinearOptimizer() {
|
||||
super(null); // No convergence checker.
|
||||
}
|
||||
|
||||
/**
|
||||
* @return {@code true} if the variables are restricted to non-negative values.
|
||||
*/
|
||||
protected boolean isRestrictedToNonNegative() {
|
||||
return nonNegative;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the optimization type.
|
||||
*/
|
||||
protected LinearObjectiveFunction getFunction() {
|
||||
return function;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the optimization type.
|
||||
*/
|
||||
protected Collection<LinearConstraint> getConstraints() {
|
||||
return Collections.unmodifiableCollection(linearConstraints);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link org.apache.commons.math3.optim.MaxIter}</li>
|
||||
* <li>{@link LinearObjectiveFunction}</li>
|
||||
* <li>{@link LinearConstraintSet}</li>
|
||||
* <li>{@link NonNegativeConstraint}</li>
|
||||
* </ul>
|
||||
* @return {@inheritDoc}
|
||||
* @throws TooManyIterationsException if the maximal number of
|
||||
* iterations is exceeded.
|
||||
*/
|
||||
@Override
|
||||
public PointValuePair optimize(OptimizationData... optData)
|
||||
throws TooManyIterationsException {
|
||||
// Retrieve settings.
|
||||
parseOptimizationData(optData);
|
||||
// Set up base class and perform computation.
|
||||
return super.optimize(optData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Scans the list of (required and optional) optimization data that
|
||||
* characterize the problem.
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link LinearObjectiveFunction}</li>
|
||||
* <li>{@link LinearConstraintSet}</li>
|
||||
* <li>{@link NonNegativeConstraint}</li>
|
||||
* </ul>
|
||||
*/
|
||||
private void parseOptimizationData(OptimizationData... optData) {
|
||||
// The existing values (as set by the previous call) are reused if
|
||||
// not provided in the argument list.
|
||||
for (OptimizationData data : optData) {
|
||||
if (data instanceof LinearObjectiveFunction) {
|
||||
function = (LinearObjectiveFunction) data;
|
||||
continue;
|
||||
}
|
||||
if (data instanceof LinearConstraintSet) {
|
||||
linearConstraints = ((LinearConstraintSet) data).getConstraints();
|
||||
continue;
|
||||
}
|
||||
if (data instanceof NonNegativeConstraint) {
|
||||
nonNegative = ((NonNegativeConstraint) data).isRestrictedToNonNegative();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
import org.apache.commons.math3.exception.MathIllegalStateException;
|
||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
|
||||
/**
|
||||
* This class represents exceptions thrown by optimizers when no solution fulfills the constraints.
|
||||
*
|
||||
* @version $Id: NoFeasibleSolutionException.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public class NoFeasibleSolutionException extends MathIllegalStateException {
|
||||
/** Serializable version identifier. */
|
||||
private static final long serialVersionUID = -3044253632189082760L;
|
||||
|
||||
/**
|
||||
* Simple constructor using a default message.
|
||||
*/
|
||||
public NoFeasibleSolutionException() {
|
||||
super(LocalizedFormats.NO_FEASIBLE_SOLUTION);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
|
||||
/**
|
||||
* A constraint for a linear optimization problem indicating whether all
|
||||
* variables must be restricted to non-negative values.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public class NonNegativeConstraint implements OptimizationData {
|
||||
/** Whether the variables are all positive. */
|
||||
private final boolean isRestricted;
|
||||
|
||||
/**
|
||||
* @param restricted If {@code true}, all the variables must be positive.
|
||||
*/
|
||||
public NonNegativeConstraint(boolean restricted) {
|
||||
isRestricted = restricted;
|
||||
}
|
||||
|
||||
/**
|
||||
* Indicates whether all the variables must be restricted to non-negative
|
||||
* values.
|
||||
*
|
||||
* @return {@code true} if all the variables must be positive.
|
||||
*/
|
||||
public boolean isRestrictedToNonNegative() {
|
||||
return isRestricted;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
/**
|
||||
* Types of relationships between two cells in a Solver {@link LinearConstraint}.
|
||||
*
|
||||
* @version $Id: Relationship.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public enum Relationship {
|
||||
/** Equality relationship. */
|
||||
EQ("="),
|
||||
/** Lesser than or equal relationship. */
|
||||
LEQ("<="),
|
||||
/** Greater than or equal relationship. */
|
||||
GEQ(">=");
|
||||
|
||||
/** Display string for the relationship. */
|
||||
private final String stringValue;
|
||||
|
||||
/**
|
||||
* Simple constructor.
|
||||
*
|
||||
* @param stringValue Display string for the relationship.
|
||||
*/
|
||||
private Relationship(String stringValue) {
|
||||
this.stringValue = stringValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return stringValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the relationship obtained when multiplying all coefficients by -1.
|
||||
*
|
||||
* @return the opposite relationship.
|
||||
*/
|
||||
public Relationship oppositeRelationship() {
|
||||
switch (this) {
|
||||
case LEQ :
|
||||
return GEQ;
|
||||
case GEQ :
|
||||
return LEQ;
|
||||
default :
|
||||
return EQ;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,245 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.commons.math3.exception.TooManyIterationsException;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.util.Precision;
|
||||
|
||||
/**
|
||||
* Solves a linear problem using the "Two-Phase Simplex" method.
|
||||
*
|
||||
* @version $Id: SimplexSolver.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public class SimplexSolver extends LinearOptimizer {
|
||||
/** Default amount of error to accept for algorithm convergence. */
|
||||
private static final double DEFAULT_EPSILON = 1.0e-6;
|
||||
|
||||
/** Default amount of error to accept in floating point comparisons (as ulps). */
|
||||
private static final int DEFAULT_ULPS = 10;
|
||||
|
||||
/** Amount of error to accept for algorithm convergence. */
|
||||
private final double epsilon;
|
||||
|
||||
/** Amount of error to accept in floating point comparisons (as ulps). */
|
||||
private final int maxUlps;
|
||||
|
||||
/**
|
||||
* Builds a simplex solver with default settings.
|
||||
*/
|
||||
public SimplexSolver() {
|
||||
this(DEFAULT_EPSILON, DEFAULT_ULPS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a simplex solver with a specified accepted amount of error.
|
||||
*
|
||||
* @param epsilon Amount of error to accept for algorithm convergence.
|
||||
* @param maxUlps Amount of error to accept in floating point comparisons.
|
||||
*/
|
||||
public SimplexSolver(final double epsilon,
|
||||
final int maxUlps) {
|
||||
this.epsilon = epsilon;
|
||||
this.maxUlps = maxUlps;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the column with the most negative coefficient in the objective function row.
|
||||
*
|
||||
* @param tableau Simple tableau for the problem.
|
||||
* @return the column with the most negative coefficient.
|
||||
*/
|
||||
private Integer getPivotColumn(SimplexTableau tableau) {
|
||||
double minValue = 0;
|
||||
Integer minPos = null;
|
||||
for (int i = tableau.getNumObjectiveFunctions(); i < tableau.getWidth() - 1; i++) {
|
||||
final double entry = tableau.getEntry(0, i);
|
||||
// check if the entry is strictly smaller than the current minimum
|
||||
// do not use a ulp/epsilon check
|
||||
if (entry < minValue) {
|
||||
minValue = entry;
|
||||
minPos = i;
|
||||
}
|
||||
}
|
||||
return minPos;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the row with the minimum ratio as given by the minimum ratio test (MRT).
|
||||
*
|
||||
* @param tableau Simple tableau for the problem.
|
||||
* @param col Column to test the ratio of (see {@link #getPivotColumn(SimplexTableau)}).
|
||||
* @return the row with the minimum ratio.
|
||||
*/
|
||||
private Integer getPivotRow(SimplexTableau tableau, final int col) {
|
||||
// create a list of all the rows that tie for the lowest score in the minimum ratio test
|
||||
List<Integer> minRatioPositions = new ArrayList<Integer>();
|
||||
double minRatio = Double.MAX_VALUE;
|
||||
for (int i = tableau.getNumObjectiveFunctions(); i < tableau.getHeight(); i++) {
|
||||
final double rhs = tableau.getEntry(i, tableau.getWidth() - 1);
|
||||
final double entry = tableau.getEntry(i, col);
|
||||
|
||||
if (Precision.compareTo(entry, 0d, maxUlps) > 0) {
|
||||
final double ratio = rhs / entry;
|
||||
// check if the entry is strictly equal to the current min ratio
|
||||
// do not use a ulp/epsilon check
|
||||
final int cmp = Double.compare(ratio, minRatio);
|
||||
if (cmp == 0) {
|
||||
minRatioPositions.add(i);
|
||||
} else if (cmp < 0) {
|
||||
minRatio = ratio;
|
||||
minRatioPositions = new ArrayList<Integer>();
|
||||
minRatioPositions.add(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (minRatioPositions.size() == 0) {
|
||||
return null;
|
||||
} else if (minRatioPositions.size() > 1) {
|
||||
// there's a degeneracy as indicated by a tie in the minimum ratio test
|
||||
|
||||
// 1. check if there's an artificial variable that can be forced out of the basis
|
||||
if (tableau.getNumArtificialVariables() > 0) {
|
||||
for (Integer row : minRatioPositions) {
|
||||
for (int i = 0; i < tableau.getNumArtificialVariables(); i++) {
|
||||
int column = i + tableau.getArtificialVariableOffset();
|
||||
final double entry = tableau.getEntry(row, column);
|
||||
if (Precision.equals(entry, 1d, maxUlps) && row.equals(tableau.getBasicRow(column))) {
|
||||
return row;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. apply Bland's rule to prevent cycling:
|
||||
// take the row for which the corresponding basic variable has the smallest index
|
||||
//
|
||||
// see http://www.stanford.edu/class/msande310/blandrule.pdf
|
||||
// see http://en.wikipedia.org/wiki/Bland%27s_rule (not equivalent to the above paper)
|
||||
//
|
||||
// Additional heuristic: if we did not get a solution after half of maxIterations
|
||||
// revert to the simple case of just returning the top-most row
|
||||
// This heuristic is based on empirical data gathered while investigating MATH-828.
|
||||
if (getEvaluations() < getMaxEvaluations() / 2) {
|
||||
Integer minRow = null;
|
||||
int minIndex = tableau.getWidth();
|
||||
final int varStart = tableau.getNumObjectiveFunctions();
|
||||
final int varEnd = tableau.getWidth() - 1;
|
||||
for (Integer row : minRatioPositions) {
|
||||
for (int i = varStart; i < varEnd && !row.equals(minRow); i++) {
|
||||
final Integer basicRow = tableau.getBasicRow(i);
|
||||
if (basicRow != null && basicRow.equals(row)) {
|
||||
if (i < minIndex) {
|
||||
minIndex = i;
|
||||
minRow = row;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return minRow;
|
||||
}
|
||||
}
|
||||
return minRatioPositions.get(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs one iteration of the Simplex method on the given model.
|
||||
*
|
||||
* @param tableau Simple tableau for the problem.
|
||||
* @throws TooManyIterationsException if the allowed number of iterations has been exhausted.
|
||||
* @throws UnboundedSolutionException if the model is found not to have a bounded solution.
|
||||
*/
|
||||
protected void doIteration(final SimplexTableau tableau)
|
||||
throws TooManyIterationsException,
|
||||
UnboundedSolutionException {
|
||||
|
||||
incrementIterationCount();
|
||||
|
||||
Integer pivotCol = getPivotColumn(tableau);
|
||||
Integer pivotRow = getPivotRow(tableau, pivotCol);
|
||||
if (pivotRow == null) {
|
||||
throw new UnboundedSolutionException();
|
||||
}
|
||||
|
||||
// set the pivot element to 1
|
||||
double pivotVal = tableau.getEntry(pivotRow, pivotCol);
|
||||
tableau.divideRow(pivotRow, pivotVal);
|
||||
|
||||
// set the rest of the pivot column to 0
|
||||
for (int i = 0; i < tableau.getHeight(); i++) {
|
||||
if (i != pivotRow) {
|
||||
final double multiplier = tableau.getEntry(i, pivotCol);
|
||||
tableau.subtractRow(i, pivotRow, multiplier);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Solves Phase 1 of the Simplex method.
|
||||
*
|
||||
* @param tableau Simple tableau for the problem.
|
||||
* @throws TooManyIterationsException if the allowed number of iterations has been exhausted.
|
||||
* @throws UnboundedSolutionException if the model is found not to have a bounded solution.
|
||||
* @throws NoFeasibleSolutionException if there is no feasible solution?
|
||||
*/
|
||||
protected void solvePhase1(final SimplexTableau tableau)
|
||||
throws TooManyIterationsException,
|
||||
UnboundedSolutionException,
|
||||
NoFeasibleSolutionException {
|
||||
|
||||
// make sure we're in Phase 1
|
||||
if (tableau.getNumArtificialVariables() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
while (!tableau.isOptimal()) {
|
||||
doIteration(tableau);
|
||||
}
|
||||
|
||||
// if W is not zero then we have no feasible solution
|
||||
if (!Precision.equals(tableau.getEntry(0, tableau.getRhsOffset()), 0d, epsilon)) {
|
||||
throw new NoFeasibleSolutionException();
|
||||
}
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public PointValuePair doOptimize()
|
||||
throws TooManyIterationsException,
|
||||
UnboundedSolutionException,
|
||||
NoFeasibleSolutionException {
|
||||
final SimplexTableau tableau =
|
||||
new SimplexTableau(getFunction(),
|
||||
getConstraints(),
|
||||
getGoalType(),
|
||||
isRestrictedToNonNegative(),
|
||||
epsilon,
|
||||
maxUlps);
|
||||
|
||||
solvePhase1(tableau);
|
||||
tableau.dropPhase1Objective();
|
||||
|
||||
while (!tableau.isOptimal()) {
|
||||
doIteration(tableau);
|
||||
}
|
||||
return tableau.getSolution();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,637 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.linear.MatrixUtils;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.RealVector;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.util.Precision;
|
||||
|
||||
/**
|
||||
* A tableau for use in the Simplex method.
|
||||
*
|
||||
* <p>
|
||||
* Example:
|
||||
* <pre>
|
||||
* W | Z | x1 | x2 | x- | s1 | s2 | a1 | RHS
|
||||
* ---------------------------------------------------
|
||||
* -1 0 0 0 0 0 0 1 0 <= phase 1 objective
|
||||
* 0 1 -15 -10 0 0 0 0 0 <= phase 2 objective
|
||||
* 0 0 1 0 0 1 0 0 2 <= constraint 1
|
||||
* 0 0 0 1 0 0 1 0 3 <= constraint 2
|
||||
* 0 0 1 1 0 0 0 1 4 <= constraint 3
|
||||
* </pre>
|
||||
* W: Phase 1 objective function</br>
|
||||
* Z: Phase 2 objective function</br>
|
||||
* x1 & x2: Decision variables</br>
|
||||
* x-: Extra decision variable to allow for negative values</br>
|
||||
* s1 & s2: Slack/Surplus variables</br>
|
||||
* a1: Artificial variable</br>
|
||||
* RHS: Right hand side</br>
|
||||
* </p>
|
||||
* @version $Id: SimplexTableau.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
class SimplexTableau implements Serializable {
|
||||
|
||||
/** Column label for negative vars. */
|
||||
private static final String NEGATIVE_VAR_COLUMN_LABEL = "x-";
|
||||
|
||||
/** Default amount of error to accept in floating point comparisons (as ulps). */
|
||||
private static final int DEFAULT_ULPS = 10;
|
||||
|
||||
/** The cut-off threshold to zero-out entries. */
|
||||
private static final double CUTOFF_THRESHOLD = 1e-12;
|
||||
|
||||
/** Serializable version identifier. */
|
||||
private static final long serialVersionUID = -1369660067587938365L;
|
||||
|
||||
/** Linear objective function. */
|
||||
private final LinearObjectiveFunction f;
|
||||
|
||||
/** Linear constraints. */
|
||||
private final List<LinearConstraint> constraints;
|
||||
|
||||
/** Whether to restrict the variables to non-negative values. */
|
||||
private final boolean restrictToNonNegative;
|
||||
|
||||
/** The variables each column represents */
|
||||
private final List<String> columnLabels = new ArrayList<String>();
|
||||
|
||||
/** Simple tableau. */
|
||||
private transient RealMatrix tableau;
|
||||
|
||||
/** Number of decision variables. */
|
||||
private final int numDecisionVariables;
|
||||
|
||||
/** Number of slack variables. */
|
||||
private final int numSlackVariables;
|
||||
|
||||
/** Number of artificial variables. */
|
||||
private int numArtificialVariables;
|
||||
|
||||
/** Amount of error to accept when checking for optimality. */
|
||||
private final double epsilon;
|
||||
|
||||
/** Amount of error to accept in floating point comparisons. */
|
||||
private final int maxUlps;
|
||||
|
||||
/**
|
||||
* Builds a tableau for a linear problem.
|
||||
*
|
||||
* @param f Linear objective function.
|
||||
* @param constraints Linear constraints.
|
||||
* @param goalType Optimization goal: either {@link GoalType#MAXIMIZE}
|
||||
* or {@link GoalType#MINIMIZE}.
|
||||
* @param restrictToNonNegative Whether to restrict the variables to non-negative values.
|
||||
* @param epsilon Amount of error to accept when checking for optimality.
|
||||
*/
|
||||
SimplexTableau(final LinearObjectiveFunction f,
|
||||
final Collection<LinearConstraint> constraints,
|
||||
final GoalType goalType,
|
||||
final boolean restrictToNonNegative,
|
||||
final double epsilon) {
|
||||
this(f, constraints, goalType, restrictToNonNegative, epsilon, DEFAULT_ULPS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a tableau for a linear problem.
|
||||
* @param f linear objective function
|
||||
* @param constraints linear constraints
|
||||
* @param goalType type of optimization goal: either {@link GoalType#MAXIMIZE} or {@link GoalType#MINIMIZE}
|
||||
* @param restrictToNonNegative whether to restrict the variables to non-negative values
|
||||
* @param epsilon amount of error to accept when checking for optimality
|
||||
* @param maxUlps amount of error to accept in floating point comparisons
|
||||
*/
|
||||
SimplexTableau(final LinearObjectiveFunction f,
|
||||
final Collection<LinearConstraint> constraints,
|
||||
final GoalType goalType,
|
||||
final boolean restrictToNonNegative,
|
||||
final double epsilon,
|
||||
final int maxUlps) {
|
||||
this.f = f;
|
||||
this.constraints = normalizeConstraints(constraints);
|
||||
this.restrictToNonNegative = restrictToNonNegative;
|
||||
this.epsilon = epsilon;
|
||||
this.maxUlps = maxUlps;
|
||||
this.numDecisionVariables = f.getCoefficients().getDimension() +
|
||||
(restrictToNonNegative ? 0 : 1);
|
||||
this.numSlackVariables = getConstraintTypeCounts(Relationship.LEQ) +
|
||||
getConstraintTypeCounts(Relationship.GEQ);
|
||||
this.numArtificialVariables = getConstraintTypeCounts(Relationship.EQ) +
|
||||
getConstraintTypeCounts(Relationship.GEQ);
|
||||
this.tableau = createTableau(goalType == GoalType.MAXIMIZE);
|
||||
initializeColumnLabels();
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the labels for the columns.
|
||||
*/
|
||||
protected void initializeColumnLabels() {
|
||||
if (getNumObjectiveFunctions() == 2) {
|
||||
columnLabels.add("W");
|
||||
}
|
||||
columnLabels.add("Z");
|
||||
for (int i = 0; i < getOriginalNumDecisionVariables(); i++) {
|
||||
columnLabels.add("x" + i);
|
||||
}
|
||||
if (!restrictToNonNegative) {
|
||||
columnLabels.add(NEGATIVE_VAR_COLUMN_LABEL);
|
||||
}
|
||||
for (int i = 0; i < getNumSlackVariables(); i++) {
|
||||
columnLabels.add("s" + i);
|
||||
}
|
||||
for (int i = 0; i < getNumArtificialVariables(); i++) {
|
||||
columnLabels.add("a" + i);
|
||||
}
|
||||
columnLabels.add("RHS");
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the tableau by itself.
|
||||
* @param maximize if true, goal is to maximize the objective function
|
||||
* @return created tableau
|
||||
*/
|
||||
protected RealMatrix createTableau(final boolean maximize) {
|
||||
|
||||
// create a matrix of the correct size
|
||||
int width = numDecisionVariables + numSlackVariables +
|
||||
numArtificialVariables + getNumObjectiveFunctions() + 1; // + 1 is for RHS
|
||||
int height = constraints.size() + getNumObjectiveFunctions();
|
||||
Array2DRowRealMatrix matrix = new Array2DRowRealMatrix(height, width);
|
||||
|
||||
// initialize the objective function rows
|
||||
if (getNumObjectiveFunctions() == 2) {
|
||||
matrix.setEntry(0, 0, -1);
|
||||
}
|
||||
int zIndex = (getNumObjectiveFunctions() == 1) ? 0 : 1;
|
||||
matrix.setEntry(zIndex, zIndex, maximize ? 1 : -1);
|
||||
RealVector objectiveCoefficients =
|
||||
maximize ? f.getCoefficients().mapMultiply(-1) : f.getCoefficients();
|
||||
copyArray(objectiveCoefficients.toArray(), matrix.getDataRef()[zIndex]);
|
||||
matrix.setEntry(zIndex, width - 1,
|
||||
maximize ? f.getConstantTerm() : -1 * f.getConstantTerm());
|
||||
|
||||
if (!restrictToNonNegative) {
|
||||
matrix.setEntry(zIndex, getSlackVariableOffset() - 1,
|
||||
getInvertedCoefficientSum(objectiveCoefficients));
|
||||
}
|
||||
|
||||
// initialize the constraint rows
|
||||
int slackVar = 0;
|
||||
int artificialVar = 0;
|
||||
for (int i = 0; i < constraints.size(); i++) {
|
||||
LinearConstraint constraint = constraints.get(i);
|
||||
int row = getNumObjectiveFunctions() + i;
|
||||
|
||||
// decision variable coefficients
|
||||
copyArray(constraint.getCoefficients().toArray(), matrix.getDataRef()[row]);
|
||||
|
||||
// x-
|
||||
if (!restrictToNonNegative) {
|
||||
matrix.setEntry(row, getSlackVariableOffset() - 1,
|
||||
getInvertedCoefficientSum(constraint.getCoefficients()));
|
||||
}
|
||||
|
||||
// RHS
|
||||
matrix.setEntry(row, width - 1, constraint.getValue());
|
||||
|
||||
// slack variables
|
||||
if (constraint.getRelationship() == Relationship.LEQ) {
|
||||
matrix.setEntry(row, getSlackVariableOffset() + slackVar++, 1); // slack
|
||||
} else if (constraint.getRelationship() == Relationship.GEQ) {
|
||||
matrix.setEntry(row, getSlackVariableOffset() + slackVar++, -1); // excess
|
||||
}
|
||||
|
||||
// artificial variables
|
||||
if ((constraint.getRelationship() == Relationship.EQ) ||
|
||||
(constraint.getRelationship() == Relationship.GEQ)) {
|
||||
matrix.setEntry(0, getArtificialVariableOffset() + artificialVar, 1);
|
||||
matrix.setEntry(row, getArtificialVariableOffset() + artificialVar++, 1);
|
||||
matrix.setRowVector(0, matrix.getRowVector(0).subtract(matrix.getRowVector(row)));
|
||||
}
|
||||
}
|
||||
|
||||
return matrix;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get new versions of the constraints which have positive right hand sides.
|
||||
* @param originalConstraints original (not normalized) constraints
|
||||
* @return new versions of the constraints
|
||||
*/
|
||||
public List<LinearConstraint> normalizeConstraints(Collection<LinearConstraint> originalConstraints) {
|
||||
List<LinearConstraint> normalized = new ArrayList<LinearConstraint>();
|
||||
for (LinearConstraint constraint : originalConstraints) {
|
||||
normalized.add(normalize(constraint));
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a new equation equivalent to this one with a positive right hand side.
|
||||
* @param constraint reference constraint
|
||||
* @return new equation
|
||||
*/
|
||||
private LinearConstraint normalize(final LinearConstraint constraint) {
|
||||
if (constraint.getValue() < 0) {
|
||||
return new LinearConstraint(constraint.getCoefficients().mapMultiply(-1),
|
||||
constraint.getRelationship().oppositeRelationship(),
|
||||
-1 * constraint.getValue());
|
||||
}
|
||||
return new LinearConstraint(constraint.getCoefficients(),
|
||||
constraint.getRelationship(), constraint.getValue());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of objective functions in this tableau.
|
||||
* @return 2 for Phase 1. 1 for Phase 2.
|
||||
*/
|
||||
protected final int getNumObjectiveFunctions() {
|
||||
return this.numArtificialVariables > 0 ? 2 : 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a count of constraints corresponding to a specified relationship.
|
||||
* @param relationship relationship to count
|
||||
* @return number of constraint with the specified relationship
|
||||
*/
|
||||
private int getConstraintTypeCounts(final Relationship relationship) {
|
||||
int count = 0;
|
||||
for (final LinearConstraint constraint : constraints) {
|
||||
if (constraint.getRelationship() == relationship) {
|
||||
++count;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the -1 times the sum of all coefficients in the given array.
|
||||
* @param coefficients coefficients to sum
|
||||
* @return the -1 times the sum of all coefficients in the given array.
|
||||
*/
|
||||
protected static double getInvertedCoefficientSum(final RealVector coefficients) {
|
||||
double sum = 0;
|
||||
for (double coefficient : coefficients.toArray()) {
|
||||
sum -= coefficient;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks whether the given column is basic.
|
||||
* @param col index of the column to check
|
||||
* @return the row that the variable is basic in. null if the column is not basic
|
||||
*/
|
||||
protected Integer getBasicRow(final int col) {
|
||||
Integer row = null;
|
||||
for (int i = 0; i < getHeight(); i++) {
|
||||
final double entry = getEntry(i, col);
|
||||
if (Precision.equals(entry, 1d, maxUlps) && (row == null)) {
|
||||
row = i;
|
||||
} else if (!Precision.equals(entry, 0d, maxUlps)) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return row;
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes the phase 1 objective function, positive cost non-artificial variables,
|
||||
* and the non-basic artificial variables from this tableau.
|
||||
*/
|
||||
protected void dropPhase1Objective() {
|
||||
if (getNumObjectiveFunctions() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
Set<Integer> columnsToDrop = new TreeSet<Integer>();
|
||||
columnsToDrop.add(0);
|
||||
|
||||
// positive cost non-artificial variables
|
||||
for (int i = getNumObjectiveFunctions(); i < getArtificialVariableOffset(); i++) {
|
||||
final double entry = tableau.getEntry(0, i);
|
||||
if (Precision.compareTo(entry, 0d, epsilon) > 0) {
|
||||
columnsToDrop.add(i);
|
||||
}
|
||||
}
|
||||
|
||||
// non-basic artificial variables
|
||||
for (int i = 0; i < getNumArtificialVariables(); i++) {
|
||||
int col = i + getArtificialVariableOffset();
|
||||
if (getBasicRow(col) == null) {
|
||||
columnsToDrop.add(col);
|
||||
}
|
||||
}
|
||||
|
||||
double[][] matrix = new double[getHeight() - 1][getWidth() - columnsToDrop.size()];
|
||||
for (int i = 1; i < getHeight(); i++) {
|
||||
int col = 0;
|
||||
for (int j = 0; j < getWidth(); j++) {
|
||||
if (!columnsToDrop.contains(j)) {
|
||||
matrix[i - 1][col++] = tableau.getEntry(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// remove the columns in reverse order so the indices are correct
|
||||
Integer[] drop = columnsToDrop.toArray(new Integer[columnsToDrop.size()]);
|
||||
for (int i = drop.length - 1; i >= 0; i--) {
|
||||
columnLabels.remove((int) drop[i]);
|
||||
}
|
||||
|
||||
this.tableau = new Array2DRowRealMatrix(matrix);
|
||||
this.numArtificialVariables = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param src the source array
|
||||
* @param dest the destination array
|
||||
*/
|
||||
private void copyArray(final double[] src, final double[] dest) {
|
||||
System.arraycopy(src, 0, dest, getNumObjectiveFunctions(), src.length);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns whether the problem is at an optimal state.
|
||||
* @return whether the model has been solved
|
||||
*/
|
||||
boolean isOptimal() {
|
||||
for (int i = getNumObjectiveFunctions(); i < getWidth() - 1; i++) {
|
||||
final double entry = tableau.getEntry(0, i);
|
||||
if (Precision.compareTo(entry, 0d, epsilon) < 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current solution.
|
||||
* @return current solution
|
||||
*/
|
||||
protected PointValuePair getSolution() {
|
||||
int negativeVarColumn = columnLabels.indexOf(NEGATIVE_VAR_COLUMN_LABEL);
|
||||
Integer negativeVarBasicRow = negativeVarColumn > 0 ? getBasicRow(negativeVarColumn) : null;
|
||||
double mostNegative = negativeVarBasicRow == null ? 0 : getEntry(negativeVarBasicRow, getRhsOffset());
|
||||
|
||||
Set<Integer> basicRows = new HashSet<Integer>();
|
||||
double[] coefficients = new double[getOriginalNumDecisionVariables()];
|
||||
for (int i = 0; i < coefficients.length; i++) {
|
||||
int colIndex = columnLabels.indexOf("x" + i);
|
||||
if (colIndex < 0) {
|
||||
coefficients[i] = 0;
|
||||
continue;
|
||||
}
|
||||
Integer basicRow = getBasicRow(colIndex);
|
||||
if (basicRow != null && basicRow == 0) {
|
||||
// if the basic row is found to be the objective function row
|
||||
// set the coefficient to 0 -> this case handles unconstrained
|
||||
// variables that are still part of the objective function
|
||||
coefficients[i] = 0;
|
||||
} else if (basicRows.contains(basicRow)) {
|
||||
// if multiple variables can take a given value
|
||||
// then we choose the first and set the rest equal to 0
|
||||
coefficients[i] = 0 - (restrictToNonNegative ? 0 : mostNegative);
|
||||
} else {
|
||||
basicRows.add(basicRow);
|
||||
coefficients[i] =
|
||||
(basicRow == null ? 0 : getEntry(basicRow, getRhsOffset())) -
|
||||
(restrictToNonNegative ? 0 : mostNegative);
|
||||
}
|
||||
}
|
||||
return new PointValuePair(coefficients, f.value(coefficients));
|
||||
}
|
||||
|
||||
/**
|
||||
* Subtracts a multiple of one row from another.
|
||||
* <p>
|
||||
* After application of this operation, the following will hold:
|
||||
* <pre>minuendRow = minuendRow - multiple * subtrahendRow</pre>
|
||||
*
|
||||
* @param dividendRow index of the row
|
||||
* @param divisor value of the divisor
|
||||
*/
|
||||
protected void divideRow(final int dividendRow, final double divisor) {
|
||||
for (int j = 0; j < getWidth(); j++) {
|
||||
tableau.setEntry(dividendRow, j, tableau.getEntry(dividendRow, j) / divisor);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Subtracts a multiple of one row from another.
|
||||
* <p>
|
||||
* After application of this operation, the following will hold:
|
||||
* <pre>minuendRow = minuendRow - multiple * subtrahendRow</pre>
|
||||
*
|
||||
* @param minuendRow row index
|
||||
* @param subtrahendRow row index
|
||||
* @param multiple multiplication factor
|
||||
*/
|
||||
protected void subtractRow(final int minuendRow, final int subtrahendRow,
|
||||
final double multiple) {
|
||||
for (int i = 0; i < getWidth(); i++) {
|
||||
double result = tableau.getEntry(minuendRow, i) - tableau.getEntry(subtrahendRow, i) * multiple;
|
||||
// cut-off values smaller than the CUTOFF_THRESHOLD, otherwise may lead to numerical instabilities
|
||||
if (FastMath.abs(result) < CUTOFF_THRESHOLD) {
|
||||
result = 0.0;
|
||||
}
|
||||
tableau.setEntry(minuendRow, i, result);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the width of the tableau.
|
||||
* @return width of the tableau
|
||||
*/
|
||||
protected final int getWidth() {
|
||||
return tableau.getColumnDimension();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the height of the tableau.
|
||||
* @return height of the tableau
|
||||
*/
|
||||
protected final int getHeight() {
|
||||
return tableau.getRowDimension();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an entry of the tableau.
|
||||
* @param row row index
|
||||
* @param column column index
|
||||
* @return entry at (row, column)
|
||||
*/
|
||||
protected final double getEntry(final int row, final int column) {
|
||||
return tableau.getEntry(row, column);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set an entry of the tableau.
|
||||
* @param row row index
|
||||
* @param column column index
|
||||
* @param value for the entry
|
||||
*/
|
||||
protected final void setEntry(final int row, final int column,
|
||||
final double value) {
|
||||
tableau.setEntry(row, column, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the offset of the first slack variable.
|
||||
* @return offset of the first slack variable
|
||||
*/
|
||||
protected final int getSlackVariableOffset() {
|
||||
return getNumObjectiveFunctions() + numDecisionVariables;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the offset of the first artificial variable.
|
||||
* @return offset of the first artificial variable
|
||||
*/
|
||||
protected final int getArtificialVariableOffset() {
|
||||
return getNumObjectiveFunctions() + numDecisionVariables + numSlackVariables;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the offset of the right hand side.
|
||||
* @return offset of the right hand side
|
||||
*/
|
||||
protected final int getRhsOffset() {
|
||||
return getWidth() - 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of decision variables.
|
||||
* <p>
|
||||
* If variables are not restricted to positive values, this will include 1 extra decision variable to represent
|
||||
* the absolute value of the most negative variable.
|
||||
*
|
||||
* @return number of decision variables
|
||||
* @see #getOriginalNumDecisionVariables()
|
||||
*/
|
||||
protected final int getNumDecisionVariables() {
|
||||
return numDecisionVariables;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the original number of decision variables.
|
||||
* @return original number of decision variables
|
||||
* @see #getNumDecisionVariables()
|
||||
*/
|
||||
protected final int getOriginalNumDecisionVariables() {
|
||||
return f.getCoefficients().getDimension();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of slack variables.
|
||||
* @return number of slack variables
|
||||
*/
|
||||
protected final int getNumSlackVariables() {
|
||||
return numSlackVariables;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of artificial variables.
|
||||
* @return number of artificial variables
|
||||
*/
|
||||
protected final int getNumArtificialVariables() {
|
||||
return numArtificialVariables;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the tableau data.
|
||||
* @return tableau data
|
||||
*/
|
||||
protected final double[][] getData() {
|
||||
return tableau.getData();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object other) {
|
||||
|
||||
if (this == other) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (other instanceof SimplexTableau) {
|
||||
SimplexTableau rhs = (SimplexTableau) other;
|
||||
return (restrictToNonNegative == rhs.restrictToNonNegative) &&
|
||||
(numDecisionVariables == rhs.numDecisionVariables) &&
|
||||
(numSlackVariables == rhs.numSlackVariables) &&
|
||||
(numArtificialVariables == rhs.numArtificialVariables) &&
|
||||
(epsilon == rhs.epsilon) &&
|
||||
(maxUlps == rhs.maxUlps) &&
|
||||
f.equals(rhs.f) &&
|
||||
constraints.equals(rhs.constraints) &&
|
||||
tableau.equals(rhs.tableau);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Boolean.valueOf(restrictToNonNegative).hashCode() ^
|
||||
numDecisionVariables ^
|
||||
numSlackVariables ^
|
||||
numArtificialVariables ^
|
||||
Double.valueOf(epsilon).hashCode() ^
|
||||
maxUlps ^
|
||||
f.hashCode() ^
|
||||
constraints.hashCode() ^
|
||||
tableau.hashCode();
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize the instance.
|
||||
* @param oos stream where object should be written
|
||||
* @throws IOException if object cannot be written to stream
|
||||
*/
|
||||
private void writeObject(ObjectOutputStream oos)
|
||||
throws IOException {
|
||||
oos.defaultWriteObject();
|
||||
MatrixUtils.serializeRealMatrix(tableau, oos);
|
||||
}
|
||||
|
||||
/**
|
||||
* Deserialize the instance.
|
||||
* @param ois stream from which the object should be read
|
||||
* @throws ClassNotFoundException if a class in the stream cannot be found
|
||||
* @throws IOException if object cannot be read from the stream
|
||||
*/
|
||||
private void readObject(ObjectInputStream ois)
|
||||
throws ClassNotFoundException, IOException {
|
||||
ois.defaultReadObject();
|
||||
MatrixUtils.deserializeRealMatrix(this, "tableau", ois);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
import org.apache.commons.math3.exception.MathIllegalStateException;
|
||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
|
||||
/**
|
||||
* This class represents exceptions thrown by optimizers when a solution escapes to infinity.
|
||||
*
|
||||
* @version $Id: UnboundedSolutionException.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public class UnboundedSolutionException extends MathIllegalStateException {
|
||||
/** Serializable version identifier. */
|
||||
private static final long serialVersionUID = 940539497277290619L;
|
||||
|
||||
/**
|
||||
* Simple constructor using a default message.
|
||||
*/
|
||||
public UnboundedSolutionException() {
|
||||
super(LocalizedFormats.UNBOUNDED_SOLUTION);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
/**
|
||||
* Optimization algorithms for linear constrained problems.
|
||||
*/
|
|
@ -0,0 +1,105 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||
|
||||
/**
|
||||
* Base class for implementing optimizers for multivariate scalar
|
||||
* differentiable functions.
|
||||
* It contains boiler-plate code for dealing with gradient evaluation.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public abstract class GradientMultivariateOptimizer
|
||||
extends MultivariateOptimizer {
|
||||
/**
|
||||
* Gradient of the objective function.
|
||||
*/
|
||||
private MultivariateVectorFunction gradient;
|
||||
|
||||
/**
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
protected GradientMultivariateOptimizer(ConvergenceChecker<PointValuePair> checker) {
|
||||
super(checker);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the gradient vector.
|
||||
*
|
||||
* @param params Point at which the gradient must be evaluated.
|
||||
* @return the gradient at the specified point.
|
||||
*/
|
||||
protected double[] computeObjectiveGradient(final double[] params) {
|
||||
return gradient.value(params);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.GoalType}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.ObjectiveFunction}</li>
|
||||
* <li>{@link ObjectiveFunctionGradient}</li>
|
||||
* </ul>
|
||||
* @return {@inheritDoc}
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations (of the objective function) is exceeded.
|
||||
*/
|
||||
@Override
|
||||
public PointValuePair optimize(OptimizationData... optData)
|
||||
throws TooManyEvaluationsException {
|
||||
// Retrieve settings.
|
||||
parseOptimizationData(optData);
|
||||
// Set up base class and perform computation.
|
||||
return super.optimize(optData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Scans the list of (required and optional) optimization data that
|
||||
* characterize the problem.
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link ObjectiveFunction}</li>
|
||||
* </ul>
|
||||
*/
|
||||
private void parseOptimizationData(OptimizationData... optData) {
|
||||
// The existing values (as set by the previous call) are reused if
|
||||
// not provided in the argument list.
|
||||
for (OptimizationData data : optData) {
|
||||
if (data instanceof ObjectiveFunctionGradient) {
|
||||
gradient = ((ObjectiveFunctionGradient) data).getObjectiveFunctionGradient();
|
||||
// If more data must be parsed, this statement _must_ be
|
||||
// changed to "continue".
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
|
||||
/**
|
||||
* This class converts
|
||||
* {@link MultivariateVectorFunction vectorial objective functions} to
|
||||
* {@link MultivariateFunction scalar objective functions}
|
||||
* when the goal is to minimize them.
|
||||
* <br/>
|
||||
* This class is mostly used when the vectorial objective function represents
|
||||
* a theoretical result computed from a point set applied to a model and
|
||||
* the models point must be adjusted to fit the theoretical result to some
|
||||
* reference observations. The observations may be obtained for example from
|
||||
* physical measurements whether the model is built from theoretical
|
||||
* considerations.
|
||||
* <br/>
|
||||
* This class computes a possibly weighted squared sum of the residuals, which is
|
||||
* a scalar value. The residuals are the difference between the theoretical model
|
||||
* (i.e. the output of the vectorial objective function) and the observations. The
|
||||
* class implements the {@link MultivariateFunction} interface and can therefore be
|
||||
* minimized by any optimizer supporting scalar objectives functions.This is one way
|
||||
* to perform a least square estimation. There are other ways to do this without using
|
||||
* this converter, as some optimization algorithms directly support vectorial objective
|
||||
* functions.
|
||||
* <br/>
|
||||
* This class support combination of residuals with or without weights and correlations.
|
||||
*
|
||||
* @see MultivariateFunction
|
||||
* @see MultivariateVectorFunction
|
||||
* @version $Id: LeastSquaresConverter.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
|
||||
public class LeastSquaresConverter implements MultivariateFunction {
|
||||
/** Underlying vectorial function. */
|
||||
private final MultivariateVectorFunction function;
|
||||
/** Observations to be compared to objective function to compute residuals. */
|
||||
private final double[] observations;
|
||||
/** Optional weights for the residuals. */
|
||||
private final double[] weights;
|
||||
/** Optional scaling matrix (weight and correlations) for the residuals. */
|
||||
private final RealMatrix scale;
|
||||
|
||||
/**
|
||||
* Builds a simple converter for uncorrelated residuals with identical
|
||||
* weights.
|
||||
*
|
||||
* @param function vectorial residuals function to wrap
|
||||
* @param observations observations to be compared to objective function to compute residuals
|
||||
*/
|
||||
public LeastSquaresConverter(final MultivariateVectorFunction function,
|
||||
final double[] observations) {
|
||||
this.function = function;
|
||||
this.observations = observations.clone();
|
||||
this.weights = null;
|
||||
this.scale = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a simple converter for uncorrelated residuals with the
|
||||
* specified weights.
|
||||
* <p>
|
||||
* The scalar objective function value is computed as:
|
||||
* <pre>
|
||||
* objective = ∑weight<sub>i</sub>(observation<sub>i</sub>-objective<sub>i</sub>)<sup>2</sup>
|
||||
* </pre>
|
||||
* </p>
|
||||
* <p>
|
||||
* Weights can be used for example to combine residuals with different standard
|
||||
* deviations. As an example, consider a residuals array in which even elements
|
||||
* are angular measurements in degrees with a 0.01° standard deviation and
|
||||
* odd elements are distance measurements in meters with a 15m standard deviation.
|
||||
* In this case, the weights array should be initialized with value
|
||||
* 1.0/(0.01<sup>2</sup>) in the even elements and 1.0/(15.0<sup>2</sup>) in the
|
||||
* odd elements (i.e. reciprocals of variances).
|
||||
* </p>
|
||||
* <p>
|
||||
* The array computed by the objective function, the observations array and the
|
||||
* weights array must have consistent sizes or a {@link DimensionMismatchException}
|
||||
* will be triggered while computing the scalar objective.
|
||||
* </p>
|
||||
*
|
||||
* @param function vectorial residuals function to wrap
|
||||
* @param observations observations to be compared to objective function to compute residuals
|
||||
* @param weights weights to apply to the residuals
|
||||
* @throws DimensionMismatchException if the observations vector and the weights
|
||||
* vector dimensions do not match (objective function dimension is checked only when
|
||||
* the {@link #value(double[])} method is called)
|
||||
*/
|
||||
public LeastSquaresConverter(final MultivariateVectorFunction function,
|
||||
final double[] observations,
|
||||
final double[] weights) {
|
||||
if (observations.length != weights.length) {
|
||||
throw new DimensionMismatchException(observations.length, weights.length);
|
||||
}
|
||||
this.function = function;
|
||||
this.observations = observations.clone();
|
||||
this.weights = weights.clone();
|
||||
this.scale = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a simple converter for correlated residuals with the
|
||||
* specified weights.
|
||||
* <p>
|
||||
* The scalar objective function value is computed as:
|
||||
* <pre>
|
||||
* objective = y<sup>T</sup>y with y = scale×(observation-objective)
|
||||
* </pre>
|
||||
* </p>
|
||||
* <p>
|
||||
* The array computed by the objective function, the observations array and the
|
||||
* the scaling matrix must have consistent sizes or a {@link DimensionMismatchException}
|
||||
* will be triggered while computing the scalar objective.
|
||||
* </p>
|
||||
*
|
||||
* @param function vectorial residuals function to wrap
|
||||
* @param observations observations to be compared to objective function to compute residuals
|
||||
* @param scale scaling matrix
|
||||
* @throws DimensionMismatchException if the observations vector and the scale
|
||||
* matrix dimensions do not match (objective function dimension is checked only when
|
||||
* the {@link #value(double[])} method is called)
|
||||
*/
|
||||
public LeastSquaresConverter(final MultivariateVectorFunction function,
|
||||
final double[] observations,
|
||||
final RealMatrix scale) {
|
||||
if (observations.length != scale.getColumnDimension()) {
|
||||
throw new DimensionMismatchException(observations.length, scale.getColumnDimension());
|
||||
}
|
||||
this.function = function;
|
||||
this.observations = observations.clone();
|
||||
this.weights = null;
|
||||
this.scale = scale.copy();
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double value(final double[] point) {
|
||||
// compute residuals
|
||||
final double[] residuals = function.value(point);
|
||||
if (residuals.length != observations.length) {
|
||||
throw new DimensionMismatchException(residuals.length, observations.length);
|
||||
}
|
||||
for (int i = 0; i < residuals.length; ++i) {
|
||||
residuals[i] -= observations[i];
|
||||
}
|
||||
|
||||
// compute sum of squares
|
||||
double sumSquares = 0;
|
||||
if (weights != null) {
|
||||
for (int i = 0; i < residuals.length; ++i) {
|
||||
final double ri = residuals[i];
|
||||
sumSquares += weights[i] * ri * ri;
|
||||
}
|
||||
} else if (scale != null) {
|
||||
for (final double yi : scale.operate(residuals)) {
|
||||
sumSquares += yi * yi;
|
||||
}
|
||||
} else {
|
||||
for (final double ri : residuals) {
|
||||
sumSquares += ri * ri;
|
||||
}
|
||||
}
|
||||
|
||||
return sumSquares;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.NullArgumentException;
|
||||
import org.apache.commons.math3.random.RandomVectorGenerator;
|
||||
import org.apache.commons.math3.optim.BaseMultiStartMultivariateOptimizer;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
|
||||
/**
|
||||
* Multi-start optimizer.
|
||||
*
|
||||
* This class wraps an optimizer in order to use it several times in
|
||||
* turn with different starting points (trying to avoid being trapped
|
||||
* in a local extremum when looking for a global one).
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.0
|
||||
*/
|
||||
public class MultiStartMultivariateOptimizer
|
||||
extends BaseMultiStartMultivariateOptimizer<PointValuePair> {
|
||||
/** Underlying optimizer. */
|
||||
private final MultivariateOptimizer optimizer;
|
||||
/** Found optima. */
|
||||
private final List<PointValuePair> optima = new ArrayList<PointValuePair>();
|
||||
|
||||
/**
|
||||
* Create a multi-start optimizer from a single-start optimizer.
|
||||
*
|
||||
* @param optimizer Single-start optimizer to wrap.
|
||||
* @param starts Number of starts to perform.
|
||||
* If {@code starts == 1}, the result will be same as if {@code optimizer}
|
||||
* is called directly.
|
||||
* @param generator Random vector generator to use for restarts.
|
||||
* @throws NullArgumentException if {@code optimizer} or {@code generator}
|
||||
* is {@code null}.
|
||||
* @throws NotStrictlyPositiveException if {@code starts < 1}.
|
||||
*/
|
||||
public MultiStartMultivariateOptimizer(final MultivariateOptimizer optimizer,
|
||||
final int starts,
|
||||
final RandomVectorGenerator generator)
|
||||
throws NullArgumentException,
|
||||
NotStrictlyPositiveException {
|
||||
super(optimizer, starts, generator);
|
||||
this.optimizer = optimizer;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public PointValuePair[] getOptima() {
|
||||
Collections.sort(optima, getPairComparator());
|
||||
return optima.toArray(new PointValuePair[0]);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
protected void store(PointValuePair optimum) {
|
||||
optima.add(optimum);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
protected void clear() {
|
||||
optima.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a comparator for sorting the optima.
|
||||
*/
|
||||
private Comparator<PointValuePair> getPairComparator() {
|
||||
return new Comparator<PointValuePair>() {
|
||||
public int compare(final PointValuePair o1,
|
||||
final PointValuePair o2) {
|
||||
if (o1 == null) {
|
||||
return (o2 == null) ? 0 : 1;
|
||||
} else if (o2 == null) {
|
||||
return -1;
|
||||
}
|
||||
final double v1 = o1.getValue();
|
||||
final double v2 = o2.getValue();
|
||||
return (optimizer.getGoalType() == GoalType.MINIMIZE) ?
|
||||
Double.compare(v1, v2) : Double.compare(v2, v1);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,295 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.analysis.UnivariateFunction;
|
||||
import org.apache.commons.math3.analysis.function.Logit;
|
||||
import org.apache.commons.math3.analysis.function.Sigmoid;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.util.MathUtils;
|
||||
|
||||
/**
|
||||
* <p>Adapter for mapping bounded {@link MultivariateFunction} to unbounded ones.</p>
|
||||
*
|
||||
* <p>
|
||||
* This adapter can be used to wrap functions subject to simple bounds on
|
||||
* parameters so they can be used by optimizers that do <em>not</em> directly
|
||||
* support simple bounds.
|
||||
* </p>
|
||||
* <p>
|
||||
* The principle is that the user function that will be wrapped will see its
|
||||
* parameters bounded as required, i.e when its {@code value} method is called
|
||||
* with argument array {@code point}, the elements array will fulfill requirement
|
||||
* {@code lower[i] <= point[i] <= upper[i]} for all i. Some of the components
|
||||
* may be unbounded or bounded only on one side if the corresponding bound is
|
||||
* set to an infinite value. The optimizer will not manage the user function by
|
||||
* itself, but it will handle this adapter and it is this adapter that will take
|
||||
* care the bounds are fulfilled. The adapter {@link #value(double[])} method will
|
||||
* be called by the optimizer with unbound parameters, and the adapter will map
|
||||
* the unbounded value to the bounded range using appropriate functions like
|
||||
* {@link Sigmoid} for double bounded elements for example.
|
||||
* </p>
|
||||
* <p>
|
||||
* As the optimizer sees only unbounded parameters, it should be noted that the
|
||||
* start point or simplex expected by the optimizer should be unbounded, so the
|
||||
* user is responsible for converting his bounded point to unbounded by calling
|
||||
* {@link #boundedToUnbounded(double[])} before providing them to the optimizer.
|
||||
* For the same reason, the point returned by the {@link
|
||||
* org.apache.commons.math3.optimization.BaseMultivariateOptimizer#optimize(int,
|
||||
* MultivariateFunction, org.apache.commons.math3.optimization.GoalType, double[])}
|
||||
* method is unbounded. So to convert this point to bounded, users must call
|
||||
* {@link #unboundedToBounded(double[])} by themselves!</p>
|
||||
* <p>
|
||||
* This adapter is only a poor man solution to simple bounds optimization constraints
|
||||
* that can be used with simple optimizers like
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
||||
* SimplexOptimizer}.
|
||||
* A better solution is to use an optimizer that directly supports simple bounds like
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer
|
||||
* CMAESOptimizer} or
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer
|
||||
* BOBYQAOptimizer}.
|
||||
* One caveat of this poor-man's solution is that behavior near the bounds may be
|
||||
* numerically unstable as bounds are mapped from infinite values.
|
||||
* Another caveat is that convergence values are evaluated by the optimizer with
|
||||
* respect to unbounded variables, so there will be scales differences when
|
||||
* converted to bounded variables.
|
||||
* </p>
|
||||
*
|
||||
* @see MultivariateFunctionPenaltyAdapter
|
||||
*
|
||||
* @version $Id: MultivariateFunctionMappingAdapter.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 3.0
|
||||
*/
|
||||
public class MultivariateFunctionMappingAdapter
|
||||
implements MultivariateFunction {
|
||||
/** Underlying bounded function. */
|
||||
private final MultivariateFunction bounded;
|
||||
/** Mapping functions. */
|
||||
private final Mapper[] mappers;
|
||||
|
||||
/** Simple constructor.
|
||||
* @param bounded bounded function
|
||||
* @param lower lower bounds for each element of the input parameters array
|
||||
* (some elements may be set to {@code Double.NEGATIVE_INFINITY} for
|
||||
* unbounded values)
|
||||
* @param upper upper bounds for each element of the input parameters array
|
||||
* (some elements may be set to {@code Double.POSITIVE_INFINITY} for
|
||||
* unbounded values)
|
||||
* @exception DimensionMismatchException if lower and upper bounds are not
|
||||
* consistent, either according to dimension or to values
|
||||
*/
|
||||
public MultivariateFunctionMappingAdapter(final MultivariateFunction bounded,
|
||||
final double[] lower, final double[] upper) {
|
||||
// safety checks
|
||||
MathUtils.checkNotNull(lower);
|
||||
MathUtils.checkNotNull(upper);
|
||||
if (lower.length != upper.length) {
|
||||
throw new DimensionMismatchException(lower.length, upper.length);
|
||||
}
|
||||
for (int i = 0; i < lower.length; ++i) {
|
||||
// note the following test is written in such a way it also fails for NaN
|
||||
if (!(upper[i] >= lower[i])) {
|
||||
throw new NumberIsTooSmallException(upper[i], lower[i], true);
|
||||
}
|
||||
}
|
||||
|
||||
this.bounded = bounded;
|
||||
this.mappers = new Mapper[lower.length];
|
||||
for (int i = 0; i < mappers.length; ++i) {
|
||||
if (Double.isInfinite(lower[i])) {
|
||||
if (Double.isInfinite(upper[i])) {
|
||||
// element is unbounded, no transformation is needed
|
||||
mappers[i] = new NoBoundsMapper();
|
||||
} else {
|
||||
// element is simple-bounded on the upper side
|
||||
mappers[i] = new UpperBoundMapper(upper[i]);
|
||||
}
|
||||
} else {
|
||||
if (Double.isInfinite(upper[i])) {
|
||||
// element is simple-bounded on the lower side
|
||||
mappers[i] = new LowerBoundMapper(lower[i]);
|
||||
} else {
|
||||
// element is double-bounded
|
||||
mappers[i] = new LowerUpperBoundMapper(lower[i], upper[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps an array from unbounded to bounded.
|
||||
*
|
||||
* @param point Unbounded values.
|
||||
* @return the bounded values.
|
||||
*/
|
||||
public double[] unboundedToBounded(double[] point) {
|
||||
// Map unbounded input point to bounded point.
|
||||
final double[] mapped = new double[mappers.length];
|
||||
for (int i = 0; i < mappers.length; ++i) {
|
||||
mapped[i] = mappers[i].unboundedToBounded(point[i]);
|
||||
}
|
||||
|
||||
return mapped;
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps an array from bounded to unbounded.
|
||||
*
|
||||
* @param point Bounded values.
|
||||
* @return the unbounded values.
|
||||
*/
|
||||
public double[] boundedToUnbounded(double[] point) {
|
||||
// Map bounded input point to unbounded point.
|
||||
final double[] mapped = new double[mappers.length];
|
||||
for (int i = 0; i < mappers.length; ++i) {
|
||||
mapped[i] = mappers[i].boundedToUnbounded(point[i]);
|
||||
}
|
||||
|
||||
return mapped;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the underlying function value from an unbounded point.
|
||||
* <p>
|
||||
* This method simply bounds the unbounded point using the mappings
|
||||
* set up at construction and calls the underlying function using
|
||||
* the bounded point.
|
||||
* </p>
|
||||
* @param point unbounded value
|
||||
* @return underlying function value
|
||||
* @see #unboundedToBounded(double[])
|
||||
*/
|
||||
public double value(double[] point) {
|
||||
return bounded.value(unboundedToBounded(point));
|
||||
}
|
||||
|
||||
/** Mapping interface. */
|
||||
private interface Mapper {
|
||||
/**
|
||||
* Maps a value from unbounded to bounded.
|
||||
*
|
||||
* @param y Unbounded value.
|
||||
* @return the bounded value.
|
||||
*/
|
||||
double unboundedToBounded(double y);
|
||||
|
||||
/**
|
||||
* Maps a value from bounded to unbounded.
|
||||
*
|
||||
* @param x Bounded value.
|
||||
* @return the unbounded value.
|
||||
*/
|
||||
double boundedToUnbounded(double x);
|
||||
}
|
||||
|
||||
/** Local class for no bounds mapping. */
|
||||
private static class NoBoundsMapper implements Mapper {
|
||||
/** {@inheritDoc} */
|
||||
public double unboundedToBounded(final double y) {
|
||||
return y;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double boundedToUnbounded(final double x) {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
/** Local class for lower bounds mapping. */
|
||||
private static class LowerBoundMapper implements Mapper {
|
||||
/** Low bound. */
|
||||
private final double lower;
|
||||
|
||||
/**
|
||||
* Simple constructor.
|
||||
*
|
||||
* @param lower lower bound
|
||||
*/
|
||||
public LowerBoundMapper(final double lower) {
|
||||
this.lower = lower;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double unboundedToBounded(final double y) {
|
||||
return lower + FastMath.exp(y);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double boundedToUnbounded(final double x) {
|
||||
return FastMath.log(x - lower);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/** Local class for upper bounds mapping. */
|
||||
private static class UpperBoundMapper implements Mapper {
|
||||
|
||||
/** Upper bound. */
|
||||
private final double upper;
|
||||
|
||||
/** Simple constructor.
|
||||
* @param upper upper bound
|
||||
*/
|
||||
public UpperBoundMapper(final double upper) {
|
||||
this.upper = upper;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double unboundedToBounded(final double y) {
|
||||
return upper - FastMath.exp(-y);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double boundedToUnbounded(final double x) {
|
||||
return -FastMath.log(upper - x);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/** Local class for lower and bounds mapping. */
|
||||
private static class LowerUpperBoundMapper implements Mapper {
|
||||
/** Function from unbounded to bounded. */
|
||||
private final UnivariateFunction boundingFunction;
|
||||
/** Function from bounded to unbounded. */
|
||||
private final UnivariateFunction unboundingFunction;
|
||||
|
||||
/**
|
||||
* Simple constructor.
|
||||
*
|
||||
* @param lower lower bound
|
||||
* @param upper upper bound
|
||||
*/
|
||||
public LowerUpperBoundMapper(final double lower, final double upper) {
|
||||
boundingFunction = new Sigmoid(lower, upper);
|
||||
unboundingFunction = new Logit(lower, upper);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double unboundedToBounded(final double y) {
|
||||
return boundingFunction.value(y);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double boundedToUnbounded(final double x) {
|
||||
return unboundingFunction.value(x);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.util.MathUtils;
|
||||
|
||||
/**
|
||||
* <p>Adapter extending bounded {@link MultivariateFunction} to an unbouded
|
||||
* domain using a penalty function.</p>
|
||||
*
|
||||
* <p>
|
||||
* This adapter can be used to wrap functions subject to simple bounds on
|
||||
* parameters so they can be used by optimizers that do <em>not</em> directly
|
||||
* support simple bounds.
|
||||
* </p>
|
||||
* <p>
|
||||
* The principle is that the user function that will be wrapped will see its
|
||||
* parameters bounded as required, i.e when its {@code value} method is called
|
||||
* with argument array {@code point}, the elements array will fulfill requirement
|
||||
* {@code lower[i] <= point[i] <= upper[i]} for all i. Some of the components
|
||||
* may be unbounded or bounded only on one side if the corresponding bound is
|
||||
* set to an infinite value. The optimizer will not manage the user function by
|
||||
* itself, but it will handle this adapter and it is this adapter that will take
|
||||
* care the bounds are fulfilled. The adapter {@link #value(double[])} method will
|
||||
* be called by the optimizer with unbound parameters, and the adapter will check
|
||||
* if the parameters is within range or not. If it is in range, then the underlying
|
||||
* user function will be called, and if it is not the value of a penalty function
|
||||
* will be returned instead.
|
||||
* </p>
|
||||
* <p>
|
||||
* This adapter is only a poor-man's solution to simple bounds optimization
|
||||
* constraints that can be used with simple optimizers like
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
||||
* SimplexOptimizer}.
|
||||
* A better solution is to use an optimizer that directly supports simple bounds like
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer
|
||||
* CMAESOptimizer} or
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer
|
||||
* BOBYQAOptimizer}.
|
||||
* One caveat of this poor-man's solution is that if start point or start simplex
|
||||
* is completely outside of the allowed range, only the penalty function is used,
|
||||
* and the optimizer may converge without ever entering the range.
|
||||
* </p>
|
||||
*
|
||||
* @see MultivariateFunctionMappingAdapter
|
||||
*
|
||||
* @version $Id: MultivariateFunctionPenaltyAdapter.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 3.0
|
||||
*/
|
||||
public class MultivariateFunctionPenaltyAdapter
|
||||
implements MultivariateFunction {
|
||||
/** Underlying bounded function. */
|
||||
private final MultivariateFunction bounded;
|
||||
/** Lower bounds. */
|
||||
private final double[] lower;
|
||||
/** Upper bounds. */
|
||||
private final double[] upper;
|
||||
/** Penalty offset. */
|
||||
private final double offset;
|
||||
/** Penalty scales. */
|
||||
private final double[] scale;
|
||||
|
||||
/**
|
||||
* Simple constructor.
|
||||
* <p>
|
||||
* When the optimizer provided points are out of range, the value of the
|
||||
* penalty function will be used instead of the value of the underlying
|
||||
* function. In order for this penalty to be effective in rejecting this
|
||||
* point during the optimization process, the penalty function value should
|
||||
* be defined with care. This value is computed as:
|
||||
* <pre>
|
||||
* penalty(point) = offset + ∑<sub>i</sub>[scale[i] * √|point[i]-boundary[i]|]
|
||||
* </pre>
|
||||
* where indices i correspond to all the components that violates their boundaries.
|
||||
* </p>
|
||||
* <p>
|
||||
* So when attempting a function minimization, offset should be larger than
|
||||
* the maximum expected value of the underlying function and scale components
|
||||
* should all be positive. When attempting a function maximization, offset
|
||||
* should be lesser than the minimum expected value of the underlying function
|
||||
* and scale components should all be negative.
|
||||
* minimization, and lesser than the minimum expected value of the underlying
|
||||
* function when attempting maximization.
|
||||
* </p>
|
||||
* <p>
|
||||
* These choices for the penalty function have two properties. First, all out
|
||||
* of range points will return a function value that is worse than the value
|
||||
* returned by any in range point. Second, the penalty is worse for large
|
||||
* boundaries violation than for small violations, so the optimizer has an hint
|
||||
* about the direction in which it should search for acceptable points.
|
||||
* </p>
|
||||
* @param bounded bounded function
|
||||
* @param lower lower bounds for each element of the input parameters array
|
||||
* (some elements may be set to {@code Double.NEGATIVE_INFINITY} for
|
||||
* unbounded values)
|
||||
* @param upper upper bounds for each element of the input parameters array
|
||||
* (some elements may be set to {@code Double.POSITIVE_INFINITY} for
|
||||
* unbounded values)
|
||||
* @param offset base offset of the penalty function
|
||||
* @param scale scale of the penalty function
|
||||
* @exception DimensionMismatchException if lower bounds, upper bounds and
|
||||
* scales are not consistent, either according to dimension or to bounadary
|
||||
* values
|
||||
*/
|
||||
public MultivariateFunctionPenaltyAdapter(final MultivariateFunction bounded,
|
||||
final double[] lower, final double[] upper,
|
||||
final double offset, final double[] scale) {
|
||||
|
||||
// safety checks
|
||||
MathUtils.checkNotNull(lower);
|
||||
MathUtils.checkNotNull(upper);
|
||||
MathUtils.checkNotNull(scale);
|
||||
if (lower.length != upper.length) {
|
||||
throw new DimensionMismatchException(lower.length, upper.length);
|
||||
}
|
||||
if (lower.length != scale.length) {
|
||||
throw new DimensionMismatchException(lower.length, scale.length);
|
||||
}
|
||||
for (int i = 0; i < lower.length; ++i) {
|
||||
// note the following test is written in such a way it also fails for NaN
|
||||
if (!(upper[i] >= lower[i])) {
|
||||
throw new NumberIsTooSmallException(upper[i], lower[i], true);
|
||||
}
|
||||
}
|
||||
|
||||
this.bounded = bounded;
|
||||
this.lower = lower.clone();
|
||||
this.upper = upper.clone();
|
||||
this.offset = offset;
|
||||
this.scale = scale.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the underlying function value from an unbounded point.
|
||||
* <p>
|
||||
* This method simply returns the value of the underlying function
|
||||
* if the unbounded point already fulfills the bounds, and compute
|
||||
* a replacement value using the offset and scale if bounds are
|
||||
* violated, without calling the function at all.
|
||||
* </p>
|
||||
* @param point unbounded point
|
||||
* @return either underlying function value or penalty function value
|
||||
*/
|
||||
public double value(double[] point) {
|
||||
|
||||
for (int i = 0; i < scale.length; ++i) {
|
||||
if ((point[i] < lower[i]) || (point[i] > upper[i])) {
|
||||
// bound violation starting at this component
|
||||
double sum = 0;
|
||||
for (int j = i; j < scale.length; ++j) {
|
||||
final double overshoot;
|
||||
if (point[j] < lower[j]) {
|
||||
overshoot = scale[j] * (lower[j] - point[j]);
|
||||
} else if (point[j] > upper[j]) {
|
||||
overshoot = scale[j] * (point[j] - upper[j]);
|
||||
} else {
|
||||
overshoot = 0;
|
||||
}
|
||||
sum += FastMath.sqrt(overshoot);
|
||||
}
|
||||
return offset + sum;
|
||||
}
|
||||
}
|
||||
|
||||
// all boundaries are fulfilled, we are in the expected
|
||||
// domain of the underlying function
|
||||
return bounded.value(point);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.optim.BaseMultivariateOptimizer;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.ObjectiveFunction;
|
||||
import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||
|
||||
/**
|
||||
* Base class for a multivariate scalar function optimizer.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public abstract class MultivariateOptimizer
|
||||
extends BaseMultivariateOptimizer<PointValuePair> {
|
||||
/** Objective function. */
|
||||
private MultivariateFunction function;
|
||||
/** Type of optimization. */
|
||||
private GoalType goal;
|
||||
|
||||
/**
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
protected MultivariateOptimizer(ConvergenceChecker<PointValuePair> checker) {
|
||||
super(checker);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
|
||||
* <li>{@link ObjectiveFunction}</li>
|
||||
* <li>{@link GoalType}</li>
|
||||
* </ul>
|
||||
* @return {@inheritDoc}
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations is exceeded.
|
||||
*/
|
||||
@Override
|
||||
public PointValuePair optimize(OptimizationData... optData)
|
||||
throws TooManyEvaluationsException {
|
||||
// Retrieve settings.
|
||||
parseOptimizationData(optData);
|
||||
// Set up base class and perform computation.
|
||||
return super.optimize(optData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Scans the list of (required and optional) optimization data that
|
||||
* characterize the problem.
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link ObjectiveFunction}</li>
|
||||
* <li>{@link GoalType}</li>
|
||||
* </ul>
|
||||
*/
|
||||
private void parseOptimizationData(OptimizationData... optData) {
|
||||
// The existing values (as set by the previous call) are reused if
|
||||
// not provided in the argument list.
|
||||
for (OptimizationData data : optData) {
|
||||
if (data instanceof GoalType) {
|
||||
goal = (GoalType) data;
|
||||
continue;
|
||||
}
|
||||
if (data instanceof ObjectiveFunction) {
|
||||
function = ((ObjectiveFunction) data).getObjectiveFunction();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the optimization type.
|
||||
*/
|
||||
public GoalType getGoalType() {
|
||||
return goal;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the objective function value.
|
||||
* This method <em>must</em> be called by subclasses to enforce the
|
||||
* evaluation counter limit.
|
||||
*
|
||||
* @param params Point at which the objective function must be evaluated.
|
||||
* @return the objective function value at the specified point.
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations is exceeded.
|
||||
*/
|
||||
protected double computeObjectiveValue(double[] params) {
|
||||
super.incrementEvaluationCount();
|
||||
return function.value(params);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
|
||||
/**
|
||||
* Gradient of the scalar function to be optimized.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public class ObjectiveFunctionGradient implements OptimizationData {
|
||||
/** Function to be optimized. */
|
||||
private final MultivariateVectorFunction gradient;
|
||||
|
||||
/**
|
||||
* @param g Gradient of the function to be optimized.
|
||||
*/
|
||||
public ObjectiveFunctionGradient(MultivariateVectorFunction g) {
|
||||
gradient = g;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the gradient of the function to be optimized.
|
||||
*
|
||||
* @return the objective function gradient.
|
||||
*/
|
||||
public MultivariateVectorFunction getObjectiveFunctionGradient() {
|
||||
return gradient;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,393 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.gradient;
|
||||
|
||||
import org.apache.commons.math3.analysis.UnivariateFunction;
|
||||
import org.apache.commons.math3.analysis.solvers.BrentSolver;
|
||||
import org.apache.commons.math3.analysis.solvers.UnivariateSolver;
|
||||
import org.apache.commons.math3.exception.MathInternalError;
|
||||
import org.apache.commons.math3.exception.MathIllegalStateException;
|
||||
import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.GradientMultivariateOptimizer;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
|
||||
/**
|
||||
* Non-linear conjugate gradient optimizer.
|
||||
* <p>
|
||||
* This class supports both the Fletcher-Reeves and the Polak-Ribière
|
||||
* update formulas for the conjugate search directions.
|
||||
* It also supports optional preconditioning.
|
||||
* </p>
|
||||
*
|
||||
* @version $Id: NonLinearConjugateGradientOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public class NonLinearConjugateGradientOptimizer
|
||||
extends GradientMultivariateOptimizer {
|
||||
/** Update formula for the beta parameter. */
|
||||
private final Formula updateFormula;
|
||||
/** Preconditioner (may be null). */
|
||||
private final Preconditioner preconditioner;
|
||||
/** solver to use in the line search (may be null). */
|
||||
private final UnivariateSolver solver;
|
||||
/** Initial step used to bracket the optimum in line search. */
|
||||
private double initialStep = 1;
|
||||
|
||||
/**
|
||||
* Constructor with default {@link BrentSolver line search solver} and
|
||||
* {@link IdentityPreconditioner preconditioner}.
|
||||
*
|
||||
* @param updateFormula formula to use for updating the β parameter,
|
||||
* must be one of {@link Formula#FLETCHER_REEVES} or
|
||||
* {@link Formula#POLAK_RIBIERE}.
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
|
||||
ConvergenceChecker<PointValuePair> checker) {
|
||||
this(updateFormula,
|
||||
checker,
|
||||
new BrentSolver(),
|
||||
new IdentityPreconditioner());
|
||||
}
|
||||
|
||||
/**
|
||||
* Available choices of update formulas for the updating the parameter
|
||||
* that is used to compute the successive conjugate search directions.
|
||||
* For non-linear conjugate gradients, there are
|
||||
* two formulas:
|
||||
* <ul>
|
||||
* <li>Fletcher-Reeves formula</li>
|
||||
* <li>Polak-Ribière formula</li>
|
||||
* </ul>
|
||||
*
|
||||
* On the one hand, the Fletcher-Reeves formula is guaranteed to converge
|
||||
* if the start point is close enough of the optimum whether the
|
||||
* Polak-Ribière formula may not converge in rare cases. On the
|
||||
* other hand, the Polak-Ribière formula is often faster when it
|
||||
* does converge. Polak-Ribière is often used.
|
||||
*
|
||||
* @since 2.0
|
||||
*/
|
||||
public static enum Formula {
|
||||
/** Fletcher-Reeves formula. */
|
||||
FLETCHER_REEVES,
|
||||
/** Polak-Ribière formula. */
|
||||
POLAK_RIBIERE
|
||||
}
|
||||
|
||||
/**
|
||||
* The initial step is a factor with respect to the search direction
|
||||
* (which itself is roughly related to the gradient of the function).
|
||||
* <br/>
|
||||
* It is used to find an interval that brackets the optimum in line
|
||||
* search.
|
||||
*
|
||||
* @since 3.1
|
||||
*/
|
||||
public static class BracketingStep implements OptimizationData {
|
||||
/** Initial step. */
|
||||
private final double initialStep;
|
||||
|
||||
/**
|
||||
* @param step Initial step for the bracket search.
|
||||
*/
|
||||
public BracketingStep(double step) {
|
||||
initialStep = step;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the initial step.
|
||||
*
|
||||
* @return the initial step.
|
||||
*/
|
||||
public double getBracketingStep() {
|
||||
return initialStep;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor with default {@link IdentityPreconditioner preconditioner}.
|
||||
*
|
||||
* @param updateFormula formula to use for updating the β parameter,
|
||||
* must be one of {@link Formula#FLETCHER_REEVES} or
|
||||
* {@link Formula#POLAK_RIBIERE}.
|
||||
* @param checker Convergence checker.
|
||||
* @param lineSearchSolver Solver to use during line search.
|
||||
*/
|
||||
public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
|
||||
ConvergenceChecker<PointValuePair> checker,
|
||||
final UnivariateSolver lineSearchSolver) {
|
||||
this(updateFormula,
|
||||
checker,
|
||||
lineSearchSolver,
|
||||
new IdentityPreconditioner());
|
||||
}
|
||||
|
||||
/**
|
||||
* @param updateFormula formula to use for updating the β parameter,
|
||||
* must be one of {@link Formula#FLETCHER_REEVES} or
|
||||
* {@link Formula#POLAK_RIBIERE}.
|
||||
* @param checker Convergence checker.
|
||||
* @param lineSearchSolver Solver to use during line search.
|
||||
* @param preconditioner Preconditioner.
|
||||
*/
|
||||
public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
|
||||
ConvergenceChecker<PointValuePair> checker,
|
||||
final UnivariateSolver lineSearchSolver,
|
||||
final Preconditioner preconditioner) {
|
||||
super(checker);
|
||||
|
||||
this.updateFormula = updateFormula;
|
||||
solver = lineSearchSolver;
|
||||
this.preconditioner = preconditioner;
|
||||
initialStep = 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.GoalType}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.ObjectiveFunction}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient}</li>
|
||||
* <li>{@link BracketingStep}</li>
|
||||
* </ul>
|
||||
* @return {@inheritDoc}
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations (of the objective function) is exceeded.
|
||||
*/
|
||||
@Override
|
||||
public PointValuePair optimize(OptimizationData... optData)
|
||||
throws TooManyEvaluationsException {
|
||||
// Retrieve settings.
|
||||
parseOptimizationData(optData);
|
||||
// Set up base class and perform computation.
|
||||
return super.optimize(optData);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
protected PointValuePair doOptimize() {
|
||||
final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
|
||||
final double[] point = getStartPoint();
|
||||
final GoalType goal = getGoalType();
|
||||
final int n = point.length;
|
||||
double[] r = computeObjectiveGradient(point);
|
||||
if (goal == GoalType.MINIMIZE) {
|
||||
for (int i = 0; i < n; i++) {
|
||||
r[i] = -r[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Initial search direction.
|
||||
double[] steepestDescent = preconditioner.precondition(point, r);
|
||||
double[] searchDirection = steepestDescent.clone();
|
||||
|
||||
double delta = 0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
delta += r[i] * searchDirection[i];
|
||||
}
|
||||
|
||||
PointValuePair current = null;
|
||||
int iter = 0;
|
||||
int maxEval = getMaxEvaluations();
|
||||
while (true) {
|
||||
++iter;
|
||||
|
||||
final double objective = computeObjectiveValue(point);
|
||||
PointValuePair previous = current;
|
||||
current = new PointValuePair(point, objective);
|
||||
if (previous != null) {
|
||||
if (checker.converged(iter, previous, current)) {
|
||||
// We have found an optimum.
|
||||
return current;
|
||||
}
|
||||
}
|
||||
|
||||
// Find the optimal step in the search direction.
|
||||
final UnivariateFunction lsf = new LineSearchFunction(point, searchDirection);
|
||||
final double uB = findUpperBound(lsf, 0, initialStep);
|
||||
// XXX Last parameters is set to a value close to zero in order to
|
||||
// work around the divergence problem in the "testCircleFitting"
|
||||
// unit test (see MATH-439).
|
||||
final double step = solver.solve(maxEval, lsf, 0, uB, 1e-15);
|
||||
maxEval -= solver.getEvaluations(); // Subtract used up evaluations.
|
||||
|
||||
// Validate new point.
|
||||
for (int i = 0; i < point.length; ++i) {
|
||||
point[i] += step * searchDirection[i];
|
||||
}
|
||||
|
||||
r = computeObjectiveGradient(point);
|
||||
if (goal == GoalType.MINIMIZE) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
r[i] = -r[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Compute beta.
|
||||
final double deltaOld = delta;
|
||||
final double[] newSteepestDescent = preconditioner.precondition(point, r);
|
||||
delta = 0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
delta += r[i] * newSteepestDescent[i];
|
||||
}
|
||||
|
||||
final double beta;
|
||||
switch (updateFormula) {
|
||||
case FLETCHER_REEVES:
|
||||
beta = delta / deltaOld;
|
||||
break;
|
||||
case POLAK_RIBIERE:
|
||||
double deltaMid = 0;
|
||||
for (int i = 0; i < r.length; ++i) {
|
||||
deltaMid += r[i] * steepestDescent[i];
|
||||
}
|
||||
beta = (delta - deltaMid) / deltaOld;
|
||||
break;
|
||||
default:
|
||||
// Should never happen.
|
||||
throw new MathInternalError();
|
||||
}
|
||||
steepestDescent = newSteepestDescent;
|
||||
|
||||
// Compute conjugate search direction.
|
||||
if (iter % n == 0 ||
|
||||
beta < 0) {
|
||||
// Break conjugation: reset search direction.
|
||||
searchDirection = steepestDescent.clone();
|
||||
} else {
|
||||
// Compute new conjugate search direction.
|
||||
for (int i = 0; i < n; ++i) {
|
||||
searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Scans the list of (required and optional) optimization data that
|
||||
* characterize the problem.
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link InitialStep}</li>
|
||||
* </ul>
|
||||
*/
|
||||
private void parseOptimizationData(OptimizationData... optData) {
|
||||
// The existing values (as set by the previous call) are reused if
|
||||
// not provided in the argument list.
|
||||
for (OptimizationData data : optData) {
|
||||
if (data instanceof BracketingStep) {
|
||||
initialStep = ((BracketingStep) data).getBracketingStep();
|
||||
// If more data must be parsed, this statement _must_ be
|
||||
// changed to "continue".
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the upper bound b ensuring bracketing of a root between a and b.
|
||||
*
|
||||
* @param f function whose root must be bracketed.
|
||||
* @param a lower bound of the interval.
|
||||
* @param h initial step to try.
|
||||
* @return b such that f(a) and f(b) have opposite signs.
|
||||
* @throws MathIllegalStateException if no bracket can be found.
|
||||
*/
|
||||
private double findUpperBound(final UnivariateFunction f,
|
||||
final double a, final double h) {
|
||||
final double yA = f.value(a);
|
||||
double yB = yA;
|
||||
for (double step = h; step < Double.MAX_VALUE; step *= FastMath.max(2, yA / yB)) {
|
||||
final double b = a + step;
|
||||
yB = f.value(b);
|
||||
if (yA * yB <= 0) {
|
||||
return b;
|
||||
}
|
||||
}
|
||||
throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
|
||||
}
|
||||
|
||||
/** Default identity preconditioner. */
|
||||
public static class IdentityPreconditioner implements Preconditioner {
|
||||
/** {@inheritDoc} */
|
||||
public double[] precondition(double[] variables, double[] r) {
|
||||
return r.clone();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal class for line search.
|
||||
* <p>
|
||||
* The function represented by this class is the dot product of
|
||||
* the objective function gradient and the search direction. Its
|
||||
* value is zero when the gradient is orthogonal to the search
|
||||
* direction, i.e. when the objective function value is a local
|
||||
* extremum along the search direction.
|
||||
* </p>
|
||||
*/
|
||||
private class LineSearchFunction implements UnivariateFunction {
|
||||
/** Current point. */
|
||||
private final double[] currentPoint;
|
||||
/** Search direction. */
|
||||
private final double[] searchDirection;
|
||||
|
||||
/**
|
||||
* @param point Current point.
|
||||
* @param direction Search direction.
|
||||
*/
|
||||
public LineSearchFunction(double[] point,
|
||||
double[] direction) {
|
||||
currentPoint = point.clone();
|
||||
searchDirection = direction.clone();
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double value(double x) {
|
||||
// current point in the search direction
|
||||
final double[] shiftedPoint = currentPoint.clone();
|
||||
for (int i = 0; i < shiftedPoint.length; ++i) {
|
||||
shiftedPoint[i] += x * searchDirection[i];
|
||||
}
|
||||
|
||||
// gradient of the objective function
|
||||
final double[] gradient = computeObjectiveGradient(shiftedPoint);
|
||||
|
||||
// dot product with the search direction
|
||||
double dotProduct = 0;
|
||||
for (int i = 0; i < gradient.length; ++i) {
|
||||
dotProduct += gradient[i] * searchDirection[i];
|
||||
}
|
||||
|
||||
return dotProduct;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.gradient;
|
||||
|
||||
/**
|
||||
* This interface represents a preconditioner for differentiable scalar
|
||||
* objective function optimizers.
|
||||
* @version $Id: Preconditioner.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public interface Preconditioner {
|
||||
/**
|
||||
* Precondition a search direction.
|
||||
* <p>
|
||||
* The returned preconditioned search direction must be computed fast or
|
||||
* the algorithm performances will drop drastically. A classical approach
|
||||
* is to compute only the diagonal elements of the hessian and to divide
|
||||
* the raw search direction by these elements if they are all positive.
|
||||
* If at least one of them is negative, it is safer to return a clone of
|
||||
* the raw search direction as if the hessian was the identity matrix. The
|
||||
* rationale for this simplified choice is that a negative diagonal element
|
||||
* means the current point is far from the optimum and preconditioning will
|
||||
* not be efficient anyway in this case.
|
||||
* </p>
|
||||
* @param point current point at which the search direction was computed
|
||||
* @param r raw search direction (i.e. opposite of the gradient)
|
||||
* @return approximation of H<sup>-1</sup>r where H is the objective function hessian
|
||||
*/
|
||||
double[] precondition(double[] point, double[] r);
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.gradient;
|
||||
|
||||
/**
|
||||
* This package provides optimization algorithms that require derivatives.
|
||||
*/
|
|
@ -0,0 +1,346 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.ZeroException;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
import org.apache.commons.math3.exception.NullArgumentException;
|
||||
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
|
||||
/**
|
||||
* This class implements the simplex concept.
|
||||
* It is intended to be used in conjunction with {@link SimplexOptimizer}.
|
||||
* <br/>
|
||||
* The initial configuration of the simplex is set by the constructors
|
||||
* {@link #AbstractSimplex(double[])} or {@link #AbstractSimplex(double[][])}.
|
||||
* The other {@link #AbstractSimplex(int) constructor} will set all steps
|
||||
* to 1, thus building a default configuration from a unit hypercube.
|
||||
* <br/>
|
||||
* Users <em>must</em> call the {@link #build(double[]) build} method in order
|
||||
* to create the data structure that will be acted on by the other methods of
|
||||
* this class.
|
||||
*
|
||||
* @see SimplexOptimizer
|
||||
* @version $Id: AbstractSimplex.java 1397759 2012-10-13 01:12:58Z erans $
|
||||
* @since 3.0
|
||||
*/
|
||||
public abstract class AbstractSimplex implements OptimizationData {
|
||||
/** Simplex. */
|
||||
private PointValuePair[] simplex;
|
||||
/** Start simplex configuration. */
|
||||
private double[][] startConfiguration;
|
||||
/** Simplex dimension (must be equal to {@code simplex.length - 1}). */
|
||||
private final int dimension;
|
||||
|
||||
/**
|
||||
* Build a unit hypercube simplex.
|
||||
*
|
||||
* @param n Dimension of the simplex.
|
||||
*/
|
||||
protected AbstractSimplex(int n) {
|
||||
this(n, 1d);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a hypercube simplex with the given side length.
|
||||
*
|
||||
* @param n Dimension of the simplex.
|
||||
* @param sideLength Length of the sides of the hypercube.
|
||||
*/
|
||||
protected AbstractSimplex(int n,
|
||||
double sideLength) {
|
||||
this(createHypercubeSteps(n, sideLength));
|
||||
}
|
||||
|
||||
/**
|
||||
* The start configuration for simplex is built from a box parallel to
|
||||
* the canonical axes of the space. The simplex is the subset of vertices
|
||||
* of a box parallel to the canonical axes. It is built as the path followed
|
||||
* while traveling from one vertex of the box to the diagonally opposite
|
||||
* vertex moving only along the box edges. The first vertex of the box will
|
||||
* be located at the start point of the optimization.
|
||||
* As an example, in dimension 3 a simplex has 4 vertices. Setting the
|
||||
* steps to (1, 10, 2) and the start point to (1, 1, 1) would imply the
|
||||
* start simplex would be: { (1, 1, 1), (2, 1, 1), (2, 11, 1), (2, 11, 3) }.
|
||||
* The first vertex would be set to the start point at (1, 1, 1) and the
|
||||
* last vertex would be set to the diagonally opposite vertex at (2, 11, 3).
|
||||
*
|
||||
* @param steps Steps along the canonical axes representing box edges. They
|
||||
* may be negative but not zero.
|
||||
* @throws NullArgumentException if {@code steps} is {@code null}.
|
||||
* @throws ZeroException if one of the steps is zero.
|
||||
*/
|
||||
protected AbstractSimplex(final double[] steps) {
|
||||
if (steps == null) {
|
||||
throw new NullArgumentException();
|
||||
}
|
||||
if (steps.length == 0) {
|
||||
throw new ZeroException();
|
||||
}
|
||||
dimension = steps.length;
|
||||
|
||||
// Only the relative position of the n final vertices with respect
|
||||
// to the first one are stored.
|
||||
startConfiguration = new double[dimension][dimension];
|
||||
for (int i = 0; i < dimension; i++) {
|
||||
final double[] vertexI = startConfiguration[i];
|
||||
for (int j = 0; j < i + 1; j++) {
|
||||
if (steps[j] == 0) {
|
||||
throw new ZeroException(LocalizedFormats.EQUAL_VERTICES_IN_SIMPLEX);
|
||||
}
|
||||
System.arraycopy(steps, 0, vertexI, 0, j + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The real initial simplex will be set up by moving the reference
|
||||
* simplex such that its first point is located at the start point of the
|
||||
* optimization.
|
||||
*
|
||||
* @param referenceSimplex Reference simplex.
|
||||
* @throws NotStrictlyPositiveException if the reference simplex does not
|
||||
* contain at least one point.
|
||||
* @throws DimensionMismatchException if there is a dimension mismatch
|
||||
* in the reference simplex.
|
||||
* @throws IllegalArgumentException if one of its vertices is duplicated.
|
||||
*/
|
||||
protected AbstractSimplex(final double[][] referenceSimplex) {
|
||||
if (referenceSimplex.length <= 0) {
|
||||
throw new NotStrictlyPositiveException(LocalizedFormats.SIMPLEX_NEED_ONE_POINT,
|
||||
referenceSimplex.length);
|
||||
}
|
||||
dimension = referenceSimplex.length - 1;
|
||||
|
||||
// Only the relative position of the n final vertices with respect
|
||||
// to the first one are stored.
|
||||
startConfiguration = new double[dimension][dimension];
|
||||
final double[] ref0 = referenceSimplex[0];
|
||||
|
||||
// Loop over vertices.
|
||||
for (int i = 0; i < referenceSimplex.length; i++) {
|
||||
final double[] refI = referenceSimplex[i];
|
||||
|
||||
// Safety checks.
|
||||
if (refI.length != dimension) {
|
||||
throw new DimensionMismatchException(refI.length, dimension);
|
||||
}
|
||||
for (int j = 0; j < i; j++) {
|
||||
final double[] refJ = referenceSimplex[j];
|
||||
boolean allEquals = true;
|
||||
for (int k = 0; k < dimension; k++) {
|
||||
if (refI[k] != refJ[k]) {
|
||||
allEquals = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (allEquals) {
|
||||
throw new MathIllegalArgumentException(LocalizedFormats.EQUAL_VERTICES_IN_SIMPLEX,
|
||||
i, j);
|
||||
}
|
||||
}
|
||||
|
||||
// Store vertex i position relative to vertex 0 position.
|
||||
if (i > 0) {
|
||||
final double[] confI = startConfiguration[i - 1];
|
||||
for (int k = 0; k < dimension; k++) {
|
||||
confI[k] = refI[k] - ref0[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get simplex dimension.
|
||||
*
|
||||
* @return the dimension of the simplex.
|
||||
*/
|
||||
public int getDimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get simplex size.
|
||||
* After calling the {@link #build(double[]) build} method, this method will
|
||||
* will be equivalent to {@code getDimension() + 1}.
|
||||
*
|
||||
* @return the size of the simplex.
|
||||
*/
|
||||
public int getSize() {
|
||||
return simplex.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the next simplex of the algorithm.
|
||||
*
|
||||
* @param evaluationFunction Evaluation function.
|
||||
* @param comparator Comparator to use to sort simplex vertices from best
|
||||
* to worst.
|
||||
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
|
||||
* if the algorithm fails to converge.
|
||||
*/
|
||||
public abstract void iterate(final MultivariateFunction evaluationFunction,
|
||||
final Comparator<PointValuePair> comparator);
|
||||
|
||||
/**
|
||||
* Build an initial simplex.
|
||||
*
|
||||
* @param startPoint First point of the simplex.
|
||||
* @throws DimensionMismatchException if the start point does not match
|
||||
* simplex dimension.
|
||||
*/
|
||||
public void build(final double[] startPoint) {
|
||||
if (dimension != startPoint.length) {
|
||||
throw new DimensionMismatchException(dimension, startPoint.length);
|
||||
}
|
||||
|
||||
// Set first vertex.
|
||||
simplex = new PointValuePair[dimension + 1];
|
||||
simplex[0] = new PointValuePair(startPoint, Double.NaN);
|
||||
|
||||
// Set remaining vertices.
|
||||
for (int i = 0; i < dimension; i++) {
|
||||
final double[] confI = startConfiguration[i];
|
||||
final double[] vertexI = new double[dimension];
|
||||
for (int k = 0; k < dimension; k++) {
|
||||
vertexI[k] = startPoint[k] + confI[k];
|
||||
}
|
||||
simplex[i + 1] = new PointValuePair(vertexI, Double.NaN);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluate all the non-evaluated points of the simplex.
|
||||
*
|
||||
* @param evaluationFunction Evaluation function.
|
||||
* @param comparator Comparator to use to sort simplex vertices from best to worst.
|
||||
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
|
||||
* if the maximal number of evaluations is exceeded.
|
||||
*/
|
||||
public void evaluate(final MultivariateFunction evaluationFunction,
|
||||
final Comparator<PointValuePair> comparator) {
|
||||
// Evaluate the objective function at all non-evaluated simplex points.
|
||||
for (int i = 0; i < simplex.length; i++) {
|
||||
final PointValuePair vertex = simplex[i];
|
||||
final double[] point = vertex.getPointRef();
|
||||
if (Double.isNaN(vertex.getValue())) {
|
||||
simplex[i] = new PointValuePair(point, evaluationFunction.value(point), false);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the simplex from best to worst.
|
||||
Arrays.sort(simplex, comparator);
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace the worst point of the simplex by a new point.
|
||||
*
|
||||
* @param pointValuePair Point to insert.
|
||||
* @param comparator Comparator to use for sorting the simplex vertices
|
||||
* from best to worst.
|
||||
*/
|
||||
protected void replaceWorstPoint(PointValuePair pointValuePair,
|
||||
final Comparator<PointValuePair> comparator) {
|
||||
for (int i = 0; i < dimension; i++) {
|
||||
if (comparator.compare(simplex[i], pointValuePair) > 0) {
|
||||
PointValuePair tmp = simplex[i];
|
||||
simplex[i] = pointValuePair;
|
||||
pointValuePair = tmp;
|
||||
}
|
||||
}
|
||||
simplex[dimension] = pointValuePair;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the points of the simplex.
|
||||
*
|
||||
* @return all the simplex points.
|
||||
*/
|
||||
public PointValuePair[] getPoints() {
|
||||
final PointValuePair[] copy = new PointValuePair[simplex.length];
|
||||
System.arraycopy(simplex, 0, copy, 0, simplex.length);
|
||||
return copy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the simplex point stored at the requested {@code index}.
|
||||
*
|
||||
* @param index Location.
|
||||
* @return the point at location {@code index}.
|
||||
*/
|
||||
public PointValuePair getPoint(int index) {
|
||||
if (index < 0 ||
|
||||
index >= simplex.length) {
|
||||
throw new OutOfRangeException(index, 0, simplex.length - 1);
|
||||
}
|
||||
return simplex[index];
|
||||
}
|
||||
|
||||
/**
|
||||
* Store a new point at location {@code index}.
|
||||
* Note that no deep-copy of {@code point} is performed.
|
||||
*
|
||||
* @param index Location.
|
||||
* @param point New value.
|
||||
*/
|
||||
protected void setPoint(int index, PointValuePair point) {
|
||||
if (index < 0 ||
|
||||
index >= simplex.length) {
|
||||
throw new OutOfRangeException(index, 0, simplex.length - 1);
|
||||
}
|
||||
simplex[index] = point;
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace all points.
|
||||
* Note that no deep-copy of {@code points} is performed.
|
||||
*
|
||||
* @param points New Points.
|
||||
*/
|
||||
protected void setPoints(PointValuePair[] points) {
|
||||
if (points.length != simplex.length) {
|
||||
throw new DimensionMismatchException(points.length, simplex.length);
|
||||
}
|
||||
simplex = points;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create steps for a unit hypercube.
|
||||
*
|
||||
* @param n Dimension of the hypercube.
|
||||
* @param sideLength Length of the sides of the hypercube.
|
||||
* @return the steps.
|
||||
*/
|
||||
private static double[] createHypercubeSteps(int n,
|
||||
double sideLength) {
|
||||
final double[] steps = new double[n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
steps[i] = sideLength;
|
||||
}
|
||||
return steps;
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,216 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
|
||||
|
||||
import java.util.Comparator;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
|
||||
/**
|
||||
* This class implements the multi-directional direct search method.
|
||||
*
|
||||
* @version $Id: MultiDirectionalSimplex.java 1364392 2012-07-22 18:27:12Z tn $
|
||||
* @since 3.0
|
||||
*/
|
||||
public class MultiDirectionalSimplex extends AbstractSimplex {
|
||||
/** Default value for {@link #khi}: {@value}. */
|
||||
private static final double DEFAULT_KHI = 2;
|
||||
/** Default value for {@link #gamma}: {@value}. */
|
||||
private static final double DEFAULT_GAMMA = 0.5;
|
||||
/** Expansion coefficient. */
|
||||
private final double khi;
|
||||
/** Contraction coefficient. */
|
||||
private final double gamma;
|
||||
|
||||
/**
|
||||
* Build a multi-directional simplex with default coefficients.
|
||||
* The default values are 2.0 for khi and 0.5 for gamma.
|
||||
*
|
||||
* @param n Dimension of the simplex.
|
||||
*/
|
||||
public MultiDirectionalSimplex(final int n) {
|
||||
this(n, 1d);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a multi-directional simplex with default coefficients.
|
||||
* The default values are 2.0 for khi and 0.5 for gamma.
|
||||
*
|
||||
* @param n Dimension of the simplex.
|
||||
* @param sideLength Length of the sides of the default (hypercube)
|
||||
* simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
|
||||
*/
|
||||
public MultiDirectionalSimplex(final int n, double sideLength) {
|
||||
this(n, sideLength, DEFAULT_KHI, DEFAULT_GAMMA);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a multi-directional simplex with specified coefficients.
|
||||
*
|
||||
* @param n Dimension of the simplex. See
|
||||
* {@link AbstractSimplex#AbstractSimplex(int,double)}.
|
||||
* @param khi Expansion coefficient.
|
||||
* @param gamma Contraction coefficient.
|
||||
*/
|
||||
public MultiDirectionalSimplex(final int n,
|
||||
final double khi, final double gamma) {
|
||||
this(n, 1d, khi, gamma);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a multi-directional simplex with specified coefficients.
|
||||
*
|
||||
* @param n Dimension of the simplex. See
|
||||
* {@link AbstractSimplex#AbstractSimplex(int,double)}.
|
||||
* @param sideLength Length of the sides of the default (hypercube)
|
||||
* simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
|
||||
* @param khi Expansion coefficient.
|
||||
* @param gamma Contraction coefficient.
|
||||
*/
|
||||
public MultiDirectionalSimplex(final int n, double sideLength,
|
||||
final double khi, final double gamma) {
|
||||
super(n, sideLength);
|
||||
|
||||
this.khi = khi;
|
||||
this.gamma = gamma;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a multi-directional simplex with default coefficients.
|
||||
* The default values are 2.0 for khi and 0.5 for gamma.
|
||||
*
|
||||
* @param steps Steps along the canonical axes representing box edges.
|
||||
* They may be negative but not zero. See
|
||||
*/
|
||||
public MultiDirectionalSimplex(final double[] steps) {
|
||||
this(steps, DEFAULT_KHI, DEFAULT_GAMMA);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a multi-directional simplex with specified coefficients.
|
||||
*
|
||||
* @param steps Steps along the canonical axes representing box edges.
|
||||
* They may be negative but not zero. See
|
||||
* {@link AbstractSimplex#AbstractSimplex(double[])}.
|
||||
* @param khi Expansion coefficient.
|
||||
* @param gamma Contraction coefficient.
|
||||
*/
|
||||
public MultiDirectionalSimplex(final double[] steps,
|
||||
final double khi, final double gamma) {
|
||||
super(steps);
|
||||
|
||||
this.khi = khi;
|
||||
this.gamma = gamma;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a multi-directional simplex with default coefficients.
|
||||
* The default values are 2.0 for khi and 0.5 for gamma.
|
||||
*
|
||||
* @param referenceSimplex Reference simplex. See
|
||||
* {@link AbstractSimplex#AbstractSimplex(double[][])}.
|
||||
*/
|
||||
public MultiDirectionalSimplex(final double[][] referenceSimplex) {
|
||||
this(referenceSimplex, DEFAULT_KHI, DEFAULT_GAMMA);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a multi-directional simplex with specified coefficients.
|
||||
*
|
||||
* @param referenceSimplex Reference simplex. See
|
||||
* {@link AbstractSimplex#AbstractSimplex(double[][])}.
|
||||
* @param khi Expansion coefficient.
|
||||
* @param gamma Contraction coefficient.
|
||||
* @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
|
||||
* if the reference simplex does not contain at least one point.
|
||||
* @throws org.apache.commons.math3.exception.DimensionMismatchException
|
||||
* if there is a dimension mismatch in the reference simplex.
|
||||
*/
|
||||
public MultiDirectionalSimplex(final double[][] referenceSimplex,
|
||||
final double khi, final double gamma) {
|
||||
super(referenceSimplex);
|
||||
|
||||
this.khi = khi;
|
||||
this.gamma = gamma;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public void iterate(final MultivariateFunction evaluationFunction,
|
||||
final Comparator<PointValuePair> comparator) {
|
||||
// Save the original simplex.
|
||||
final PointValuePair[] original = getPoints();
|
||||
final PointValuePair best = original[0];
|
||||
|
||||
// Perform a reflection step.
|
||||
final PointValuePair reflected = evaluateNewSimplex(evaluationFunction,
|
||||
original, 1, comparator);
|
||||
if (comparator.compare(reflected, best) < 0) {
|
||||
// Compute the expanded simplex.
|
||||
final PointValuePair[] reflectedSimplex = getPoints();
|
||||
final PointValuePair expanded = evaluateNewSimplex(evaluationFunction,
|
||||
original, khi, comparator);
|
||||
if (comparator.compare(reflected, expanded) <= 0) {
|
||||
// Keep the reflected simplex.
|
||||
setPoints(reflectedSimplex);
|
||||
}
|
||||
// Keep the expanded simplex.
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute the contracted simplex.
|
||||
evaluateNewSimplex(evaluationFunction, original, gamma, comparator);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute and evaluate a new simplex.
|
||||
*
|
||||
* @param evaluationFunction Evaluation function.
|
||||
* @param original Original simplex (to be preserved).
|
||||
* @param coeff Linear coefficient.
|
||||
* @param comparator Comparator to use to sort simplex vertices from best
|
||||
* to poorest.
|
||||
* @return the best point in the transformed simplex.
|
||||
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
|
||||
* if the maximal number of evaluations is exceeded.
|
||||
*/
|
||||
private PointValuePair evaluateNewSimplex(final MultivariateFunction evaluationFunction,
|
||||
final PointValuePair[] original,
|
||||
final double coeff,
|
||||
final Comparator<PointValuePair> comparator) {
|
||||
final double[] xSmallest = original[0].getPointRef();
|
||||
// Perform a linear transformation on all the simplex points,
|
||||
// except the first one.
|
||||
setPoint(0, original[0]);
|
||||
final int dim = getDimension();
|
||||
for (int i = 1; i < getSize(); i++) {
|
||||
final double[] xOriginal = original[i].getPointRef();
|
||||
final double[] xTransformed = new double[dim];
|
||||
for (int j = 0; j < dim; j++) {
|
||||
xTransformed[j] = xSmallest[j] + coeff * (xSmallest[j] - xOriginal[j]);
|
||||
}
|
||||
setPoint(i, new PointValuePair(xTransformed, Double.NaN, false));
|
||||
}
|
||||
|
||||
// Evaluate the simplex.
|
||||
evaluate(evaluationFunction, comparator);
|
||||
|
||||
return getPoint(0);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,281 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
|
||||
|
||||
import java.util.Comparator;
|
||||
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
|
||||
/**
|
||||
* This class implements the Nelder-Mead simplex algorithm.
|
||||
*
|
||||
* @version $Id: NelderMeadSimplex.java 1364392 2012-07-22 18:27:12Z tn $
|
||||
* @since 3.0
|
||||
*/
|
||||
public class NelderMeadSimplex extends AbstractSimplex {
|
||||
/** Default value for {@link #rho}: {@value}. */
|
||||
private static final double DEFAULT_RHO = 1;
|
||||
/** Default value for {@link #khi}: {@value}. */
|
||||
private static final double DEFAULT_KHI = 2;
|
||||
/** Default value for {@link #gamma}: {@value}. */
|
||||
private static final double DEFAULT_GAMMA = 0.5;
|
||||
/** Default value for {@link #sigma}: {@value}. */
|
||||
private static final double DEFAULT_SIGMA = 0.5;
|
||||
/** Reflection coefficient. */
|
||||
private final double rho;
|
||||
/** Expansion coefficient. */
|
||||
private final double khi;
|
||||
/** Contraction coefficient. */
|
||||
private final double gamma;
|
||||
/** Shrinkage coefficient. */
|
||||
private final double sigma;
|
||||
|
||||
/**
|
||||
* Build a Nelder-Mead simplex with default coefficients.
|
||||
* The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
|
||||
* for both gamma and sigma.
|
||||
*
|
||||
* @param n Dimension of the simplex.
|
||||
*/
|
||||
public NelderMeadSimplex(final int n) {
|
||||
this(n, 1d);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a Nelder-Mead simplex with default coefficients.
|
||||
* The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
|
||||
* for both gamma and sigma.
|
||||
*
|
||||
* @param n Dimension of the simplex.
|
||||
* @param sideLength Length of the sides of the default (hypercube)
|
||||
* simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
|
||||
*/
|
||||
public NelderMeadSimplex(final int n, double sideLength) {
|
||||
this(n, sideLength,
|
||||
DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a Nelder-Mead simplex with specified coefficients.
|
||||
*
|
||||
* @param n Dimension of the simplex. See
|
||||
* {@link AbstractSimplex#AbstractSimplex(int,double)}.
|
||||
* @param sideLength Length of the sides of the default (hypercube)
|
||||
* simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
|
||||
* @param rho Reflection coefficient.
|
||||
* @param khi Expansion coefficient.
|
||||
* @param gamma Contraction coefficient.
|
||||
* @param sigma Shrinkage coefficient.
|
||||
*/
|
||||
public NelderMeadSimplex(final int n, double sideLength,
|
||||
final double rho, final double khi,
|
||||
final double gamma, final double sigma) {
|
||||
super(n, sideLength);
|
||||
|
||||
this.rho = rho;
|
||||
this.khi = khi;
|
||||
this.gamma = gamma;
|
||||
this.sigma = sigma;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a Nelder-Mead simplex with specified coefficients.
|
||||
*
|
||||
* @param n Dimension of the simplex. See
|
||||
* {@link AbstractSimplex#AbstractSimplex(int)}.
|
||||
* @param rho Reflection coefficient.
|
||||
* @param khi Expansion coefficient.
|
||||
* @param gamma Contraction coefficient.
|
||||
* @param sigma Shrinkage coefficient.
|
||||
*/
|
||||
public NelderMeadSimplex(final int n,
|
||||
final double rho, final double khi,
|
||||
final double gamma, final double sigma) {
|
||||
this(n, 1d, rho, khi, gamma, sigma);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a Nelder-Mead simplex with default coefficients.
|
||||
* The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
|
||||
* for both gamma and sigma.
|
||||
*
|
||||
* @param steps Steps along the canonical axes representing box edges.
|
||||
* They may be negative but not zero. See
|
||||
*/
|
||||
public NelderMeadSimplex(final double[] steps) {
|
||||
this(steps, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a Nelder-Mead simplex with specified coefficients.
|
||||
*
|
||||
* @param steps Steps along the canonical axes representing box edges.
|
||||
* They may be negative but not zero. See
|
||||
* {@link AbstractSimplex#AbstractSimplex(double[])}.
|
||||
* @param rho Reflection coefficient.
|
||||
* @param khi Expansion coefficient.
|
||||
* @param gamma Contraction coefficient.
|
||||
* @param sigma Shrinkage coefficient.
|
||||
* @throws IllegalArgumentException if one of the steps is zero.
|
||||
*/
|
||||
public NelderMeadSimplex(final double[] steps,
|
||||
final double rho, final double khi,
|
||||
final double gamma, final double sigma) {
|
||||
super(steps);
|
||||
|
||||
this.rho = rho;
|
||||
this.khi = khi;
|
||||
this.gamma = gamma;
|
||||
this.sigma = sigma;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a Nelder-Mead simplex with default coefficients.
|
||||
* The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
|
||||
* for both gamma and sigma.
|
||||
*
|
||||
* @param referenceSimplex Reference simplex. See
|
||||
* {@link AbstractSimplex#AbstractSimplex(double[][])}.
|
||||
*/
|
||||
public NelderMeadSimplex(final double[][] referenceSimplex) {
|
||||
this(referenceSimplex, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a Nelder-Mead simplex with specified coefficients.
|
||||
*
|
||||
* @param referenceSimplex Reference simplex. See
|
||||
* {@link AbstractSimplex#AbstractSimplex(double[][])}.
|
||||
* @param rho Reflection coefficient.
|
||||
* @param khi Expansion coefficient.
|
||||
* @param gamma Contraction coefficient.
|
||||
* @param sigma Shrinkage coefficient.
|
||||
* @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
|
||||
* if the reference simplex does not contain at least one point.
|
||||
* @throws org.apache.commons.math3.exception.DimensionMismatchException
|
||||
* if there is a dimension mismatch in the reference simplex.
|
||||
*/
|
||||
public NelderMeadSimplex(final double[][] referenceSimplex,
|
||||
final double rho, final double khi,
|
||||
final double gamma, final double sigma) {
|
||||
super(referenceSimplex);
|
||||
|
||||
this.rho = rho;
|
||||
this.khi = khi;
|
||||
this.gamma = gamma;
|
||||
this.sigma = sigma;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public void iterate(final MultivariateFunction evaluationFunction,
|
||||
final Comparator<PointValuePair> comparator) {
|
||||
// The simplex has n + 1 points if dimension is n.
|
||||
final int n = getDimension();
|
||||
|
||||
// Interesting values.
|
||||
final PointValuePair best = getPoint(0);
|
||||
final PointValuePair secondBest = getPoint(n - 1);
|
||||
final PointValuePair worst = getPoint(n);
|
||||
final double[] xWorst = worst.getPointRef();
|
||||
|
||||
// Compute the centroid of the best vertices (dismissing the worst
|
||||
// point at index n).
|
||||
final double[] centroid = new double[n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
final double[] x = getPoint(i).getPointRef();
|
||||
for (int j = 0; j < n; j++) {
|
||||
centroid[j] += x[j];
|
||||
}
|
||||
}
|
||||
final double scaling = 1.0 / n;
|
||||
for (int j = 0; j < n; j++) {
|
||||
centroid[j] *= scaling;
|
||||
}
|
||||
|
||||
// compute the reflection point
|
||||
final double[] xR = new double[n];
|
||||
for (int j = 0; j < n; j++) {
|
||||
xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
|
||||
}
|
||||
final PointValuePair reflected
|
||||
= new PointValuePair(xR, evaluationFunction.value(xR), false);
|
||||
|
||||
if (comparator.compare(best, reflected) <= 0 &&
|
||||
comparator.compare(reflected, secondBest) < 0) {
|
||||
// Accept the reflected point.
|
||||
replaceWorstPoint(reflected, comparator);
|
||||
} else if (comparator.compare(reflected, best) < 0) {
|
||||
// Compute the expansion point.
|
||||
final double[] xE = new double[n];
|
||||
for (int j = 0; j < n; j++) {
|
||||
xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
|
||||
}
|
||||
final PointValuePair expanded
|
||||
= new PointValuePair(xE, evaluationFunction.value(xE), false);
|
||||
|
||||
if (comparator.compare(expanded, reflected) < 0) {
|
||||
// Accept the expansion point.
|
||||
replaceWorstPoint(expanded, comparator);
|
||||
} else {
|
||||
// Accept the reflected point.
|
||||
replaceWorstPoint(reflected, comparator);
|
||||
}
|
||||
} else {
|
||||
if (comparator.compare(reflected, worst) < 0) {
|
||||
// Perform an outside contraction.
|
||||
final double[] xC = new double[n];
|
||||
for (int j = 0; j < n; j++) {
|
||||
xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
|
||||
}
|
||||
final PointValuePair outContracted
|
||||
= new PointValuePair(xC, evaluationFunction.value(xC), false);
|
||||
if (comparator.compare(outContracted, reflected) <= 0) {
|
||||
// Accept the contraction point.
|
||||
replaceWorstPoint(outContracted, comparator);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
// Perform an inside contraction.
|
||||
final double[] xC = new double[n];
|
||||
for (int j = 0; j < n; j++) {
|
||||
xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
|
||||
}
|
||||
final PointValuePair inContracted
|
||||
= new PointValuePair(xC, evaluationFunction.value(xC), false);
|
||||
|
||||
if (comparator.compare(inContracted, worst) < 0) {
|
||||
// Accept the contraction point.
|
||||
replaceWorstPoint(inContracted, comparator);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Perform a shrink.
|
||||
final double[] xSmallest = getPoint(0).getPointRef();
|
||||
for (int i = 1; i <= n; i++) {
|
||||
final double[] x = getPoint(i).getPoint();
|
||||
for (int j = 0; j < n; j++) {
|
||||
x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
|
||||
}
|
||||
setPoint(i, new PointValuePair(x, Double.NaN, false));
|
||||
}
|
||||
evaluate(evaluationFunction, comparator);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,356 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
|
||||
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.util.MathArrays;
|
||||
import org.apache.commons.math3.analysis.UnivariateFunction;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.optim.MaxEval;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer;
|
||||
import org.apache.commons.math3.optim.univariate.BracketFinder;
|
||||
import org.apache.commons.math3.optim.univariate.BrentOptimizer;
|
||||
import org.apache.commons.math3.optim.univariate.UnivariatePointValuePair;
|
||||
import org.apache.commons.math3.optim.univariate.SimpleUnivariateValueChecker;
|
||||
import org.apache.commons.math3.optim.univariate.SearchInterval;
|
||||
import org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction;
|
||||
|
||||
/**
|
||||
* Powell algorithm.
|
||||
* This code is translated and adapted from the Python version of this
|
||||
* algorithm (as implemented in module {@code optimize.py} v0.5 of
|
||||
* <em>SciPy</em>).
|
||||
* <br/>
|
||||
* The default stopping criterion is based on the differences of the
|
||||
* function value between two successive iterations. It is however possible
|
||||
* to define a custom convergence checker that might terminate the algorithm
|
||||
* earlier.
|
||||
* <br/>
|
||||
* The internal line search optimizer is a {@link BrentOptimizer} with a
|
||||
* convergence checker set to {@link SimpleUnivariateValueChecker}.
|
||||
*
|
||||
* @version $Id: PowellOptimizer.java 1413594 2012-11-26 13:16:39Z erans $
|
||||
* @since 2.2
|
||||
*/
|
||||
public class PowellOptimizer
|
||||
extends MultivariateOptimizer {
|
||||
/**
|
||||
* Minimum relative tolerance.
|
||||
*/
|
||||
private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
|
||||
/**
|
||||
* Relative threshold.
|
||||
*/
|
||||
private final double relativeThreshold;
|
||||
/**
|
||||
* Absolute threshold.
|
||||
*/
|
||||
private final double absoluteThreshold;
|
||||
/**
|
||||
* Line search.
|
||||
*/
|
||||
private final LineSearch line;
|
||||
|
||||
/**
|
||||
* This constructor allows to specify a user-defined convergence checker,
|
||||
* in addition to the parameters that control the default convergence
|
||||
* checking procedure.
|
||||
* <br/>
|
||||
* The internal line search tolerances are set to the square-root of their
|
||||
* corresponding value in the multivariate optimizer.
|
||||
*
|
||||
* @param rel Relative threshold.
|
||||
* @param abs Absolute threshold.
|
||||
* @param checker Convergence checker.
|
||||
* @throws NotStrictlyPositiveException if {@code abs <= 0}.
|
||||
* @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
|
||||
*/
|
||||
public PowellOptimizer(double rel,
|
||||
double abs,
|
||||
ConvergenceChecker<PointValuePair> checker) {
|
||||
this(rel, abs, FastMath.sqrt(rel), FastMath.sqrt(abs), checker);
|
||||
}
|
||||
|
||||
/**
|
||||
* This constructor allows to specify a user-defined convergence checker,
|
||||
* in addition to the parameters that control the default convergence
|
||||
* checking procedure and the line search tolerances.
|
||||
*
|
||||
* @param rel Relative threshold for this optimizer.
|
||||
* @param abs Absolute threshold for this optimizer.
|
||||
* @param lineRel Relative threshold for the internal line search optimizer.
|
||||
* @param lineAbs Absolute threshold for the internal line search optimizer.
|
||||
* @param checker Convergence checker.
|
||||
* @throws NotStrictlyPositiveException if {@code abs <= 0}.
|
||||
* @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
|
||||
*/
|
||||
public PowellOptimizer(double rel,
|
||||
double abs,
|
||||
double lineRel,
|
||||
double lineAbs,
|
||||
ConvergenceChecker<PointValuePair> checker) {
|
||||
super(checker);
|
||||
|
||||
if (rel < MIN_RELATIVE_TOLERANCE) {
|
||||
throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
|
||||
}
|
||||
if (abs <= 0) {
|
||||
throw new NotStrictlyPositiveException(abs);
|
||||
}
|
||||
relativeThreshold = rel;
|
||||
absoluteThreshold = abs;
|
||||
|
||||
// Create the line search optimizer.
|
||||
line = new LineSearch(lineRel,
|
||||
lineAbs);
|
||||
}
|
||||
|
||||
/**
|
||||
* The parameters control the default convergence checking procedure.
|
||||
* <br/>
|
||||
* The internal line search tolerances are set to the square-root of their
|
||||
* corresponding value in the multivariate optimizer.
|
||||
*
|
||||
* @param rel Relative threshold.
|
||||
* @param abs Absolute threshold.
|
||||
* @throws NotStrictlyPositiveException if {@code abs <= 0}.
|
||||
* @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
|
||||
*/
|
||||
public PowellOptimizer(double rel,
|
||||
double abs) {
|
||||
this(rel, abs, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an instance with the default convergence checking procedure.
|
||||
*
|
||||
* @param rel Relative threshold.
|
||||
* @param abs Absolute threshold.
|
||||
* @param lineRel Relative threshold for the internal line search optimizer.
|
||||
* @param lineAbs Absolute threshold for the internal line search optimizer.
|
||||
* @throws NotStrictlyPositiveException if {@code abs <= 0}.
|
||||
* @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
|
||||
*/
|
||||
public PowellOptimizer(double rel,
|
||||
double abs,
|
||||
double lineRel,
|
||||
double lineAbs) {
|
||||
this(rel, abs, lineRel, lineAbs, null);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
protected PointValuePair doOptimize() {
|
||||
final GoalType goal = getGoalType();
|
||||
final double[] guess = getStartPoint();
|
||||
final int n = guess.length;
|
||||
|
||||
final double[][] direc = new double[n][n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
direc[i][i] = 1;
|
||||
}
|
||||
|
||||
final ConvergenceChecker<PointValuePair> checker
|
||||
= getConvergenceChecker();
|
||||
|
||||
double[] x = guess;
|
||||
double fVal = computeObjectiveValue(x);
|
||||
double[] x1 = x.clone();
|
||||
int iter = 0;
|
||||
while (true) {
|
||||
++iter;
|
||||
|
||||
double fX = fVal;
|
||||
double fX2 = 0;
|
||||
double delta = 0;
|
||||
int bigInd = 0;
|
||||
double alphaMin = 0;
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
final double[] d = MathArrays.copyOf(direc[i]);
|
||||
|
||||
fX2 = fVal;
|
||||
|
||||
final UnivariatePointValuePair optimum = line.search(x, d);
|
||||
fVal = optimum.getValue();
|
||||
alphaMin = optimum.getPoint();
|
||||
final double[][] result = newPointAndDirection(x, d, alphaMin);
|
||||
x = result[0];
|
||||
|
||||
if ((fX2 - fVal) > delta) {
|
||||
delta = fX2 - fVal;
|
||||
bigInd = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Default convergence check.
|
||||
boolean stop = 2 * (fX - fVal) <=
|
||||
(relativeThreshold * (FastMath.abs(fX) + FastMath.abs(fVal)) +
|
||||
absoluteThreshold);
|
||||
|
||||
final PointValuePair previous = new PointValuePair(x1, fX);
|
||||
final PointValuePair current = new PointValuePair(x, fVal);
|
||||
if (!stop) { // User-defined stopping criteria.
|
||||
if (checker != null) {
|
||||
stop = checker.converged(iter, previous, current);
|
||||
}
|
||||
}
|
||||
if (stop) {
|
||||
if (goal == GoalType.MINIMIZE) {
|
||||
return (fVal < fX) ? current : previous;
|
||||
} else {
|
||||
return (fVal > fX) ? current : previous;
|
||||
}
|
||||
}
|
||||
|
||||
final double[] d = new double[n];
|
||||
final double[] x2 = new double[n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
d[i] = x[i] - x1[i];
|
||||
x2[i] = 2 * x[i] - x1[i];
|
||||
}
|
||||
|
||||
x1 = x.clone();
|
||||
fX2 = computeObjectiveValue(x2);
|
||||
|
||||
if (fX > fX2) {
|
||||
double t = 2 * (fX + fX2 - 2 * fVal);
|
||||
double temp = fX - fVal - delta;
|
||||
t *= temp * temp;
|
||||
temp = fX - fX2;
|
||||
t -= delta * temp * temp;
|
||||
|
||||
if (t < 0.0) {
|
||||
final UnivariatePointValuePair optimum = line.search(x, d);
|
||||
fVal = optimum.getValue();
|
||||
alphaMin = optimum.getPoint();
|
||||
final double[][] result = newPointAndDirection(x, d, alphaMin);
|
||||
x = result[0];
|
||||
|
||||
final int lastInd = n - 1;
|
||||
direc[bigInd] = direc[lastInd];
|
||||
direc[lastInd] = result[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute a new point (in the original space) and a new direction
|
||||
* vector, resulting from the line search.
|
||||
*
|
||||
* @param p Point used in the line search.
|
||||
* @param d Direction used in the line search.
|
||||
* @param optimum Optimum found by the line search.
|
||||
* @return a 2-element array containing the new point (at index 0) and
|
||||
* the new direction (at index 1).
|
||||
*/
|
||||
private double[][] newPointAndDirection(double[] p,
|
||||
double[] d,
|
||||
double optimum) {
|
||||
final int n = p.length;
|
||||
final double[] nP = new double[n];
|
||||
final double[] nD = new double[n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
nD[i] = d[i] * optimum;
|
||||
nP[i] = p[i] + nD[i];
|
||||
}
|
||||
|
||||
final double[][] result = new double[2][];
|
||||
result[0] = nP;
|
||||
result[1] = nD;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Class for finding the minimum of the objective function along a given
|
||||
* direction.
|
||||
*/
|
||||
private class LineSearch extends BrentOptimizer {
|
||||
/**
|
||||
* Value that will pass the precondition check for {@link BrentOptimizer}
|
||||
* but will not pass the convergence check, so that the custom checker
|
||||
* will always decide when to stop the line search.
|
||||
*/
|
||||
private static final double REL_TOL_UNUSED = 1e-15;
|
||||
/**
|
||||
* Value that will pass the precondition check for {@link BrentOptimizer}
|
||||
* but will not pass the convergence check, so that the custom checker
|
||||
* will always decide when to stop the line search.
|
||||
*/
|
||||
private static final double ABS_TOL_UNUSED = Double.MIN_VALUE;
|
||||
/**
|
||||
* Automatic bracketing.
|
||||
*/
|
||||
private final BracketFinder bracket = new BracketFinder();
|
||||
|
||||
/**
|
||||
* The "BrentOptimizer" default stopping criterion uses the tolerances
|
||||
* to check the domain (point) values, not the function values.
|
||||
* We thus create a custom checker to use function values.
|
||||
*
|
||||
* @param rel Relative threshold.
|
||||
* @param abs Absolute threshold.
|
||||
*/
|
||||
LineSearch(double rel,
|
||||
double abs) {
|
||||
super(REL_TOL_UNUSED,
|
||||
ABS_TOL_UNUSED,
|
||||
new SimpleUnivariateValueChecker(rel, abs));
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the minimum of the function {@code f(p + alpha * d)}.
|
||||
*
|
||||
* @param p Starting point.
|
||||
* @param d Search direction.
|
||||
* @return the optimum.
|
||||
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
|
||||
* if the number of evaluations is exceeded.
|
||||
*/
|
||||
public UnivariatePointValuePair search(final double[] p, final double[] d) {
|
||||
final int n = p.length;
|
||||
final UnivariateFunction f = new UnivariateFunction() {
|
||||
public double value(double alpha) {
|
||||
final double[] x = new double[n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
x[i] = p[i] + alpha * d[i];
|
||||
}
|
||||
final double obj = PowellOptimizer.this.computeObjectiveValue(x);
|
||||
return obj;
|
||||
}
|
||||
};
|
||||
|
||||
final GoalType goal = PowellOptimizer.this.getGoalType();
|
||||
bracket.search(f, goal, 0, 1);
|
||||
// Passing "MAX_VALUE" as a dummy value because it is the enclosing
|
||||
// class that counts the number of evaluations (and will eventually
|
||||
// generate the exception).
|
||||
return optimize(new MaxEval(Integer.MAX_VALUE),
|
||||
new UnivariateObjectiveFunction(f),
|
||||
goal,
|
||||
new SearchInterval(bracket.getLo(),
|
||||
bracket.getHi(),
|
||||
bracket.getMid()));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,201 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
|
||||
|
||||
import java.util.Comparator;
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.exception.NullArgumentException;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.SimpleValueChecker;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer;
|
||||
|
||||
/**
|
||||
* This class implements simplex-based direct search optimization.
|
||||
*
|
||||
* <p>
|
||||
* Direct search methods only use objective function values, they do
|
||||
* not need derivatives and don't either try to compute approximation
|
||||
* of the derivatives. According to a 1996 paper by Margaret H. Wright
|
||||
* (<a href="http://cm.bell-labs.com/cm/cs/doc/96/4-02.ps.gz">Direct
|
||||
* Search Methods: Once Scorned, Now Respectable</a>), they are used
|
||||
* when either the computation of the derivative is impossible (noisy
|
||||
* functions, unpredictable discontinuities) or difficult (complexity,
|
||||
* computation cost). In the first cases, rather than an optimum, a
|
||||
* <em>not too bad</em> point is desired. In the latter cases, an
|
||||
* optimum is desired but cannot be reasonably found. In all cases
|
||||
* direct search methods can be useful.
|
||||
* </p>
|
||||
* <p>
|
||||
* Simplex-based direct search methods are based on comparison of
|
||||
* the objective function values at the vertices of a simplex (which is a
|
||||
* set of n+1 points in dimension n) that is updated by the algorithms
|
||||
* steps.
|
||||
* <p>
|
||||
* <p>
|
||||
* The simplex update procedure ({@link NelderMeadSimplex} or
|
||||
* {@link MultiDirectionalSimplex}) must be passed to the
|
||||
* {@code optimize} method.
|
||||
* </p>
|
||||
* <p>
|
||||
* Each call to {@code optimize} will re-use the start configuration of
|
||||
* the current simplex and move it such that its first vertex is at the
|
||||
* provided start point of the optimization.
|
||||
* If the {@code optimize} method is called to solve a different problem
|
||||
* and the number of parameters change, the simplex must be re-initialized
|
||||
* to one with the appropriate dimensions.
|
||||
* </p>
|
||||
* <p>
|
||||
* Convergence is checked by providing the <em>worst</em> points of
|
||||
* previous and current simplex to the convergence checker, not the best
|
||||
* ones.
|
||||
* </p>
|
||||
* <p>
|
||||
* This simplex optimizer implementation does not directly support constrained
|
||||
* optimization with simple bounds; so, for such optimizations, either a more
|
||||
* dedicated algorithm must be used like
|
||||
* {@link CMAESOptimizer} or {@link BOBYQAOptimizer}, or the objective
|
||||
* function must be wrapped in an adapter like
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionMappingAdapter
|
||||
* MultivariateFunctionMappingAdapter} or
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionPenaltyAdapter
|
||||
* MultivariateFunctionPenaltyAdapter}.
|
||||
* </p>
|
||||
*
|
||||
* @version $Id: SimplexOptimizer.java 1397759 2012-10-13 01:12:58Z erans $
|
||||
* @since 3.0
|
||||
*/
|
||||
public class SimplexOptimizer extends MultivariateOptimizer {
|
||||
/** Simplex update rule. */
|
||||
private AbstractSimplex simplex;
|
||||
|
||||
/**
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
public SimplexOptimizer(ConvergenceChecker<PointValuePair> checker) {
|
||||
super(checker);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param rel Relative threshold.
|
||||
* @param abs Absolute threshold.
|
||||
*/
|
||||
public SimplexOptimizer(double rel, double abs) {
|
||||
this(new SimpleValueChecker(rel, abs));
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
|
||||
* <li>{@link AbstractSimplex}</li>
|
||||
* </ul>
|
||||
* @return {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public PointValuePair optimize(OptimizationData... optData) {
|
||||
// Retrieve settings
|
||||
parseOptimizationData(optData);
|
||||
// Set up base class and perform computation.
|
||||
return super.optimize(optData);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
protected PointValuePair doOptimize() {
|
||||
if (simplex == null) {
|
||||
throw new NullArgumentException();
|
||||
}
|
||||
|
||||
// Indirect call to "computeObjectiveValue" in order to update the
|
||||
// evaluations counter.
|
||||
final MultivariateFunction evalFunc
|
||||
= new MultivariateFunction() {
|
||||
public double value(double[] point) {
|
||||
return computeObjectiveValue(point);
|
||||
}
|
||||
};
|
||||
|
||||
final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
|
||||
final Comparator<PointValuePair> comparator
|
||||
= new Comparator<PointValuePair>() {
|
||||
public int compare(final PointValuePair o1,
|
||||
final PointValuePair o2) {
|
||||
final double v1 = o1.getValue();
|
||||
final double v2 = o2.getValue();
|
||||
return isMinim ? Double.compare(v1, v2) : Double.compare(v2, v1);
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize search.
|
||||
simplex.build(getStartPoint());
|
||||
simplex.evaluate(evalFunc, comparator);
|
||||
|
||||
PointValuePair[] previous = null;
|
||||
int iteration = 0;
|
||||
final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
|
||||
while (true) {
|
||||
if (iteration > 0) {
|
||||
boolean converged = true;
|
||||
for (int i = 0; i < simplex.getSize(); i++) {
|
||||
PointValuePair prev = previous[i];
|
||||
converged = converged &&
|
||||
checker.converged(iteration, prev, simplex.getPoint(i));
|
||||
}
|
||||
if (converged) {
|
||||
// We have found an optimum.
|
||||
return simplex.getPoint(0);
|
||||
}
|
||||
}
|
||||
|
||||
// We still need to search.
|
||||
previous = simplex.getPoints();
|
||||
simplex.iterate(evalFunc, comparator);
|
||||
++iteration;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Scans the list of (required and optional) optimization data that
|
||||
* characterize the problem.
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link AbstractSimplex}</li>
|
||||
* </ul>
|
||||
*/
|
||||
private void parseOptimizationData(OptimizationData... optData) {
|
||||
// The existing values (as set by the previous call) are reused if
|
||||
// not provided in the argument list.
|
||||
for (OptimizationData data : optData) {
|
||||
if (data instanceof AbstractSimplex) {
|
||||
simplex = (AbstractSimplex) data;
|
||||
// If more data must be parsed, this statement _must_ be
|
||||
// changed to "continue".
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
|
||||
|
||||
/**
|
||||
* This package provides optimization algorithms that do not require derivatives.
|
||||
*/
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar;
|
||||
|
||||
/**
|
||||
* Algorithms for optimizing a scalar function.
|
||||
*/
|
|
@ -0,0 +1,114 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.vector;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||
import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
|
||||
/**
|
||||
* Base class for implementing optimizers for multivariate vector
|
||||
* differentiable functions.
|
||||
* It contains boiler-plate code for dealing with Jacobian evaluation.
|
||||
* It assumes that the rows of the Jacobian matrix iterate on the model
|
||||
* functions while the columns iterate on the parameters; thus, the numbers
|
||||
* of rows is equal to the dimension of the {@link Target} while the
|
||||
* number of columns is equal to the dimension of the
|
||||
* {@link org.apache.commons.math3.optim.InitialGuess InitialGuess}.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public abstract class JacobianMultivariateVectorOptimizer
|
||||
extends MultivariateVectorOptimizer {
|
||||
/**
|
||||
* Jacobian of the model function.
|
||||
*/
|
||||
private MultivariateMatrixFunction jacobian;
|
||||
|
||||
/**
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
protected JacobianMultivariateVectorOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
|
||||
super(checker);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the Jacobian matrix.
|
||||
*
|
||||
* @param params Point at which the Jacobian must be evaluated.
|
||||
* @return the Jacobian at the specified point.
|
||||
*/
|
||||
protected double[][] computeJacobian(final double[] params) {
|
||||
return jacobian.value(params);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
|
||||
* <li>{@link Target}</li>
|
||||
* <li>{@link Weight}</li>
|
||||
* <li>{@link ModelFunction}</li>
|
||||
* <li>{@link ModelFunctionJacobian}</li>
|
||||
* </ul>
|
||||
* @return {@inheritDoc}
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations is exceeded.
|
||||
* @throws DimensionMismatchException if the initial guess, target, and weight
|
||||
* arguments have inconsistent dimensions.
|
||||
*/
|
||||
@Override
|
||||
public PointVectorValuePair optimize(OptimizationData... optData)
|
||||
throws TooManyEvaluationsException,
|
||||
DimensionMismatchException {
|
||||
// Retrieve settings.
|
||||
parseOptimizationData(optData);
|
||||
// Set up base class and perform computation.
|
||||
return super.optimize(optData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Scans the list of (required and optional) optimization data that
|
||||
* characterize the problem.
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link ModelFunctionJacobian}</li>
|
||||
* </ul>
|
||||
*/
|
||||
private void parseOptimizationData(OptimizationData... optData) {
|
||||
// The existing values (as set by the previous call) are reused if
|
||||
// not provided in the argument list.
|
||||
for (OptimizationData data : optData) {
|
||||
if (data instanceof ModelFunctionJacobian) {
|
||||
jacobian = ((ModelFunctionJacobian) data).getModelFunctionJacobian();
|
||||
// If more data must be parsed, this statement _must_ be
|
||||
// changed to "continue".
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.vector;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
|
||||
/**
|
||||
* Model (vector) function to be optimized.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public class ModelFunction implements OptimizationData {
|
||||
/** Function to be optimized. */
|
||||
private final MultivariateVectorFunction model;
|
||||
|
||||
/**
|
||||
* @param m Model function to be optimized.
|
||||
*/
|
||||
public ModelFunction(MultivariateVectorFunction m) {
|
||||
model = m;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the model function to be optimized.
|
||||
*
|
||||
* @return the model function.
|
||||
*/
|
||||
public MultivariateVectorFunction getModelFunction() {
|
||||
return model;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.vector;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
|
||||
/**
|
||||
* Jacobian of the model (vector) function to be optimized.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public class ModelFunctionJacobian implements OptimizationData {
|
||||
/** Function to be optimized. */
|
||||
private final MultivariateMatrixFunction jacobian;
|
||||
|
||||
/**
|
||||
* @param j Jacobian of the model function to be optimized.
|
||||
*/
|
||||
public ModelFunctionJacobian(MultivariateMatrixFunction j) {
|
||||
jacobian = j;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the Jacobian of the model function to be optimized.
|
||||
*
|
||||
* @return the model function Jacobian.
|
||||
*/
|
||||
public MultivariateMatrixFunction getModelFunctionJacobian() {
|
||||
return jacobian;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.vector;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.NullArgumentException;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.RealVector;
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.commons.math3.random.RandomVectorGenerator;
|
||||
import org.apache.commons.math3.optim.BaseMultiStartMultivariateOptimizer;
|
||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||
|
||||
/**
|
||||
* Multi-start optimizer for a (vector) model function.
|
||||
*
|
||||
* This class wraps an optimizer in order to use it several times in
|
||||
* turn with different starting points (trying to avoid being trapped
|
||||
* in a local extremum when looking for a global one).
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.0
|
||||
*/
|
||||
public class MultiStartMultivariateVectorOptimizer
|
||||
extends BaseMultiStartMultivariateOptimizer<PointVectorValuePair> {
|
||||
/** Underlying optimizer. */
|
||||
private final MultivariateVectorOptimizer optimizer;
|
||||
/** Found optima. */
|
||||
private final List<PointVectorValuePair> optima = new ArrayList<PointVectorValuePair>();
|
||||
|
||||
/**
|
||||
* Create a multi-start optimizer from a single-start optimizer.
|
||||
*
|
||||
* @param optimizer Single-start optimizer to wrap.
|
||||
* @param starts Number of starts to perform.
|
||||
* If {@code starts == 1}, the result will be same as if {@code optimizer}
|
||||
* is called directly.
|
||||
* @param generator Random vector generator to use for restarts.
|
||||
* @throws NullArgumentException if {@code optimizer} or {@code generator}
|
||||
* is {@code null}.
|
||||
* @throws NotStrictlyPositiveException if {@code starts < 1}.
|
||||
*/
|
||||
public MultiStartMultivariateVectorOptimizer(final MultivariateVectorOptimizer optimizer,
|
||||
final int starts,
|
||||
final RandomVectorGenerator generator)
|
||||
throws NullArgumentException,
|
||||
NotStrictlyPositiveException {
|
||||
super(optimizer, starts, generator);
|
||||
this.optimizer = optimizer;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public PointVectorValuePair[] getOptima() {
|
||||
Collections.sort(optima, getPairComparator());
|
||||
return optima.toArray(new PointVectorValuePair[0]);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
protected void store(PointVectorValuePair optimum) {
|
||||
optima.add(optimum);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
protected void clear() {
|
||||
optima.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a comparator for sorting the optima.
|
||||
*/
|
||||
private Comparator<PointVectorValuePair> getPairComparator() {
|
||||
return new Comparator<PointVectorValuePair>() {
|
||||
private final RealVector target = new ArrayRealVector(optimizer.getTarget(), false);
|
||||
private final RealMatrix weight = optimizer.getWeight();
|
||||
|
||||
public int compare(final PointVectorValuePair o1,
|
||||
final PointVectorValuePair o2) {
|
||||
if (o1 == null) {
|
||||
return (o2 == null) ? 0 : 1;
|
||||
} else if (o2 == null) {
|
||||
return -1;
|
||||
}
|
||||
return Double.compare(weightedResidual(o1),
|
||||
weightedResidual(o2));
|
||||
}
|
||||
|
||||
private double weightedResidual(final PointVectorValuePair pv) {
|
||||
final RealVector v = new ArrayRealVector(pv.getValueRef(), false);
|
||||
final RealVector r = target.subtract(v);
|
||||
return r.dotProduct(weight.operate(r));
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,164 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim.nonlinear.vector;
|
||||
|
||||
import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.optim.BaseMultivariateOptimizer;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
|
||||
/**
|
||||
* Base class for a multivariate vector function optimizer.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public abstract class MultivariateVectorOptimizer
|
||||
extends BaseMultivariateOptimizer<PointVectorValuePair> {
|
||||
/** Target values for the model function at optimum. */
|
||||
private double[] target;
|
||||
/** Weight matrix. */
|
||||
private RealMatrix weightMatrix;
|
||||
/** Model function. */
|
||||
private MultivariateVectorFunction model;
|
||||
|
||||
/**
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
protected MultivariateVectorOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
|
||||
super(checker);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the objective function value.
|
||||
* This method <em>must</em> be called by subclasses to enforce the
|
||||
* evaluation counter limit.
|
||||
*
|
||||
* @param params Point at which the objective function must be evaluated.
|
||||
* @return the objective function value at the specified point.
|
||||
* @throws TooManyEvaluationsException if the maximal number of evaluations
|
||||
* (of the model vector function) is exceeded.
|
||||
*/
|
||||
protected double[] computeObjectiveValue(double[] params) {
|
||||
super.incrementEvaluationCount();
|
||||
return model.value(params);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
|
||||
* <li>{@link Target}</li>
|
||||
* <li>{@link Weight}</li>
|
||||
* <li>{@link ModelFunction}</li>
|
||||
* </ul>
|
||||
* @return {@inheritDoc}
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations is exceeded.
|
||||
* @throws DimensionMismatchException if the initial guess, target, and weight
|
||||
* arguments have inconsistent dimensions.
|
||||
*/
|
||||
public PointVectorValuePair optimize(OptimizationData... optData)
|
||||
throws TooManyEvaluationsException,
|
||||
DimensionMismatchException {
|
||||
// Retrieve settings.
|
||||
parseOptimizationData(optData);
|
||||
// Check input consistency.
|
||||
checkParameters();
|
||||
// Set up base class and perform computation.
|
||||
return super.optimize(optData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the weight matrix of the observations.
|
||||
*
|
||||
* @return the weight matrix.
|
||||
*/
|
||||
public RealMatrix getWeight() {
|
||||
return weightMatrix.copy();
|
||||
}
|
||||
/**
|
||||
* Gets the observed values to be matched by the objective vector
|
||||
* function.
|
||||
*
|
||||
* @return the target values.
|
||||
*/
|
||||
public double[] getTarget() {
|
||||
return target.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the number of observed values.
|
||||
*
|
||||
* @return the length of the target vector.
|
||||
*/
|
||||
public int getTargetSize() {
|
||||
return target.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* Scans the list of (required and optional) optimization data that
|
||||
* characterize the problem.
|
||||
*
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link Target}</li>
|
||||
* <li>{@link Weight}</li>
|
||||
* <li>{@link ModelFunction}</li>
|
||||
* </ul>
|
||||
*/
|
||||
private void parseOptimizationData(OptimizationData... optData) {
|
||||
// The existing values (as set by the previous call) are reused if
|
||||
// not provided in the argument list.
|
||||
for (OptimizationData data : optData) {
|
||||
if (data instanceof ModelFunction) {
|
||||
model = ((ModelFunction) data).getModelFunction();
|
||||
continue;
|
||||
}
|
||||
if (data instanceof Target) {
|
||||
target = ((Target) data).getTarget();
|
||||
continue;
|
||||
}
|
||||
if (data instanceof Weight) {
|
||||
weightMatrix = ((Weight) data).getWeight();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check parameters consistency.
|
||||
*
|
||||
* @throws DimensionMismatchException if {@link #target} and
|
||||
* {@link #weightMatrix} have inconsistent dimensions.
|
||||
*/
|
||||
private void checkParameters() {
|
||||
if (target.length != weightMatrix.getColumnDimension()) {
|
||||
throw new DimensionMismatchException(target.length,
|
||||
weightMatrix.getColumnDimension());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.vector;
|
||||
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
|
||||
/**
|
||||
* Target of the optimization procedure.
|
||||
* They are the values which the objective vector function must reproduce
|
||||
* When the parameters of the model have been optimized.
|
||||
* <br/>
|
||||
* Immutable class.
|
||||
*
|
||||
* @version $Id: Target.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 3.1
|
||||
*/
|
||||
public class Target implements OptimizationData {
|
||||
/** Target values (of the objective vector function). */
|
||||
private final double[] target;
|
||||
|
||||
/**
|
||||
* @param observations Target values.
|
||||
*/
|
||||
public Target(double[] observations) {
|
||||
target = observations.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the initial guess.
|
||||
*
|
||||
* @return the initial guess.
|
||||
*/
|
||||
public double[] getTarget() {
|
||||
return target.clone();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.vector;
|
||||
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.MatrixUtils;
|
||||
import org.apache.commons.math3.linear.NonSquareMatrixException;
|
||||
|
||||
/**
|
||||
* Weight matrix of the residuals between model and observations.
|
||||
* <br/>
|
||||
* Immutable class.
|
||||
*
|
||||
* @version $Id: Weight.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 3.1
|
||||
*/
|
||||
public class Weight implements OptimizationData {
|
||||
/** Weight matrix. */
|
||||
private final RealMatrix weightMatrix;
|
||||
|
||||
/**
|
||||
* Creates a diagonal weight matrix.
|
||||
*
|
||||
* @param weight List of the values of the diagonal.
|
||||
*/
|
||||
public Weight(double[] weight) {
|
||||
final int dim = weight.length;
|
||||
weightMatrix = MatrixUtils.createRealMatrix(dim, dim);
|
||||
for (int i = 0; i < dim; i++) {
|
||||
weightMatrix.setEntry(i, i, weight[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param weight Weight matrix.
|
||||
* @throws NonSquareMatrixException if the argument is not
|
||||
* a square matrix.
|
||||
*/
|
||||
public Weight(RealMatrix weight) {
|
||||
if (weight.getColumnDimension() != weight.getRowDimension()) {
|
||||
throw new NonSquareMatrixException(weight.getColumnDimension(),
|
||||
weight.getRowDimension());
|
||||
}
|
||||
|
||||
weightMatrix = weight.copy();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the initial guess.
|
||||
*
|
||||
* @return the initial guess.
|
||||
*/
|
||||
public RealMatrix getWeight() {
|
||||
return weightMatrix.copy();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,269 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.vector.jacobian;
|
||||
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.DecompositionSolver;
|
||||
import org.apache.commons.math3.linear.MatrixUtils;
|
||||
import org.apache.commons.math3.linear.QRDecomposition;
|
||||
import org.apache.commons.math3.linear.EigenDecomposition;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.Weight;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.JacobianMultivariateVectorOptimizer;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
|
||||
/**
|
||||
* Base class for implementing least-squares optimizers.
|
||||
* It provides methods for error estimation.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public abstract class AbstractLeastSquaresOptimizer
|
||||
extends JacobianMultivariateVectorOptimizer {
|
||||
/** Square-root of the weight matrix. */
|
||||
private RealMatrix weightMatrixSqrt;
|
||||
/** Cost value (square root of the sum of the residuals). */
|
||||
private double cost;
|
||||
|
||||
/**
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
protected AbstractLeastSquaresOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
|
||||
super(checker);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the weighted Jacobian matrix.
|
||||
*
|
||||
* @param params Model parameters at which to compute the Jacobian.
|
||||
* @return the weighted Jacobian: W<sup>1/2</sup> J.
|
||||
* @throws DimensionMismatchException if the Jacobian dimension does not
|
||||
* match problem dimension.
|
||||
*/
|
||||
protected RealMatrix computeWeightedJacobian(double[] params) {
|
||||
return weightMatrixSqrt.multiply(MatrixUtils.createRealMatrix(computeJacobian(params)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the cost.
|
||||
*
|
||||
* @param residuals Residuals.
|
||||
* @return the cost.
|
||||
* @see #computeResiduals(double[])
|
||||
*/
|
||||
protected double computeCost(double[] residuals) {
|
||||
final ArrayRealVector r = new ArrayRealVector(residuals);
|
||||
return FastMath.sqrt(r.dotProduct(getWeight().operate(r)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the root-mean-square (RMS) value.
|
||||
*
|
||||
* The RMS the root of the arithmetic mean of the square of all weighted
|
||||
* residuals.
|
||||
* This is related to the criterion that is minimized by the optimizer
|
||||
* as follows: If <em>c</em> if the criterion, and <em>n</em> is the
|
||||
* number of measurements, then the RMS is <em>sqrt (c/n)</em>.
|
||||
*
|
||||
* @return the RMS value.
|
||||
*/
|
||||
public double getRMS() {
|
||||
return FastMath.sqrt(getChiSquare() / getTargetSize());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a Chi-Square-like value assuming the N residuals follow N
|
||||
* distinct normal distributions centered on 0 and whose variances are
|
||||
* the reciprocal of the weights.
|
||||
* @return chi-square value
|
||||
*/
|
||||
public double getChiSquare() {
|
||||
return cost * cost;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the square-root of the weight matrix.
|
||||
*
|
||||
* @return the square-root of the weight matrix.
|
||||
*/
|
||||
public RealMatrix getWeightSquareRoot() {
|
||||
return weightMatrixSqrt.copy();
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the cost.
|
||||
*
|
||||
* @param cost Cost value.
|
||||
*/
|
||||
protected void setCost(double cost) {
|
||||
this.cost = cost;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the covariance matrix of the optimized parameters.
|
||||
* <br/>
|
||||
* Note that this operation involves the inversion of the
|
||||
* <code>J<sup>T</sup>J</code> matrix, where {@code J} is the
|
||||
* Jacobian matrix.
|
||||
* The {@code threshold} parameter is a way for the caller to specify
|
||||
* that the result of this computation should be considered meaningless,
|
||||
* and thus trigger an exception.
|
||||
*
|
||||
* @param params Model parameters.
|
||||
* @param threshold Singularity threshold.
|
||||
* @return the covariance matrix.
|
||||
* @throws org.apache.commons.math3.linear.SingularMatrixException
|
||||
* if the covariance matrix cannot be computed (singular problem).
|
||||
*/
|
||||
public double[][] computeCovariances(double[] params,
|
||||
double threshold) {
|
||||
// Set up the Jacobian.
|
||||
final RealMatrix j = computeWeightedJacobian(params);
|
||||
|
||||
// Compute transpose(J)J.
|
||||
final RealMatrix jTj = j.transpose().multiply(j);
|
||||
|
||||
// Compute the covariances matrix.
|
||||
final DecompositionSolver solver
|
||||
= new QRDecomposition(jTj, threshold).getSolver();
|
||||
return solver.getInverse().getData();
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes an estimate of the standard deviation of the parameters. The
|
||||
* returned values are the square root of the diagonal coefficients of the
|
||||
* covariance matrix, {@code sd(a[i]) ~= sqrt(C[i][i])}, where {@code a[i]}
|
||||
* is the optimized value of the {@code i}-th parameter, and {@code C} is
|
||||
* the covariance matrix.
|
||||
*
|
||||
* @param params Model parameters.
|
||||
* @param covarianceSingularityThreshold Singularity threshold (see
|
||||
* {@link #computeCovariances(double[],double) computeCovariances}).
|
||||
* @return an estimate of the standard deviation of the optimized parameters
|
||||
* @throws org.apache.commons.math3.linear.SingularMatrixException
|
||||
* if the covariance matrix cannot be computed.
|
||||
*/
|
||||
public double[] computeSigma(double[] params,
|
||||
double covarianceSingularityThreshold) {
|
||||
final int nC = params.length;
|
||||
final double[] sig = new double[nC];
|
||||
final double[][] cov = computeCovariances(params, covarianceSingularityThreshold);
|
||||
for (int i = 0; i < nC; ++i) {
|
||||
sig[i] = FastMath.sqrt(cov[i][i]);
|
||||
}
|
||||
return sig;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.nonlinear.vector.Target}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.nonlinear.vector.Weight}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.nonlinear.vector.ModelFunction}</li>
|
||||
* <li>{@link org.apache.commons.math3.optim.nonlinear.vector.ModelFunctionJacobian}</li>
|
||||
* </ul>
|
||||
* @return {@inheritDoc}
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations is exceeded.
|
||||
* @throws DimensionMismatchException if the initial guess, target, and weight
|
||||
* arguments have inconsistent dimensions.
|
||||
*/
|
||||
@Override
|
||||
public PointVectorValuePair optimize(OptimizationData... optData)
|
||||
throws TooManyEvaluationsException {
|
||||
// Retrieve settings.
|
||||
parseOptimizationData(optData);
|
||||
// Set up base class and perform computation.
|
||||
return super.optimize(optData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the residuals.
|
||||
* The residual is the difference between the observed (target)
|
||||
* values and the model (objective function) value.
|
||||
* There is one residual for each element of the vector-valued
|
||||
* function.
|
||||
*
|
||||
* @param objectiveValue Value of the the objective function. This is
|
||||
* the value returned from a call to
|
||||
* {@link #computeObjectiveValue(double[]) computeObjectiveValue}
|
||||
* (whose array argument contains the model parameters).
|
||||
* @return the residuals.
|
||||
* @throws DimensionMismatchException if {@code params} has a wrong
|
||||
* length.
|
||||
*/
|
||||
protected double[] computeResiduals(double[] objectiveValue) {
|
||||
final double[] target = getTarget();
|
||||
if (objectiveValue.length != target.length) {
|
||||
throw new DimensionMismatchException(target.length,
|
||||
objectiveValue.length);
|
||||
}
|
||||
|
||||
final double[] residuals = new double[target.length];
|
||||
for (int i = 0; i < target.length; i++) {
|
||||
residuals[i] = target[i] - objectiveValue[i];
|
||||
}
|
||||
|
||||
return residuals;
|
||||
}
|
||||
|
||||
/**
|
||||
* Scans the list of (required and optional) optimization data that
|
||||
* characterize the problem.
|
||||
* If the weight matrix is specified, the {@link #weightMatrixSqrt}
|
||||
* field is recomputed.
|
||||
*
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link Weight}</li>
|
||||
* </ul>
|
||||
*/
|
||||
private void parseOptimizationData(OptimizationData... optData) {
|
||||
// The existing values (as set by the previous call) are reused if
|
||||
// not provided in the argument list.
|
||||
for (OptimizationData data : optData) {
|
||||
if (data instanceof Weight) {
|
||||
weightMatrixSqrt = squareRoot(((Weight) data).getWeight());
|
||||
// If more data must be parsed, this statement _must_ be
|
||||
// changed to "continue".
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the square-root of the weight matrix.
|
||||
*
|
||||
* @param m Symmetric, positive-definite (weight) matrix.
|
||||
* @return the square-root of the weight matrix.
|
||||
*/
|
||||
private RealMatrix squareRoot(RealMatrix m) {
|
||||
final EigenDecomposition dec = new EigenDecomposition(m);
|
||||
return dec.getSquareRoot();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.vector.jacobian;
|
||||
|
||||
import org.apache.commons.math3.exception.ConvergenceException;
|
||||
import org.apache.commons.math3.exception.NullArgumentException;
|
||||
import org.apache.commons.math3.exception.MathInternalError;
|
||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.commons.math3.linear.BlockRealMatrix;
|
||||
import org.apache.commons.math3.linear.DecompositionSolver;
|
||||
import org.apache.commons.math3.linear.LUDecomposition;
|
||||
import org.apache.commons.math3.linear.QRDecomposition;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.SingularMatrixException;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||
|
||||
/**
|
||||
* Gauss-Newton least-squares solver.
|
||||
* <p>
|
||||
* This class solve a least-square problem by solving the normal equations
|
||||
* of the linearized problem at each iteration. Either LU decomposition or
|
||||
* QR decomposition can be used to solve the normal equations. LU decomposition
|
||||
* is faster but QR decomposition is more robust for difficult problems.
|
||||
* </p>
|
||||
*
|
||||
* @version $Id: GaussNewtonOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*
|
||||
*/
|
||||
public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
|
||||
/** Indicator for using LU decomposition. */
|
||||
private final boolean useLU;
|
||||
|
||||
/**
|
||||
* Simple constructor with default settings.
|
||||
* The normal equations will be solved using LU decomposition.
|
||||
*
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
public GaussNewtonOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
|
||||
this(true, checker);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param useLU If {@code true}, the normal equations will be solved
|
||||
* using LU decomposition, otherwise they will be solved using QR
|
||||
* decomposition.
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
public GaussNewtonOptimizer(final boolean useLU,
|
||||
ConvergenceChecker<PointVectorValuePair> checker) {
|
||||
super(checker);
|
||||
this.useLU = useLU;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public PointVectorValuePair doOptimize() {
|
||||
final ConvergenceChecker<PointVectorValuePair> checker
|
||||
= getConvergenceChecker();
|
||||
|
||||
// Computation will be useless without a checker (see "for-loop").
|
||||
if (checker == null) {
|
||||
throw new NullArgumentException();
|
||||
}
|
||||
|
||||
final double[] targetValues = getTarget();
|
||||
final int nR = targetValues.length; // Number of observed data.
|
||||
|
||||
final RealMatrix weightMatrix = getWeight();
|
||||
// Diagonal of the weight matrix.
|
||||
final double[] residualsWeights = new double[nR];
|
||||
for (int i = 0; i < nR; i++) {
|
||||
residualsWeights[i] = weightMatrix.getEntry(i, i);
|
||||
}
|
||||
|
||||
final double[] currentPoint = getStartPoint();
|
||||
final int nC = currentPoint.length;
|
||||
|
||||
// iterate until convergence is reached
|
||||
PointVectorValuePair current = null;
|
||||
int iter = 0;
|
||||
for (boolean converged = false; !converged;) {
|
||||
++iter;
|
||||
|
||||
// evaluate the objective function and its jacobian
|
||||
PointVectorValuePair previous = current;
|
||||
// Value of the objective function at "currentPoint".
|
||||
final double[] currentObjective = computeObjectiveValue(currentPoint);
|
||||
final double[] currentResiduals = computeResiduals(currentObjective);
|
||||
final RealMatrix weightedJacobian = computeWeightedJacobian(currentPoint);
|
||||
current = new PointVectorValuePair(currentPoint, currentObjective);
|
||||
|
||||
// build the linear problem
|
||||
final double[] b = new double[nC];
|
||||
final double[][] a = new double[nC][nC];
|
||||
for (int i = 0; i < nR; ++i) {
|
||||
|
||||
final double[] grad = weightedJacobian.getRow(i);
|
||||
final double weight = residualsWeights[i];
|
||||
final double residual = currentResiduals[i];
|
||||
|
||||
// compute the normal equation
|
||||
final double wr = weight * residual;
|
||||
for (int j = 0; j < nC; ++j) {
|
||||
b[j] += wr * grad[j];
|
||||
}
|
||||
|
||||
// build the contribution matrix for measurement i
|
||||
for (int k = 0; k < nC; ++k) {
|
||||
double[] ak = a[k];
|
||||
double wgk = weight * grad[k];
|
||||
for (int l = 0; l < nC; ++l) {
|
||||
ak[l] += wgk * grad[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// solve the linearized least squares problem
|
||||
RealMatrix mA = new BlockRealMatrix(a);
|
||||
DecompositionSolver solver = useLU ?
|
||||
new LUDecomposition(mA).getSolver() :
|
||||
new QRDecomposition(mA).getSolver();
|
||||
final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
|
||||
// update the estimated parameters
|
||||
for (int i = 0; i < nC; ++i) {
|
||||
currentPoint[i] += dX[i];
|
||||
}
|
||||
} catch (SingularMatrixException e) {
|
||||
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
|
||||
}
|
||||
|
||||
// Check convergence.
|
||||
if (previous != null) {
|
||||
converged = checker.converged(iter, previous, current);
|
||||
if (converged) {
|
||||
setCost(computeCost(currentResiduals));
|
||||
return current;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Must never happen.
|
||||
throw new MathInternalError();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,939 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.vector.jacobian;
|
||||
|
||||
import java.util.Arrays;
|
||||
import org.apache.commons.math3.exception.ConvergenceException;
|
||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math3.optim.PointVectorValuePair;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.util.Precision;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
|
||||
|
||||
/**
|
||||
* This class solves a least-squares problem using the Levenberg-Marquardt algorithm.
|
||||
*
|
||||
* <p>This implementation <em>should</em> work even for over-determined systems
|
||||
* (i.e. systems having more point than equations). Over-determined systems
|
||||
* are solved by ignoring the point which have the smallest impact according
|
||||
* to their jacobian column norm. Only the rank of the matrix and some loop bounds
|
||||
* are changed to implement this.</p>
|
||||
*
|
||||
* <p>The resolution engine is a simple translation of the MINPACK <a
|
||||
* href="http://www.netlib.org/minpack/lmder.f">lmder</a> routine with minor
|
||||
* changes. The changes include the over-determined resolution, the use of
|
||||
* inherited convergence checker and the Q.R. decomposition which has been
|
||||
* rewritten following the algorithm described in the
|
||||
* P. Lascaux and R. Theodor book <i>Analyse numérique matricielle
|
||||
* appliquée à l'art de l'ingénieur</i>, Masson 1986.</p>
|
||||
* <p>The authors of the original fortran version are:
|
||||
* <ul>
|
||||
* <li>Argonne National Laboratory. MINPACK project. March 1980</li>
|
||||
* <li>Burton S. Garbow</li>
|
||||
* <li>Kenneth E. Hillstrom</li>
|
||||
* <li>Jorge J. More</li>
|
||||
* </ul>
|
||||
* The redistribution policy for MINPACK is available <a
|
||||
* href="http://www.netlib.org/minpack/disclaimer">here</a>, for convenience, it
|
||||
* is reproduced below.</p>
|
||||
*
|
||||
* <table border="0" width="80%" cellpadding="10" align="center" bgcolor="#E0E0E0">
|
||||
* <tr><td>
|
||||
* Minpack Copyright Notice (1999) University of Chicago.
|
||||
* All rights reserved
|
||||
* </td></tr>
|
||||
* <tr><td>
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions
|
||||
* are met:
|
||||
* <ol>
|
||||
* <li>Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.</li>
|
||||
* <li>Redistributions in binary form must reproduce the above
|
||||
* copyright notice, this list of conditions and the following
|
||||
* disclaimer in the documentation and/or other materials provided
|
||||
* with the distribution.</li>
|
||||
* <li>The end-user documentation included with the redistribution, if any,
|
||||
* must include the following acknowledgment:
|
||||
* <code>This product includes software developed by the University of
|
||||
* Chicago, as Operator of Argonne National Laboratory.</code>
|
||||
* Alternately, this acknowledgment may appear in the software itself,
|
||||
* if and wherever such third-party acknowledgments normally appear.</li>
|
||||
* <li><strong>WARRANTY DISCLAIMER. THE SOFTWARE IS SUPPLIED "AS IS"
|
||||
* WITHOUT WARRANTY OF ANY KIND. THE COPYRIGHT HOLDER, THE
|
||||
* UNITED STATES, THE UNITED STATES DEPARTMENT OF ENERGY, AND
|
||||
* THEIR EMPLOYEES: (1) DISCLAIM ANY WARRANTIES, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED WARRANTIES
|
||||
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE
|
||||
* OR NON-INFRINGEMENT, (2) DO NOT ASSUME ANY LEGAL LIABILITY
|
||||
* OR RESPONSIBILITY FOR THE ACCURACY, COMPLETENESS, OR
|
||||
* USEFULNESS OF THE SOFTWARE, (3) DO NOT REPRESENT THAT USE OF
|
||||
* THE SOFTWARE WOULD NOT INFRINGE PRIVATELY OWNED RIGHTS, (4)
|
||||
* DO NOT WARRANT THAT THE SOFTWARE WILL FUNCTION
|
||||
* UNINTERRUPTED, THAT IT IS ERROR-FREE OR THAT ANY ERRORS WILL
|
||||
* BE CORRECTED.</strong></li>
|
||||
* <li><strong>LIMITATION OF LIABILITY. IN NO EVENT WILL THE COPYRIGHT
|
||||
* HOLDER, THE UNITED STATES, THE UNITED STATES DEPARTMENT OF
|
||||
* ENERGY, OR THEIR EMPLOYEES: BE LIABLE FOR ANY INDIRECT,
|
||||
* INCIDENTAL, CONSEQUENTIAL, SPECIAL OR PUNITIVE DAMAGES OF
|
||||
* ANY KIND OR NATURE, INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
* PROFITS OR LOSS OF DATA, FOR ANY REASON WHATSOEVER, WHETHER
|
||||
* SUCH LIABILITY IS ASSERTED ON THE BASIS OF CONTRACT, TORT
|
||||
* (INCLUDING NEGLIGENCE OR STRICT LIABILITY), OR OTHERWISE,
|
||||
* EVEN IF ANY OF SAID PARTIES HAS BEEN WARNED OF THE
|
||||
* POSSIBILITY OF SUCH LOSS OR DAMAGES.</strong></li>
|
||||
* <ol></td></tr>
|
||||
* </table>
|
||||
*
|
||||
* @version $Id: LevenbergMarquardtOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public class LevenbergMarquardtOptimizer
|
||||
extends AbstractLeastSquaresOptimizer {
|
||||
/** Number of solved point. */
|
||||
private int solvedCols;
|
||||
/** Diagonal elements of the R matrix in the Q.R. decomposition. */
|
||||
private double[] diagR;
|
||||
/** Norms of the columns of the jacobian matrix. */
|
||||
private double[] jacNorm;
|
||||
/** Coefficients of the Householder transforms vectors. */
|
||||
private double[] beta;
|
||||
/** Columns permutation array. */
|
||||
private int[] permutation;
|
||||
/** Rank of the jacobian matrix. */
|
||||
private int rank;
|
||||
/** Levenberg-Marquardt parameter. */
|
||||
private double lmPar;
|
||||
/** Parameters evolution direction associated with lmPar. */
|
||||
private double[] lmDir;
|
||||
/** Positive input variable used in determining the initial step bound. */
|
||||
private final double initialStepBoundFactor;
|
||||
/** Desired relative error in the sum of squares. */
|
||||
private final double costRelativeTolerance;
|
||||
/** Desired relative error in the approximate solution parameters. */
|
||||
private final double parRelativeTolerance;
|
||||
/** Desired max cosine on the orthogonality between the function vector
|
||||
* and the columns of the jacobian. */
|
||||
private final double orthoTolerance;
|
||||
/** Threshold for QR ranking. */
|
||||
private final double qrRankingThreshold;
|
||||
/** Weighted residuals. */
|
||||
private double[] weightedResidual;
|
||||
/** Weighted Jacobian. */
|
||||
private double[][] weightedJacobian;
|
||||
|
||||
/**
|
||||
* Build an optimizer for least squares problems with default values
|
||||
* for all the tuning parameters (see the {@link
|
||||
* #LevenbergMarquardtOptimizer(double,double,double,double,double)
|
||||
* other contructor}.
|
||||
* The default values for the algorithm settings are:
|
||||
* <ul>
|
||||
* <li>Initial step bound factor: 100</li>
|
||||
* <li>Cost relative tolerance: 1e-10</li>
|
||||
* <li>Parameters relative tolerance: 1e-10</li>
|
||||
* <li>Orthogonality tolerance: 1e-10</li>
|
||||
* <li>QR ranking threshold: {@link Precision#SAFE_MIN}</li>
|
||||
* </ul>
|
||||
*/
|
||||
public LevenbergMarquardtOptimizer() {
|
||||
this(100, 1e-10, 1e-10, 1e-10, Precision.SAFE_MIN);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor that allows the specification of a custom convergence
|
||||
* checker.
|
||||
* Note that all the usual convergence checks will be <em>disabled</em>.
|
||||
* The default values for the algorithm settings are:
|
||||
* <ul>
|
||||
* <li>Initial step bound factor: 100</li>
|
||||
* <li>Cost relative tolerance: 1e-10</li>
|
||||
* <li>Parameters relative tolerance: 1e-10</li>
|
||||
* <li>Orthogonality tolerance: 1e-10</li>
|
||||
* <li>QR ranking threshold: {@link Precision#SAFE_MIN}</li>
|
||||
* </ul>
|
||||
*
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
public LevenbergMarquardtOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
|
||||
this(100, checker, 1e-10, 1e-10, 1e-10, Precision.SAFE_MIN);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor that allows the specification of a custom convergence
|
||||
* checker, in addition to the standard ones.
|
||||
*
|
||||
* @param initialStepBoundFactor Positive input variable used in
|
||||
* determining the initial step bound. This bound is set to the
|
||||
* product of initialStepBoundFactor and the euclidean norm of
|
||||
* {@code diag * x} if non-zero, or else to {@code initialStepBoundFactor}
|
||||
* itself. In most cases factor should lie in the interval
|
||||
* {@code (0.1, 100.0)}. {@code 100} is a generally recommended value.
|
||||
* @param checker Convergence checker.
|
||||
* @param costRelativeTolerance Desired relative error in the sum of
|
||||
* squares.
|
||||
* @param parRelativeTolerance Desired relative error in the approximate
|
||||
* solution parameters.
|
||||
* @param orthoTolerance Desired max cosine on the orthogonality between
|
||||
* the function vector and the columns of the Jacobian.
|
||||
* @param threshold Desired threshold for QR ranking. If the squared norm
|
||||
* of a column vector is smaller or equal to this threshold during QR
|
||||
* decomposition, it is considered to be a zero vector and hence the rank
|
||||
* of the matrix is reduced.
|
||||
*/
|
||||
public LevenbergMarquardtOptimizer(double initialStepBoundFactor,
|
||||
ConvergenceChecker<PointVectorValuePair> checker,
|
||||
double costRelativeTolerance,
|
||||
double parRelativeTolerance,
|
||||
double orthoTolerance,
|
||||
double threshold) {
|
||||
super(checker);
|
||||
this.initialStepBoundFactor = initialStepBoundFactor;
|
||||
this.costRelativeTolerance = costRelativeTolerance;
|
||||
this.parRelativeTolerance = parRelativeTolerance;
|
||||
this.orthoTolerance = orthoTolerance;
|
||||
this.qrRankingThreshold = threshold;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build an optimizer for least squares problems with default values
|
||||
* for some of the tuning parameters (see the {@link
|
||||
* #LevenbergMarquardtOptimizer(double,double,double,double,double)
|
||||
* other contructor}.
|
||||
* The default values for the algorithm settings are:
|
||||
* <ul>
|
||||
* <li>Initial step bound factor}: 100</li>
|
||||
* <li>QR ranking threshold}: {@link Precision#SAFE_MIN}</li>
|
||||
* </ul>
|
||||
*
|
||||
* @param costRelativeTolerance Desired relative error in the sum of
|
||||
* squares.
|
||||
* @param parRelativeTolerance Desired relative error in the approximate
|
||||
* solution parameters.
|
||||
* @param orthoTolerance Desired max cosine on the orthogonality between
|
||||
* the function vector and the columns of the Jacobian.
|
||||
*/
|
||||
public LevenbergMarquardtOptimizer(double costRelativeTolerance,
|
||||
double parRelativeTolerance,
|
||||
double orthoTolerance) {
|
||||
this(100,
|
||||
costRelativeTolerance, parRelativeTolerance, orthoTolerance,
|
||||
Precision.SAFE_MIN);
|
||||
}
|
||||
|
||||
/**
|
||||
* The arguments control the behaviour of the default convergence checking
|
||||
* procedure.
|
||||
* Additional criteria can defined through the setting of a {@link
|
||||
* ConvergenceChecker}.
|
||||
*
|
||||
* @param initialStepBoundFactor Positive input variable used in
|
||||
* determining the initial step bound. This bound is set to the
|
||||
* product of initialStepBoundFactor and the euclidean norm of
|
||||
* {@code diag * x} if non-zero, or else to {@code initialStepBoundFactor}
|
||||
* itself. In most cases factor should lie in the interval
|
||||
* {@code (0.1, 100.0)}. {@code 100} is a generally recommended value.
|
||||
* @param costRelativeTolerance Desired relative error in the sum of
|
||||
* squares.
|
||||
* @param parRelativeTolerance Desired relative error in the approximate
|
||||
* solution parameters.
|
||||
* @param orthoTolerance Desired max cosine on the orthogonality between
|
||||
* the function vector and the columns of the Jacobian.
|
||||
* @param threshold Desired threshold for QR ranking. If the squared norm
|
||||
* of a column vector is smaller or equal to this threshold during QR
|
||||
* decomposition, it is considered to be a zero vector and hence the rank
|
||||
* of the matrix is reduced.
|
||||
*/
|
||||
public LevenbergMarquardtOptimizer(double initialStepBoundFactor,
|
||||
double costRelativeTolerance,
|
||||
double parRelativeTolerance,
|
||||
double orthoTolerance,
|
||||
double threshold) {
|
||||
super(null); // No custom convergence criterion.
|
||||
this.initialStepBoundFactor = initialStepBoundFactor;
|
||||
this.costRelativeTolerance = costRelativeTolerance;
|
||||
this.parRelativeTolerance = parRelativeTolerance;
|
||||
this.orthoTolerance = orthoTolerance;
|
||||
this.qrRankingThreshold = threshold;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
protected PointVectorValuePair doOptimize() {
|
||||
final int nR = getTarget().length; // Number of observed data.
|
||||
final double[] currentPoint = getStartPoint();
|
||||
final int nC = currentPoint.length; // Number of parameters.
|
||||
|
||||
// arrays shared with the other private methods
|
||||
solvedCols = FastMath.min(nR, nC);
|
||||
diagR = new double[nC];
|
||||
jacNorm = new double[nC];
|
||||
beta = new double[nC];
|
||||
permutation = new int[nC];
|
||||
lmDir = new double[nC];
|
||||
|
||||
// local point
|
||||
double delta = 0;
|
||||
double xNorm = 0;
|
||||
double[] diag = new double[nC];
|
||||
double[] oldX = new double[nC];
|
||||
double[] oldRes = new double[nR];
|
||||
double[] oldObj = new double[nR];
|
||||
double[] qtf = new double[nR];
|
||||
double[] work1 = new double[nC];
|
||||
double[] work2 = new double[nC];
|
||||
double[] work3 = new double[nC];
|
||||
|
||||
final RealMatrix weightMatrixSqrt = getWeightSquareRoot();
|
||||
|
||||
// Evaluate the function at the starting point and calculate its norm.
|
||||
double[] currentObjective = computeObjectiveValue(currentPoint);
|
||||
double[] currentResiduals = computeResiduals(currentObjective);
|
||||
PointVectorValuePair current = new PointVectorValuePair(currentPoint, currentObjective);
|
||||
double currentCost = computeCost(currentResiduals);
|
||||
|
||||
// Outer loop.
|
||||
lmPar = 0;
|
||||
boolean firstIteration = true;
|
||||
int iter = 0;
|
||||
final ConvergenceChecker<PointVectorValuePair> checker = getConvergenceChecker();
|
||||
while (true) {
|
||||
++iter;
|
||||
final PointVectorValuePair previous = current;
|
||||
|
||||
// QR decomposition of the jacobian matrix
|
||||
qrDecomposition(computeWeightedJacobian(currentPoint));
|
||||
|
||||
weightedResidual = weightMatrixSqrt.operate(currentResiduals);
|
||||
for (int i = 0; i < nR; i++) {
|
||||
qtf[i] = weightedResidual[i];
|
||||
}
|
||||
|
||||
// compute Qt.res
|
||||
qTy(qtf);
|
||||
|
||||
// now we don't need Q anymore,
|
||||
// so let jacobian contain the R matrix with its diagonal elements
|
||||
for (int k = 0; k < solvedCols; ++k) {
|
||||
int pk = permutation[k];
|
||||
weightedJacobian[k][pk] = diagR[pk];
|
||||
}
|
||||
|
||||
if (firstIteration) {
|
||||
// scale the point according to the norms of the columns
|
||||
// of the initial jacobian
|
||||
xNorm = 0;
|
||||
for (int k = 0; k < nC; ++k) {
|
||||
double dk = jacNorm[k];
|
||||
if (dk == 0) {
|
||||
dk = 1.0;
|
||||
}
|
||||
double xk = dk * currentPoint[k];
|
||||
xNorm += xk * xk;
|
||||
diag[k] = dk;
|
||||
}
|
||||
xNorm = FastMath.sqrt(xNorm);
|
||||
|
||||
// initialize the step bound delta
|
||||
delta = (xNorm == 0) ? initialStepBoundFactor : (initialStepBoundFactor * xNorm);
|
||||
}
|
||||
|
||||
// check orthogonality between function vector and jacobian columns
|
||||
double maxCosine = 0;
|
||||
if (currentCost != 0) {
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
double s = jacNorm[pj];
|
||||
if (s != 0) {
|
||||
double sum = 0;
|
||||
for (int i = 0; i <= j; ++i) {
|
||||
sum += weightedJacobian[i][pj] * qtf[i];
|
||||
}
|
||||
maxCosine = FastMath.max(maxCosine, FastMath.abs(sum) / (s * currentCost));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (maxCosine <= orthoTolerance) {
|
||||
// Convergence has been reached.
|
||||
setCost(currentCost);
|
||||
return current;
|
||||
}
|
||||
|
||||
// rescale if necessary
|
||||
for (int j = 0; j < nC; ++j) {
|
||||
diag[j] = FastMath.max(diag[j], jacNorm[j]);
|
||||
}
|
||||
|
||||
// Inner loop.
|
||||
for (double ratio = 0; ratio < 1.0e-4;) {
|
||||
|
||||
// save the state
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
oldX[pj] = currentPoint[pj];
|
||||
}
|
||||
final double previousCost = currentCost;
|
||||
double[] tmpVec = weightedResidual;
|
||||
weightedResidual = oldRes;
|
||||
oldRes = tmpVec;
|
||||
tmpVec = currentObjective;
|
||||
currentObjective = oldObj;
|
||||
oldObj = tmpVec;
|
||||
|
||||
// determine the Levenberg-Marquardt parameter
|
||||
determineLMParameter(qtf, delta, diag, work1, work2, work3);
|
||||
|
||||
// compute the new point and the norm of the evolution direction
|
||||
double lmNorm = 0;
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
lmDir[pj] = -lmDir[pj];
|
||||
currentPoint[pj] = oldX[pj] + lmDir[pj];
|
||||
double s = diag[pj] * lmDir[pj];
|
||||
lmNorm += s * s;
|
||||
}
|
||||
lmNorm = FastMath.sqrt(lmNorm);
|
||||
// on the first iteration, adjust the initial step bound.
|
||||
if (firstIteration) {
|
||||
delta = FastMath.min(delta, lmNorm);
|
||||
}
|
||||
|
||||
// Evaluate the function at x + p and calculate its norm.
|
||||
currentObjective = computeObjectiveValue(currentPoint);
|
||||
currentResiduals = computeResiduals(currentObjective);
|
||||
current = new PointVectorValuePair(currentPoint, currentObjective);
|
||||
currentCost = computeCost(currentResiduals);
|
||||
|
||||
// compute the scaled actual reduction
|
||||
double actRed = -1.0;
|
||||
if (0.1 * currentCost < previousCost) {
|
||||
double r = currentCost / previousCost;
|
||||
actRed = 1.0 - r * r;
|
||||
}
|
||||
|
||||
// compute the scaled predicted reduction
|
||||
// and the scaled directional derivative
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
double dirJ = lmDir[pj];
|
||||
work1[j] = 0;
|
||||
for (int i = 0; i <= j; ++i) {
|
||||
work1[i] += weightedJacobian[i][pj] * dirJ;
|
||||
}
|
||||
}
|
||||
double coeff1 = 0;
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
coeff1 += work1[j] * work1[j];
|
||||
}
|
||||
double pc2 = previousCost * previousCost;
|
||||
coeff1 = coeff1 / pc2;
|
||||
double coeff2 = lmPar * lmNorm * lmNorm / pc2;
|
||||
double preRed = coeff1 + 2 * coeff2;
|
||||
double dirDer = -(coeff1 + coeff2);
|
||||
|
||||
// ratio of the actual to the predicted reduction
|
||||
ratio = (preRed == 0) ? 0 : (actRed / preRed);
|
||||
|
||||
// update the step bound
|
||||
if (ratio <= 0.25) {
|
||||
double tmp =
|
||||
(actRed < 0) ? (0.5 * dirDer / (dirDer + 0.5 * actRed)) : 0.5;
|
||||
if ((0.1 * currentCost >= previousCost) || (tmp < 0.1)) {
|
||||
tmp = 0.1;
|
||||
}
|
||||
delta = tmp * FastMath.min(delta, 10.0 * lmNorm);
|
||||
lmPar /= tmp;
|
||||
} else if ((lmPar == 0) || (ratio >= 0.75)) {
|
||||
delta = 2 * lmNorm;
|
||||
lmPar *= 0.5;
|
||||
}
|
||||
|
||||
// test for successful iteration.
|
||||
if (ratio >= 1.0e-4) {
|
||||
// successful iteration, update the norm
|
||||
firstIteration = false;
|
||||
xNorm = 0;
|
||||
for (int k = 0; k < nC; ++k) {
|
||||
double xK = diag[k] * currentPoint[k];
|
||||
xNorm += xK * xK;
|
||||
}
|
||||
xNorm = FastMath.sqrt(xNorm);
|
||||
|
||||
// tests for convergence.
|
||||
if (checker != null) {
|
||||
// we use the vectorial convergence checker
|
||||
if (checker.converged(iter, previous, current)) {
|
||||
setCost(currentCost);
|
||||
return current;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// failed iteration, reset the previous values
|
||||
currentCost = previousCost;
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
currentPoint[pj] = oldX[pj];
|
||||
}
|
||||
tmpVec = weightedResidual;
|
||||
weightedResidual = oldRes;
|
||||
oldRes = tmpVec;
|
||||
tmpVec = currentObjective;
|
||||
currentObjective = oldObj;
|
||||
oldObj = tmpVec;
|
||||
// Reset "current" to previous values.
|
||||
current = new PointVectorValuePair(currentPoint, currentObjective);
|
||||
}
|
||||
|
||||
// Default convergence criteria.
|
||||
if ((FastMath.abs(actRed) <= costRelativeTolerance &&
|
||||
preRed <= costRelativeTolerance &&
|
||||
ratio <= 2.0) ||
|
||||
delta <= parRelativeTolerance * xNorm) {
|
||||
setCost(currentCost);
|
||||
return current;
|
||||
}
|
||||
|
||||
// tests for termination and stringent tolerances
|
||||
// (2.2204e-16 is the machine epsilon for IEEE754)
|
||||
if ((FastMath.abs(actRed) <= 2.2204e-16) && (preRed <= 2.2204e-16) && (ratio <= 2.0)) {
|
||||
throw new ConvergenceException(LocalizedFormats.TOO_SMALL_COST_RELATIVE_TOLERANCE,
|
||||
costRelativeTolerance);
|
||||
} else if (delta <= 2.2204e-16 * xNorm) {
|
||||
throw new ConvergenceException(LocalizedFormats.TOO_SMALL_PARAMETERS_RELATIVE_TOLERANCE,
|
||||
parRelativeTolerance);
|
||||
} else if (maxCosine <= 2.2204e-16) {
|
||||
throw new ConvergenceException(LocalizedFormats.TOO_SMALL_ORTHOGONALITY_TOLERANCE,
|
||||
orthoTolerance);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine the Levenberg-Marquardt parameter.
|
||||
* <p>This implementation is a translation in Java of the MINPACK
|
||||
* <a href="http://www.netlib.org/minpack/lmpar.f">lmpar</a>
|
||||
* routine.</p>
|
||||
* <p>This method sets the lmPar and lmDir attributes.</p>
|
||||
* <p>The authors of the original fortran function are:</p>
|
||||
* <ul>
|
||||
* <li>Argonne National Laboratory. MINPACK project. March 1980</li>
|
||||
* <li>Burton S. Garbow</li>
|
||||
* <li>Kenneth E. Hillstrom</li>
|
||||
* <li>Jorge J. More</li>
|
||||
* </ul>
|
||||
* <p>Luc Maisonobe did the Java translation.</p>
|
||||
*
|
||||
* @param qy array containing qTy
|
||||
* @param delta upper bound on the euclidean norm of diagR * lmDir
|
||||
* @param diag diagonal matrix
|
||||
* @param work1 work array
|
||||
* @param work2 work array
|
||||
* @param work3 work array
|
||||
*/
|
||||
private void determineLMParameter(double[] qy, double delta, double[] diag,
|
||||
double[] work1, double[] work2, double[] work3) {
|
||||
final int nC = weightedJacobian[0].length;
|
||||
|
||||
// compute and store in x the gauss-newton direction, if the
|
||||
// jacobian is rank-deficient, obtain a least squares solution
|
||||
for (int j = 0; j < rank; ++j) {
|
||||
lmDir[permutation[j]] = qy[j];
|
||||
}
|
||||
for (int j = rank; j < nC; ++j) {
|
||||
lmDir[permutation[j]] = 0;
|
||||
}
|
||||
for (int k = rank - 1; k >= 0; --k) {
|
||||
int pk = permutation[k];
|
||||
double ypk = lmDir[pk] / diagR[pk];
|
||||
for (int i = 0; i < k; ++i) {
|
||||
lmDir[permutation[i]] -= ypk * weightedJacobian[i][pk];
|
||||
}
|
||||
lmDir[pk] = ypk;
|
||||
}
|
||||
|
||||
// evaluate the function at the origin, and test
|
||||
// for acceptance of the Gauss-Newton direction
|
||||
double dxNorm = 0;
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
double s = diag[pj] * lmDir[pj];
|
||||
work1[pj] = s;
|
||||
dxNorm += s * s;
|
||||
}
|
||||
dxNorm = FastMath.sqrt(dxNorm);
|
||||
double fp = dxNorm - delta;
|
||||
if (fp <= 0.1 * delta) {
|
||||
lmPar = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
// if the jacobian is not rank deficient, the Newton step provides
|
||||
// a lower bound, parl, for the zero of the function,
|
||||
// otherwise set this bound to zero
|
||||
double sum2;
|
||||
double parl = 0;
|
||||
if (rank == solvedCols) {
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
work1[pj] *= diag[pj] / dxNorm;
|
||||
}
|
||||
sum2 = 0;
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
double sum = 0;
|
||||
for (int i = 0; i < j; ++i) {
|
||||
sum += weightedJacobian[i][pj] * work1[permutation[i]];
|
||||
}
|
||||
double s = (work1[pj] - sum) / diagR[pj];
|
||||
work1[pj] = s;
|
||||
sum2 += s * s;
|
||||
}
|
||||
parl = fp / (delta * sum2);
|
||||
}
|
||||
|
||||
// calculate an upper bound, paru, for the zero of the function
|
||||
sum2 = 0;
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
double sum = 0;
|
||||
for (int i = 0; i <= j; ++i) {
|
||||
sum += weightedJacobian[i][pj] * qy[i];
|
||||
}
|
||||
sum /= diag[pj];
|
||||
sum2 += sum * sum;
|
||||
}
|
||||
double gNorm = FastMath.sqrt(sum2);
|
||||
double paru = gNorm / delta;
|
||||
if (paru == 0) {
|
||||
// 2.2251e-308 is the smallest positive real for IEE754
|
||||
paru = 2.2251e-308 / FastMath.min(delta, 0.1);
|
||||
}
|
||||
|
||||
// if the input par lies outside of the interval (parl,paru),
|
||||
// set par to the closer endpoint
|
||||
lmPar = FastMath.min(paru, FastMath.max(lmPar, parl));
|
||||
if (lmPar == 0) {
|
||||
lmPar = gNorm / dxNorm;
|
||||
}
|
||||
|
||||
for (int countdown = 10; countdown >= 0; --countdown) {
|
||||
|
||||
// evaluate the function at the current value of lmPar
|
||||
if (lmPar == 0) {
|
||||
lmPar = FastMath.max(2.2251e-308, 0.001 * paru);
|
||||
}
|
||||
double sPar = FastMath.sqrt(lmPar);
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
work1[pj] = sPar * diag[pj];
|
||||
}
|
||||
determineLMDirection(qy, work1, work2, work3);
|
||||
|
||||
dxNorm = 0;
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
double s = diag[pj] * lmDir[pj];
|
||||
work3[pj] = s;
|
||||
dxNorm += s * s;
|
||||
}
|
||||
dxNorm = FastMath.sqrt(dxNorm);
|
||||
double previousFP = fp;
|
||||
fp = dxNorm - delta;
|
||||
|
||||
// if the function is small enough, accept the current value
|
||||
// of lmPar, also test for the exceptional cases where parl is zero
|
||||
if ((FastMath.abs(fp) <= 0.1 * delta) ||
|
||||
((parl == 0) && (fp <= previousFP) && (previousFP < 0))) {
|
||||
return;
|
||||
}
|
||||
|
||||
// compute the Newton correction
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
work1[pj] = work3[pj] * diag[pj] / dxNorm;
|
||||
}
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
work1[pj] /= work2[j];
|
||||
double tmp = work1[pj];
|
||||
for (int i = j + 1; i < solvedCols; ++i) {
|
||||
work1[permutation[i]] -= weightedJacobian[i][pj] * tmp;
|
||||
}
|
||||
}
|
||||
sum2 = 0;
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
double s = work1[permutation[j]];
|
||||
sum2 += s * s;
|
||||
}
|
||||
double correction = fp / (delta * sum2);
|
||||
|
||||
// depending on the sign of the function, update parl or paru.
|
||||
if (fp > 0) {
|
||||
parl = FastMath.max(parl, lmPar);
|
||||
} else if (fp < 0) {
|
||||
paru = FastMath.min(paru, lmPar);
|
||||
}
|
||||
|
||||
// compute an improved estimate for lmPar
|
||||
lmPar = FastMath.max(parl, lmPar + correction);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Solve a*x = b and d*x = 0 in the least squares sense.
|
||||
* <p>This implementation is a translation in Java of the MINPACK
|
||||
* <a href="http://www.netlib.org/minpack/qrsolv.f">qrsolv</a>
|
||||
* routine.</p>
|
||||
* <p>This method sets the lmDir and lmDiag attributes.</p>
|
||||
* <p>The authors of the original fortran function are:</p>
|
||||
* <ul>
|
||||
* <li>Argonne National Laboratory. MINPACK project. March 1980</li>
|
||||
* <li>Burton S. Garbow</li>
|
||||
* <li>Kenneth E. Hillstrom</li>
|
||||
* <li>Jorge J. More</li>
|
||||
* </ul>
|
||||
* <p>Luc Maisonobe did the Java translation.</p>
|
||||
*
|
||||
* @param qy array containing qTy
|
||||
* @param diag diagonal matrix
|
||||
* @param lmDiag diagonal elements associated with lmDir
|
||||
* @param work work array
|
||||
*/
|
||||
private void determineLMDirection(double[] qy, double[] diag,
|
||||
double[] lmDiag, double[] work) {
|
||||
|
||||
// copy R and Qty to preserve input and initialize s
|
||||
// in particular, save the diagonal elements of R in lmDir
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
int pj = permutation[j];
|
||||
for (int i = j + 1; i < solvedCols; ++i) {
|
||||
weightedJacobian[i][pj] = weightedJacobian[j][permutation[i]];
|
||||
}
|
||||
lmDir[j] = diagR[pj];
|
||||
work[j] = qy[j];
|
||||
}
|
||||
|
||||
// eliminate the diagonal matrix d using a Givens rotation
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
|
||||
// prepare the row of d to be eliminated, locating the
|
||||
// diagonal element using p from the Q.R. factorization
|
||||
int pj = permutation[j];
|
||||
double dpj = diag[pj];
|
||||
if (dpj != 0) {
|
||||
Arrays.fill(lmDiag, j + 1, lmDiag.length, 0);
|
||||
}
|
||||
lmDiag[j] = dpj;
|
||||
|
||||
// the transformations to eliminate the row of d
|
||||
// modify only a single element of Qty
|
||||
// beyond the first n, which is initially zero.
|
||||
double qtbpj = 0;
|
||||
for (int k = j; k < solvedCols; ++k) {
|
||||
int pk = permutation[k];
|
||||
|
||||
// determine a Givens rotation which eliminates the
|
||||
// appropriate element in the current row of d
|
||||
if (lmDiag[k] != 0) {
|
||||
|
||||
final double sin;
|
||||
final double cos;
|
||||
double rkk = weightedJacobian[k][pk];
|
||||
if (FastMath.abs(rkk) < FastMath.abs(lmDiag[k])) {
|
||||
final double cotan = rkk / lmDiag[k];
|
||||
sin = 1.0 / FastMath.sqrt(1.0 + cotan * cotan);
|
||||
cos = sin * cotan;
|
||||
} else {
|
||||
final double tan = lmDiag[k] / rkk;
|
||||
cos = 1.0 / FastMath.sqrt(1.0 + tan * tan);
|
||||
sin = cos * tan;
|
||||
}
|
||||
|
||||
// compute the modified diagonal element of R and
|
||||
// the modified element of (Qty,0)
|
||||
weightedJacobian[k][pk] = cos * rkk + sin * lmDiag[k];
|
||||
final double temp = cos * work[k] + sin * qtbpj;
|
||||
qtbpj = -sin * work[k] + cos * qtbpj;
|
||||
work[k] = temp;
|
||||
|
||||
// accumulate the tranformation in the row of s
|
||||
for (int i = k + 1; i < solvedCols; ++i) {
|
||||
double rik = weightedJacobian[i][pk];
|
||||
final double temp2 = cos * rik + sin * lmDiag[i];
|
||||
lmDiag[i] = -sin * rik + cos * lmDiag[i];
|
||||
weightedJacobian[i][pk] = temp2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// store the diagonal element of s and restore
|
||||
// the corresponding diagonal element of R
|
||||
lmDiag[j] = weightedJacobian[j][permutation[j]];
|
||||
weightedJacobian[j][permutation[j]] = lmDir[j];
|
||||
}
|
||||
|
||||
// solve the triangular system for z, if the system is
|
||||
// singular, then obtain a least squares solution
|
||||
int nSing = solvedCols;
|
||||
for (int j = 0; j < solvedCols; ++j) {
|
||||
if ((lmDiag[j] == 0) && (nSing == solvedCols)) {
|
||||
nSing = j;
|
||||
}
|
||||
if (nSing < solvedCols) {
|
||||
work[j] = 0;
|
||||
}
|
||||
}
|
||||
if (nSing > 0) {
|
||||
for (int j = nSing - 1; j >= 0; --j) {
|
||||
int pj = permutation[j];
|
||||
double sum = 0;
|
||||
for (int i = j + 1; i < nSing; ++i) {
|
||||
sum += weightedJacobian[i][pj] * work[i];
|
||||
}
|
||||
work[j] = (work[j] - sum) / lmDiag[j];
|
||||
}
|
||||
}
|
||||
|
||||
// permute the components of z back to components of lmDir
|
||||
for (int j = 0; j < lmDir.length; ++j) {
|
||||
lmDir[permutation[j]] = work[j];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Decompose a matrix A as A.P = Q.R using Householder transforms.
|
||||
* <p>As suggested in the P. Lascaux and R. Theodor book
|
||||
* <i>Analyse numérique matricielle appliquée à
|
||||
* l'art de l'ingénieur</i> (Masson, 1986), instead of representing
|
||||
* the Householder transforms with u<sub>k</sub> unit vectors such that:
|
||||
* <pre>
|
||||
* H<sub>k</sub> = I - 2u<sub>k</sub>.u<sub>k</sub><sup>t</sup>
|
||||
* </pre>
|
||||
* we use <sub>k</sub> non-unit vectors such that:
|
||||
* <pre>
|
||||
* H<sub>k</sub> = I - beta<sub>k</sub>v<sub>k</sub>.v<sub>k</sub><sup>t</sup>
|
||||
* </pre>
|
||||
* where v<sub>k</sub> = a<sub>k</sub> - alpha<sub>k</sub> e<sub>k</sub>.
|
||||
* The beta<sub>k</sub> coefficients are provided upon exit as recomputing
|
||||
* them from the v<sub>k</sub> vectors would be costly.</p>
|
||||
* <p>This decomposition handles rank deficient cases since the tranformations
|
||||
* are performed in non-increasing columns norms order thanks to columns
|
||||
* pivoting. The diagonal elements of the R matrix are therefore also in
|
||||
* non-increasing absolute values order.</p>
|
||||
*
|
||||
* @param jacobian Weighted Jacobian matrix at the current point.
|
||||
* @exception ConvergenceException if the decomposition cannot be performed
|
||||
*/
|
||||
private void qrDecomposition(RealMatrix jacobian) throws ConvergenceException {
|
||||
// Code in this class assumes that the weighted Jacobian is -(W^(1/2) J),
|
||||
// hence the multiplication by -1.
|
||||
weightedJacobian = jacobian.scalarMultiply(-1).getData();
|
||||
|
||||
final int nR = weightedJacobian.length;
|
||||
final int nC = weightedJacobian[0].length;
|
||||
|
||||
// initializations
|
||||
for (int k = 0; k < nC; ++k) {
|
||||
permutation[k] = k;
|
||||
double norm2 = 0;
|
||||
for (int i = 0; i < nR; ++i) {
|
||||
double akk = weightedJacobian[i][k];
|
||||
norm2 += akk * akk;
|
||||
}
|
||||
jacNorm[k] = FastMath.sqrt(norm2);
|
||||
}
|
||||
|
||||
// transform the matrix column after column
|
||||
for (int k = 0; k < nC; ++k) {
|
||||
|
||||
// select the column with the greatest norm on active components
|
||||
int nextColumn = -1;
|
||||
double ak2 = Double.NEGATIVE_INFINITY;
|
||||
for (int i = k; i < nC; ++i) {
|
||||
double norm2 = 0;
|
||||
for (int j = k; j < nR; ++j) {
|
||||
double aki = weightedJacobian[j][permutation[i]];
|
||||
norm2 += aki * aki;
|
||||
}
|
||||
if (Double.isInfinite(norm2) || Double.isNaN(norm2)) {
|
||||
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_PERFORM_QR_DECOMPOSITION_ON_JACOBIAN,
|
||||
nR, nC);
|
||||
}
|
||||
if (norm2 > ak2) {
|
||||
nextColumn = i;
|
||||
ak2 = norm2;
|
||||
}
|
||||
}
|
||||
if (ak2 <= qrRankingThreshold) {
|
||||
rank = k;
|
||||
return;
|
||||
}
|
||||
int pk = permutation[nextColumn];
|
||||
permutation[nextColumn] = permutation[k];
|
||||
permutation[k] = pk;
|
||||
|
||||
// choose alpha such that Hk.u = alpha ek
|
||||
double akk = weightedJacobian[k][pk];
|
||||
double alpha = (akk > 0) ? -FastMath.sqrt(ak2) : FastMath.sqrt(ak2);
|
||||
double betak = 1.0 / (ak2 - akk * alpha);
|
||||
beta[pk] = betak;
|
||||
|
||||
// transform the current column
|
||||
diagR[pk] = alpha;
|
||||
weightedJacobian[k][pk] -= alpha;
|
||||
|
||||
// transform the remaining columns
|
||||
for (int dk = nC - 1 - k; dk > 0; --dk) {
|
||||
double gamma = 0;
|
||||
for (int j = k; j < nR; ++j) {
|
||||
gamma += weightedJacobian[j][pk] * weightedJacobian[j][permutation[k + dk]];
|
||||
}
|
||||
gamma *= betak;
|
||||
for (int j = k; j < nR; ++j) {
|
||||
weightedJacobian[j][permutation[k + dk]] -= gamma * weightedJacobian[j][pk];
|
||||
}
|
||||
}
|
||||
}
|
||||
rank = solvedCols;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the product Qt.y for some Q.R. decomposition.
|
||||
*
|
||||
* @param y vector to multiply (will be overwritten with the result)
|
||||
*/
|
||||
private void qTy(double[] y) {
|
||||
final int nR = weightedJacobian.length;
|
||||
final int nC = weightedJacobian[0].length;
|
||||
|
||||
for (int k = 0; k < nC; ++k) {
|
||||
int pk = permutation[k];
|
||||
double gamma = 0;
|
||||
for (int i = k; i < nR; ++i) {
|
||||
gamma += weightedJacobian[i][pk] * y[i];
|
||||
}
|
||||
gamma *= beta[pk];
|
||||
for (int i = k; i < nR; ++i) {
|
||||
y[i] -= gamma * weightedJacobian[i][pk];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.vector.jacobian;
|
||||
|
||||
/**
|
||||
* This package provides optimization algorithms that require derivatives.
|
||||
*/
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.vector;
|
||||
|
||||
/**
|
||||
* Algorithms for optimizing a vector function.
|
||||
*/
|
|
@ -0,0 +1,73 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
/**
|
||||
* <p>
|
||||
* Generally, optimizers are algorithms that will either
|
||||
* {@link GoalType#MINIMIZE minimize} or {@link GoalType#MAXIMIZE maximize}
|
||||
* a scalar function, called the {@link ObjectiveFunction <em>objective
|
||||
* function</em>}.
|
||||
* <br/>
|
||||
* For some scalar objective functions the gradient can be computed (analytically
|
||||
* or numerically). Algorithms that use this knowledge are defined in the
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.scalar.gradient} package.
|
||||
* The algorithms that do not need this additional information are located in
|
||||
* the {@link org.apache.commons.math3.optim.nonlinear.scalar.noderiv} package.
|
||||
* </p>
|
||||
*
|
||||
* <p>
|
||||
* Some problems are solved more efficiently by algorithms that, instead of an
|
||||
* objective function, need access to a
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.vector.ModelFunction
|
||||
* <em>model function</em>}: such a model predicts a set of values which the
|
||||
* algorithm tries to match with a set of given
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.vector.Target target values}.
|
||||
* Those algorithms are located in the
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.vector} package.
|
||||
* <br/>
|
||||
* Algorithms that also require the
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.vector.ModelFunctionJacobian
|
||||
* Jacobian matrix of the model} are located in the
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.vector.jacobian} package.
|
||||
* <br/>
|
||||
* The {@link org.apache.commons.math3.optim.nonlinear.vector.jacobian.AbstractLeastSquaresOptimizer
|
||||
* non-linear least-squares optimizers} are a specialization of the the latter,
|
||||
* that minimize the distance (called <em>cost</em> or <em>χ<sup>2</sup></em>)
|
||||
* between model and observations.
|
||||
* <br/>
|
||||
* For cases where the Jacobian cannot be provided, a utility class will
|
||||
* {@link org.apache.commons.math3.optim.nonlinear.scalar.LeastSquaresConverter
|
||||
* convert} a (vector) model into a (scalar) objective function.
|
||||
* </p>
|
||||
*
|
||||
* <p>
|
||||
* This package provides common functionality for the optimization algorithms.
|
||||
* Abstract classes ({@link BaseOptimizer} and {@link BaseMultivariateOptimizer})
|
||||
* define boiler-plate code for storing {@link MaxEval evaluations} and
|
||||
* {@link MaxIter iterations} counters and a user-defined
|
||||
* {@link ConvergenceChecker convergence checker}.
|
||||
* </p>
|
||||
*
|
||||
* <p>
|
||||
* For each of the optimizer types, there is a special implementation that
|
||||
* wraps an optimizer instance and provides a "multi-start" feature: it calls
|
||||
* the underlying optimizer several times with different starting points and
|
||||
* returns the best optimum found, or all optima if so desired.
|
||||
* This could be useful to avoid being trapped in a local extremum.
|
||||
* </p>
|
||||
*/
|
|
@ -0,0 +1,287 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.univariate;
|
||||
|
||||
import org.apache.commons.math3.util.Incrementor;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||
import org.apache.commons.math3.exception.MaxCountExceededException;
|
||||
import org.apache.commons.math3.analysis.UnivariateFunction;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
|
||||
/**
|
||||
* Provide an interval that brackets a local optimum of a function.
|
||||
* This code is based on a Python implementation (from <em>SciPy</em>,
|
||||
* module {@code optimize.py} v0.5).
|
||||
*
|
||||
* @version $Id: BracketFinder.java 1413186 2012-11-24 13:47:59Z erans $
|
||||
* @since 2.2
|
||||
*/
|
||||
public class BracketFinder {
|
||||
/** Tolerance to avoid division by zero. */
|
||||
private static final double EPS_MIN = 1e-21;
|
||||
/**
|
||||
* Golden section.
|
||||
*/
|
||||
private static final double GOLD = 1.618034;
|
||||
/**
|
||||
* Factor for expanding the interval.
|
||||
*/
|
||||
private final double growLimit;
|
||||
/**
|
||||
* Counter for function evaluations.
|
||||
*/
|
||||
private final Incrementor evaluations = new Incrementor();
|
||||
/**
|
||||
* Lower bound of the bracket.
|
||||
*/
|
||||
private double lo;
|
||||
/**
|
||||
* Higher bound of the bracket.
|
||||
*/
|
||||
private double hi;
|
||||
/**
|
||||
* Point inside the bracket.
|
||||
*/
|
||||
private double mid;
|
||||
/**
|
||||
* Function value at {@link #lo}.
|
||||
*/
|
||||
private double fLo;
|
||||
/**
|
||||
* Function value at {@link #hi}.
|
||||
*/
|
||||
private double fHi;
|
||||
/**
|
||||
* Function value at {@link #mid}.
|
||||
*/
|
||||
private double fMid;
|
||||
|
||||
/**
|
||||
* Constructor with default values {@code 100, 50} (see the
|
||||
* {@link #BracketFinder(double,int) other constructor}).
|
||||
*/
|
||||
public BracketFinder() {
|
||||
this(100, 50);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a bracketing interval finder.
|
||||
*
|
||||
* @param growLimit Expanding factor.
|
||||
* @param maxEvaluations Maximum number of evaluations allowed for finding
|
||||
* a bracketing interval.
|
||||
*/
|
||||
public BracketFinder(double growLimit,
|
||||
int maxEvaluations) {
|
||||
if (growLimit <= 0) {
|
||||
throw new NotStrictlyPositiveException(growLimit);
|
||||
}
|
||||
if (maxEvaluations <= 0) {
|
||||
throw new NotStrictlyPositiveException(maxEvaluations);
|
||||
}
|
||||
|
||||
this.growLimit = growLimit;
|
||||
evaluations.setMaximalCount(maxEvaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
* Search new points that bracket a local optimum of the function.
|
||||
*
|
||||
* @param func Function whose optimum should be bracketed.
|
||||
* @param goal {@link GoalType Goal type}.
|
||||
* @param xA Initial point.
|
||||
* @param xB Initial point.
|
||||
* @throws TooManyEvaluationsException if the maximum number of evaluations
|
||||
* is exceeded.
|
||||
*/
|
||||
public void search(UnivariateFunction func, GoalType goal, double xA, double xB) {
|
||||
evaluations.resetCount();
|
||||
final boolean isMinim = goal == GoalType.MINIMIZE;
|
||||
|
||||
double fA = eval(func, xA);
|
||||
double fB = eval(func, xB);
|
||||
if (isMinim ?
|
||||
fA < fB :
|
||||
fA > fB) {
|
||||
|
||||
double tmp = xA;
|
||||
xA = xB;
|
||||
xB = tmp;
|
||||
|
||||
tmp = fA;
|
||||
fA = fB;
|
||||
fB = tmp;
|
||||
}
|
||||
|
||||
double xC = xB + GOLD * (xB - xA);
|
||||
double fC = eval(func, xC);
|
||||
|
||||
while (isMinim ? fC < fB : fC > fB) {
|
||||
double tmp1 = (xB - xA) * (fB - fC);
|
||||
double tmp2 = (xB - xC) * (fB - fA);
|
||||
|
||||
double val = tmp2 - tmp1;
|
||||
double denom = Math.abs(val) < EPS_MIN ? 2 * EPS_MIN : 2 * val;
|
||||
|
||||
double w = xB - ((xB - xC) * tmp2 - (xB - xA) * tmp1) / denom;
|
||||
double wLim = xB + growLimit * (xC - xB);
|
||||
|
||||
double fW;
|
||||
if ((w - xC) * (xB - w) > 0) {
|
||||
fW = eval(func, w);
|
||||
if (isMinim ?
|
||||
fW < fC :
|
||||
fW > fC) {
|
||||
xA = xB;
|
||||
xB = w;
|
||||
fA = fB;
|
||||
fB = fW;
|
||||
break;
|
||||
} else if (isMinim ?
|
||||
fW > fB :
|
||||
fW < fB) {
|
||||
xC = w;
|
||||
fC = fW;
|
||||
break;
|
||||
}
|
||||
w = xC + GOLD * (xC - xB);
|
||||
fW = eval(func, w);
|
||||
} else if ((w - wLim) * (wLim - xC) >= 0) {
|
||||
w = wLim;
|
||||
fW = eval(func, w);
|
||||
} else if ((w - wLim) * (xC - w) > 0) {
|
||||
fW = eval(func, w);
|
||||
if (isMinim ?
|
||||
fW < fC :
|
||||
fW > fC) {
|
||||
xB = xC;
|
||||
xC = w;
|
||||
w = xC + GOLD * (xC - xB);
|
||||
fB = fC;
|
||||
fC =fW;
|
||||
fW = eval(func, w);
|
||||
}
|
||||
} else {
|
||||
w = xC + GOLD * (xC - xB);
|
||||
fW = eval(func, w);
|
||||
}
|
||||
|
||||
xA = xB;
|
||||
fA = fB;
|
||||
xB = xC;
|
||||
fB = fC;
|
||||
xC = w;
|
||||
fC = fW;
|
||||
}
|
||||
|
||||
lo = xA;
|
||||
fLo = fA;
|
||||
mid = xB;
|
||||
fMid = fB;
|
||||
hi = xC;
|
||||
fHi = fC;
|
||||
|
||||
if (lo > hi) {
|
||||
double tmp = lo;
|
||||
lo = hi;
|
||||
hi = tmp;
|
||||
|
||||
tmp = fLo;
|
||||
fLo = fHi;
|
||||
fHi = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the number of evalutations.
|
||||
*/
|
||||
public int getMaxEvaluations() {
|
||||
return evaluations.getMaximalCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the number of evalutations.
|
||||
*/
|
||||
public int getEvaluations() {
|
||||
return evaluations.getCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the lower bound of the bracket.
|
||||
* @see #getFLo()
|
||||
*/
|
||||
public double getLo() {
|
||||
return lo;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get function value at {@link #getLo()}.
|
||||
* @return function value at {@link #getLo()}
|
||||
*/
|
||||
public double getFLo() {
|
||||
return fLo;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the higher bound of the bracket.
|
||||
* @see #getFHi()
|
||||
*/
|
||||
public double getHi() {
|
||||
return hi;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get function value at {@link #getHi()}.
|
||||
* @return function value at {@link #getHi()}
|
||||
*/
|
||||
public double getFHi() {
|
||||
return fHi;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a point in the middle of the bracket.
|
||||
* @see #getFMid()
|
||||
*/
|
||||
public double getMid() {
|
||||
return mid;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get function value at {@link #getMid()}.
|
||||
* @return function value at {@link #getMid()}
|
||||
*/
|
||||
public double getFMid() {
|
||||
return fMid;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param f Function.
|
||||
* @param x Argument.
|
||||
* @return {@code f(x)}
|
||||
* @throws TooManyEvaluationsException if the maximal number of evaluations is
|
||||
* exceeded.
|
||||
*/
|
||||
private double eval(UnivariateFunction f, double x) {
|
||||
try {
|
||||
evaluations.incrementCount();
|
||||
} catch (MaxCountExceededException e) {
|
||||
throw new TooManyEvaluationsException(e.getMax());
|
||||
}
|
||||
return f.value(x);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,317 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.univariate;
|
||||
|
||||
import org.apache.commons.math3.util.Precision;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
|
||||
/**
|
||||
* For a function defined on some interval {@code (lo, hi)}, this class
|
||||
* finds an approximation {@code x} to the point at which the function
|
||||
* attains its minimum.
|
||||
* It implements Richard Brent's algorithm (from his book "Algorithms for
|
||||
* Minimization without Derivatives", p. 79) for finding minima of real
|
||||
* univariate functions.
|
||||
* <br/>
|
||||
* This code is an adaptation, partly based on the Python code from SciPy
|
||||
* (module "optimize.py" v0.5); the original algorithm is also modified
|
||||
* <ul>
|
||||
* <li>to use an initial guess provided by the user,</li>
|
||||
* <li>to ensure that the best point encountered is the one returned.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @version $Id: BrentOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
|
||||
* @since 2.0
|
||||
*/
|
||||
public class BrentOptimizer extends UnivariateOptimizer {
|
||||
/**
|
||||
* Golden section.
|
||||
*/
|
||||
private static final double GOLDEN_SECTION = 0.5 * (3 - FastMath.sqrt(5));
|
||||
/**
|
||||
* Minimum relative tolerance.
|
||||
*/
|
||||
private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
|
||||
/**
|
||||
* Relative threshold.
|
||||
*/
|
||||
private final double relativeThreshold;
|
||||
/**
|
||||
* Absolute threshold.
|
||||
*/
|
||||
private final double absoluteThreshold;
|
||||
|
||||
/**
|
||||
* The arguments are used implement the original stopping criterion
|
||||
* of Brent's algorithm.
|
||||
* {@code abs} and {@code rel} define a tolerance
|
||||
* {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
|
||||
* <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
|
||||
* where <em>macheps</em> is the relative machine precision. {@code abs} must
|
||||
* be positive.
|
||||
*
|
||||
* @param rel Relative threshold.
|
||||
* @param abs Absolute threshold.
|
||||
* @param checker Additional, user-defined, convergence checking
|
||||
* procedure.
|
||||
* @throws NotStrictlyPositiveException if {@code abs <= 0}.
|
||||
* @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
|
||||
*/
|
||||
public BrentOptimizer(double rel,
|
||||
double abs,
|
||||
ConvergenceChecker<UnivariatePointValuePair> checker) {
|
||||
super(checker);
|
||||
|
||||
if (rel < MIN_RELATIVE_TOLERANCE) {
|
||||
throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
|
||||
}
|
||||
if (abs <= 0) {
|
||||
throw new NotStrictlyPositiveException(abs);
|
||||
}
|
||||
|
||||
relativeThreshold = rel;
|
||||
absoluteThreshold = abs;
|
||||
}
|
||||
|
||||
/**
|
||||
* The arguments are used for implementing the original stopping criterion
|
||||
* of Brent's algorithm.
|
||||
* {@code abs} and {@code rel} define a tolerance
|
||||
* {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
|
||||
* <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
|
||||
* where <em>macheps</em> is the relative machine precision. {@code abs} must
|
||||
* be positive.
|
||||
*
|
||||
* @param rel Relative threshold.
|
||||
* @param abs Absolute threshold.
|
||||
* @throws NotStrictlyPositiveException if {@code abs <= 0}.
|
||||
* @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
|
||||
*/
|
||||
public BrentOptimizer(double rel,
|
||||
double abs) {
|
||||
this(rel, abs, null);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
protected UnivariatePointValuePair doOptimize() {
|
||||
final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
|
||||
final double lo = getMin();
|
||||
final double mid = getStartValue();
|
||||
final double hi = getMax();
|
||||
|
||||
// Optional additional convergence criteria.
|
||||
final ConvergenceChecker<UnivariatePointValuePair> checker
|
||||
= getConvergenceChecker();
|
||||
|
||||
double a;
|
||||
double b;
|
||||
if (lo < hi) {
|
||||
a = lo;
|
||||
b = hi;
|
||||
} else {
|
||||
a = hi;
|
||||
b = lo;
|
||||
}
|
||||
|
||||
double x = mid;
|
||||
double v = x;
|
||||
double w = x;
|
||||
double d = 0;
|
||||
double e = 0;
|
||||
double fx = computeObjectiveValue(x);
|
||||
if (!isMinim) {
|
||||
fx = -fx;
|
||||
}
|
||||
double fv = fx;
|
||||
double fw = fx;
|
||||
|
||||
UnivariatePointValuePair previous = null;
|
||||
UnivariatePointValuePair current
|
||||
= new UnivariatePointValuePair(x, isMinim ? fx : -fx);
|
||||
// Best point encountered so far (which is the initial guess).
|
||||
UnivariatePointValuePair best = current;
|
||||
|
||||
int iter = 0;
|
||||
while (true) {
|
||||
final double m = 0.5 * (a + b);
|
||||
final double tol1 = relativeThreshold * FastMath.abs(x) + absoluteThreshold;
|
||||
final double tol2 = 2 * tol1;
|
||||
|
||||
// Default stopping criterion.
|
||||
final boolean stop = FastMath.abs(x - m) <= tol2 - 0.5 * (b - a);
|
||||
if (!stop) {
|
||||
double p = 0;
|
||||
double q = 0;
|
||||
double r = 0;
|
||||
double u = 0;
|
||||
|
||||
if (FastMath.abs(e) > tol1) { // Fit parabola.
|
||||
r = (x - w) * (fx - fv);
|
||||
q = (x - v) * (fx - fw);
|
||||
p = (x - v) * q - (x - w) * r;
|
||||
q = 2 * (q - r);
|
||||
|
||||
if (q > 0) {
|
||||
p = -p;
|
||||
} else {
|
||||
q = -q;
|
||||
}
|
||||
|
||||
r = e;
|
||||
e = d;
|
||||
|
||||
if (p > q * (a - x) &&
|
||||
p < q * (b - x) &&
|
||||
FastMath.abs(p) < FastMath.abs(0.5 * q * r)) {
|
||||
// Parabolic interpolation step.
|
||||
d = p / q;
|
||||
u = x + d;
|
||||
|
||||
// f must not be evaluated too close to a or b.
|
||||
if (u - a < tol2 || b - u < tol2) {
|
||||
if (x <= m) {
|
||||
d = tol1;
|
||||
} else {
|
||||
d = -tol1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Golden section step.
|
||||
if (x < m) {
|
||||
e = b - x;
|
||||
} else {
|
||||
e = a - x;
|
||||
}
|
||||
d = GOLDEN_SECTION * e;
|
||||
}
|
||||
} else {
|
||||
// Golden section step.
|
||||
if (x < m) {
|
||||
e = b - x;
|
||||
} else {
|
||||
e = a - x;
|
||||
}
|
||||
d = GOLDEN_SECTION * e;
|
||||
}
|
||||
|
||||
// Update by at least "tol1".
|
||||
if (FastMath.abs(d) < tol1) {
|
||||
if (d >= 0) {
|
||||
u = x + tol1;
|
||||
} else {
|
||||
u = x - tol1;
|
||||
}
|
||||
} else {
|
||||
u = x + d;
|
||||
}
|
||||
|
||||
double fu = computeObjectiveValue(u);
|
||||
if (!isMinim) {
|
||||
fu = -fu;
|
||||
}
|
||||
|
||||
// User-defined convergence checker.
|
||||
previous = current;
|
||||
current = new UnivariatePointValuePair(u, isMinim ? fu : -fu);
|
||||
best = best(best,
|
||||
best(previous,
|
||||
current,
|
||||
isMinim),
|
||||
isMinim);
|
||||
|
||||
if (checker != null) {
|
||||
if (checker.converged(iter, previous, current)) {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
// Update a, b, v, w and x.
|
||||
if (fu <= fx) {
|
||||
if (u < x) {
|
||||
b = x;
|
||||
} else {
|
||||
a = x;
|
||||
}
|
||||
v = w;
|
||||
fv = fw;
|
||||
w = x;
|
||||
fw = fx;
|
||||
x = u;
|
||||
fx = fu;
|
||||
} else {
|
||||
if (u < x) {
|
||||
a = u;
|
||||
} else {
|
||||
b = u;
|
||||
}
|
||||
if (fu <= fw ||
|
||||
Precision.equals(w, x)) {
|
||||
v = w;
|
||||
fv = fw;
|
||||
w = u;
|
||||
fw = fu;
|
||||
} else if (fu <= fv ||
|
||||
Precision.equals(v, x) ||
|
||||
Precision.equals(v, w)) {
|
||||
v = u;
|
||||
fv = fu;
|
||||
}
|
||||
}
|
||||
} else { // Default termination (Brent's criterion).
|
||||
return best(best,
|
||||
best(previous,
|
||||
current,
|
||||
isMinim),
|
||||
isMinim);
|
||||
}
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects the best of two points.
|
||||
*
|
||||
* @param a Point and value.
|
||||
* @param b Point and value.
|
||||
* @param isMinim {@code true} if the selected point must be the one with
|
||||
* the lowest value.
|
||||
* @return the best point, or {@code null} if {@code a} and {@code b} are
|
||||
* both {@code null}. When {@code a} and {@code b} have the same function
|
||||
* value, {@code a} is returned.
|
||||
*/
|
||||
private UnivariatePointValuePair best(UnivariatePointValuePair a,
|
||||
UnivariatePointValuePair b,
|
||||
boolean isMinim) {
|
||||
if (a == null) {
|
||||
return b;
|
||||
}
|
||||
if (b == null) {
|
||||
return a;
|
||||
}
|
||||
|
||||
if (isMinim) {
|
||||
return a.getValue() <= b.getValue() ? a : b;
|
||||
} else {
|
||||
return a.getValue() >= b.getValue() ? a : b;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,235 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim.univariate;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import org.apache.commons.math3.exception.MathIllegalStateException;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.NullArgumentException;
|
||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.optim.MaxEval;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
|
||||
/**
|
||||
* Special implementation of the {@link UnivariateOptimizer} interface
|
||||
* adding multi-start features to an existing optimizer.
|
||||
* <br/>
|
||||
* This class wraps an optimizer in order to use it several times in
|
||||
* turn with different starting points (trying to avoid being trapped
|
||||
* in a local extremum when looking for a global one).
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.0
|
||||
*/
|
||||
public class MultiStartUnivariateOptimizer
|
||||
extends UnivariateOptimizer {
|
||||
/** Underlying classical optimizer. */
|
||||
private final UnivariateOptimizer optimizer;
|
||||
/** Number of evaluations already performed for all starts. */
|
||||
private int totalEvaluations;
|
||||
/** Number of starts to go. */
|
||||
private int starts;
|
||||
/** Random generator for multi-start. */
|
||||
private RandomGenerator generator;
|
||||
/** Found optima. */
|
||||
private UnivariatePointValuePair[] optima;
|
||||
/** Optimization data. */
|
||||
private OptimizationData[] optimData;
|
||||
/**
|
||||
* Location in {@link #optimData} where the updated maximum
|
||||
* number of evaluations will be stored.
|
||||
*/
|
||||
private int maxEvalIndex = -1;
|
||||
/**
|
||||
* Location in {@link #optimData} where the updated start value
|
||||
* will be stored.
|
||||
*/
|
||||
private int searchIntervalIndex = -1;
|
||||
|
||||
/**
|
||||
* Create a multi-start optimizer from a single-start optimizer.
|
||||
*
|
||||
* @param optimizer Single-start optimizer to wrap.
|
||||
* @param starts Number of starts to perform. If {@code starts == 1},
|
||||
* the {@code optimize} methods will return the same solution as
|
||||
* {@code optimizer} would.
|
||||
* @param generator Random generator to use for restarts.
|
||||
* @throws NullArgumentException if {@code optimizer} or {@code generator}
|
||||
* is {@code null}.
|
||||
* @throws NotStrictlyPositiveException if {@code starts < 1}.
|
||||
*/
|
||||
public MultiStartUnivariateOptimizer(final UnivariateOptimizer optimizer,
|
||||
final int starts,
|
||||
final RandomGenerator generator) {
|
||||
super(optimizer.getConvergenceChecker());
|
||||
|
||||
if (optimizer == null ||
|
||||
generator == null) {
|
||||
throw new NullArgumentException();
|
||||
}
|
||||
if (starts < 1) {
|
||||
throw new NotStrictlyPositiveException(starts);
|
||||
}
|
||||
|
||||
this.optimizer = optimizer;
|
||||
this.starts = starts;
|
||||
this.generator = generator;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public int getEvaluations() {
|
||||
return totalEvaluations;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all the optima found during the last call to {@code optimize}.
|
||||
* The optimizer stores all the optima found during a set of
|
||||
* restarts. The {@code optimize} method returns the best point only.
|
||||
* This method returns all the points found at the end of each starts,
|
||||
* including the best one already returned by the {@code optimize} method.
|
||||
* <br/>
|
||||
* The returned array as one element for each start as specified
|
||||
* in the constructor. It is ordered with the results from the
|
||||
* runs that did converge first, sorted from best to worst
|
||||
* objective value (i.e in ascending order if minimizing and in
|
||||
* descending order if maximizing), followed by {@code null} elements
|
||||
* corresponding to the runs that did not converge. This means all
|
||||
* elements will be {@code null} if the {@code optimize} method did throw
|
||||
* an exception.
|
||||
* This also means that if the first element is not {@code null}, it is
|
||||
* the best point found across all starts.
|
||||
*
|
||||
* @return an array containing the optima.
|
||||
* @throws MathIllegalStateException if {@link #optimize(OptimizationData[])
|
||||
* optimize} has not been called.
|
||||
*/
|
||||
public UnivariatePointValuePair[] getOptima() {
|
||||
if (optima == null) {
|
||||
throw new MathIllegalStateException(LocalizedFormats.NO_OPTIMUM_COMPUTED_YET);
|
||||
}
|
||||
return optima.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @throws MathIllegalStateException if {@code optData} does not contain an
|
||||
* instance of {@link MaxEval} or {@link SearchInterval}.
|
||||
*/
|
||||
@Override
|
||||
public UnivariatePointValuePair optimize(OptimizationData... optData) {
|
||||
// Store arguments in order to pass them to the internal optimizer.
|
||||
optimData = optData;
|
||||
// Set up base class and perform computations.
|
||||
return super.optimize(optData);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
protected UnivariatePointValuePair doOptimize() {
|
||||
// Remove all instances of "MaxEval" and "SearchInterval" from the
|
||||
// array that will be passed to the internal optimizer.
|
||||
// The former is to enforce smaller numbers of allowed evaluations
|
||||
// (according to how many have been used up already), and the latter
|
||||
// to impose a different start value for each start.
|
||||
for (int i = 0; i < optimData.length; i++) {
|
||||
if (optimData[i] instanceof MaxEval) {
|
||||
optimData[i] = null;
|
||||
maxEvalIndex = i;
|
||||
continue;
|
||||
}
|
||||
if (optimData[i] instanceof SearchInterval) {
|
||||
optimData[i] = null;
|
||||
searchIntervalIndex = i;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (maxEvalIndex == -1) {
|
||||
throw new MathIllegalStateException();
|
||||
}
|
||||
if (searchIntervalIndex == -1) {
|
||||
throw new MathIllegalStateException();
|
||||
}
|
||||
|
||||
RuntimeException lastException = null;
|
||||
optima = new UnivariatePointValuePair[starts];
|
||||
totalEvaluations = 0;
|
||||
|
||||
final int maxEval = getMaxEvaluations();
|
||||
final double min = getMin();
|
||||
final double max = getMax();
|
||||
final double startValue = getStartValue();
|
||||
|
||||
// Multi-start loop.
|
||||
for (int i = 0; i < starts; i++) {
|
||||
// CHECKSTYLE: stop IllegalCatch
|
||||
try {
|
||||
// Decrease number of allowed evaluations.
|
||||
optimData[maxEvalIndex] = new MaxEval(maxEval - totalEvaluations);
|
||||
// New start value.
|
||||
final double s = (i == 0) ?
|
||||
startValue :
|
||||
min + generator.nextDouble() * (max - min);
|
||||
optimData[searchIntervalIndex] = new SearchInterval(min, max, s);
|
||||
// Optimize.
|
||||
optima[i] = optimizer.optimize(optimData);
|
||||
} catch (RuntimeException mue) {
|
||||
lastException = mue;
|
||||
optima[i] = null;
|
||||
}
|
||||
// CHECKSTYLE: resume IllegalCatch
|
||||
|
||||
totalEvaluations += optimizer.getEvaluations();
|
||||
}
|
||||
|
||||
sortPairs(getGoalType());
|
||||
|
||||
if (optima[0] == null) {
|
||||
throw lastException; // Cannot be null if starts >= 1.
|
||||
}
|
||||
|
||||
// Return the point with the best objective function value.
|
||||
return optima[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Sort the optima from best to worst, followed by {@code null} elements.
|
||||
*
|
||||
* @param goal Goal type.
|
||||
*/
|
||||
private void sortPairs(final GoalType goal) {
|
||||
Arrays.sort(optima, new Comparator<UnivariatePointValuePair>() {
|
||||
public int compare(final UnivariatePointValuePair o1,
|
||||
final UnivariatePointValuePair o2) {
|
||||
if (o1 == null) {
|
||||
return (o2 == null) ? 0 : 1;
|
||||
} else if (o2 == null) {
|
||||
return -1;
|
||||
}
|
||||
final double v1 = o1.getValue();
|
||||
final double v2 = o2.getValue();
|
||||
return (goal == GoalType.MINIMIZE) ?
|
||||
Double.compare(v1, v2) : Double.compare(v2, v1);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.univariate;
|
||||
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
|
||||
/**
|
||||
* Search interval and (optional) start value.
|
||||
* <br/>
|
||||
* Immutable class.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public class SearchInterval implements OptimizationData {
|
||||
/** Lower bound. */
|
||||
private final double lower;
|
||||
/** Upper bound. */
|
||||
private final double upper;
|
||||
/** Start value. */
|
||||
private final double start;
|
||||
|
||||
/**
|
||||
* @param lo Lower bound.
|
||||
* @param hi Upper bound.
|
||||
* @param init Start value.
|
||||
* @throws NumberIsTooLargeException if {@code lo >= hi}.
|
||||
* @throws OutOfRangeException if {@code init < lo} or {@code init > hi}.
|
||||
*/
|
||||
public SearchInterval(double lo,
|
||||
double hi,
|
||||
double init) {
|
||||
if (lo >= hi) {
|
||||
throw new NumberIsTooLargeException(lo, hi, false);
|
||||
}
|
||||
if (init < lo ||
|
||||
init > hi) {
|
||||
throw new OutOfRangeException(init, lo, hi);
|
||||
}
|
||||
|
||||
lower = lo;
|
||||
upper = hi;
|
||||
start = init;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param lo Lower bound.
|
||||
* @param hi Upper bound.
|
||||
* @throws NumberIsTooLargeException if {@code lo >= hi}.
|
||||
*/
|
||||
public SearchInterval(double lo,
|
||||
double hi) {
|
||||
this(lo, hi, 0.5 * (lo + hi));
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the lower bound.
|
||||
*
|
||||
* @return the lower bound.
|
||||
*/
|
||||
public double getMin() {
|
||||
return lower;
|
||||
}
|
||||
/**
|
||||
* Gets the upper bound.
|
||||
*
|
||||
* @return the upper bound.
|
||||
*/
|
||||
public double getMax() {
|
||||
return upper;
|
||||
}
|
||||
/**
|
||||
* Gets the start value.
|
||||
*
|
||||
* @return the start value.
|
||||
*/
|
||||
public double getStartValue() {
|
||||
return start;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.univariate;
|
||||
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.optim.AbstractConvergenceChecker;
|
||||
|
||||
/**
|
||||
* Simple implementation of the
|
||||
* {@link org.apache.commons.math3.optimization.ConvergenceChecker} interface
|
||||
* that uses only objective function values.
|
||||
*
|
||||
* Convergence is considered to have been reached if either the relative
|
||||
* difference between the objective function values is smaller than a
|
||||
* threshold or if either the absolute difference between the objective
|
||||
* function values is smaller than another threshold.
|
||||
* <br/>
|
||||
* The {@link #converged(int,UnivariatePointValuePair,UnivariatePointValuePair)
|
||||
* converged} method will also return {@code true} if the number of iterations
|
||||
* has been set (see {@link #SimpleUnivariateValueChecker(double,double,int)
|
||||
* this constructor}).
|
||||
*
|
||||
* @version $Id: SimpleUnivariateValueChecker.java 1413171 2012-11-24 11:11:10Z erans $
|
||||
* @since 3.1
|
||||
*/
|
||||
public class SimpleUnivariateValueChecker
|
||||
extends AbstractConvergenceChecker<UnivariatePointValuePair> {
|
||||
/**
|
||||
* If {@link #maxIterationCount} is set to this value, the number of
|
||||
* iterations will never cause
|
||||
* {@link #converged(int,UnivariatePointValuePair,UnivariatePointValuePair)}
|
||||
* to return {@code true}.
|
||||
*/
|
||||
private static final int ITERATION_CHECK_DISABLED = -1;
|
||||
/**
|
||||
* Number of iterations after which the
|
||||
* {@link #converged(int,UnivariatePointValuePair,UnivariatePointValuePair)}
|
||||
* method will return true (unless the check is disabled).
|
||||
*/
|
||||
private final int maxIterationCount;
|
||||
|
||||
/** Build an instance with specified thresholds.
|
||||
*
|
||||
* In order to perform only relative checks, the absolute tolerance
|
||||
* must be set to a negative value. In order to perform only absolute
|
||||
* checks, the relative tolerance must be set to a negative value.
|
||||
*
|
||||
* @param relativeThreshold relative tolerance threshold
|
||||
* @param absoluteThreshold absolute tolerance threshold
|
||||
*/
|
||||
public SimpleUnivariateValueChecker(final double relativeThreshold,
|
||||
final double absoluteThreshold) {
|
||||
super(relativeThreshold, absoluteThreshold);
|
||||
maxIterationCount = ITERATION_CHECK_DISABLED;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an instance with specified thresholds.
|
||||
*
|
||||
* In order to perform only relative checks, the absolute tolerance
|
||||
* must be set to a negative value. In order to perform only absolute
|
||||
* checks, the relative tolerance must be set to a negative value.
|
||||
*
|
||||
* @param relativeThreshold relative tolerance threshold
|
||||
* @param absoluteThreshold absolute tolerance threshold
|
||||
* @param maxIter Maximum iteration count.
|
||||
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
|
||||
*
|
||||
* @since 3.1
|
||||
*/
|
||||
public SimpleUnivariateValueChecker(final double relativeThreshold,
|
||||
final double absoluteThreshold,
|
||||
final int maxIter) {
|
||||
super(relativeThreshold, absoluteThreshold);
|
||||
|
||||
if (maxIter <= 0) {
|
||||
throw new NotStrictlyPositiveException(maxIter);
|
||||
}
|
||||
maxIterationCount = maxIter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the optimization algorithm has converged considering the
|
||||
* last two points.
|
||||
* This method may be called several time from the same algorithm
|
||||
* iteration with different points. This can be detected by checking the
|
||||
* iteration number at each call if needed. Each time this method is
|
||||
* called, the previous and current point correspond to points with the
|
||||
* same role at each iteration, so they can be compared. As an example,
|
||||
* simplex-based algorithms call this method for all points of the simplex,
|
||||
* not only for the best or worst ones.
|
||||
*
|
||||
* @param iteration Index of current iteration
|
||||
* @param previous Best point in the previous iteration.
|
||||
* @param current Best point in the current iteration.
|
||||
* @return {@code true} if the algorithm has converged.
|
||||
*/
|
||||
@Override
|
||||
public boolean converged(final int iteration,
|
||||
final UnivariatePointValuePair previous,
|
||||
final UnivariatePointValuePair current) {
|
||||
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
|
||||
if (iteration >= maxIterationCount) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
final double p = previous.getValue();
|
||||
final double c = current.getValue();
|
||||
final double difference = FastMath.abs(p - c);
|
||||
final double size = FastMath.max(FastMath.abs(p), FastMath.abs(c));
|
||||
return difference <= size * getRelativeThreshold() ||
|
||||
difference <= getAbsoluteThreshold();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.univariate;
|
||||
|
||||
import org.apache.commons.math3.analysis.UnivariateFunction;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
|
||||
/**
|
||||
* Scalar function to be optimized.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public class UnivariateObjectiveFunction implements OptimizationData {
|
||||
/** Function to be optimized. */
|
||||
private final UnivariateFunction function;
|
||||
|
||||
/**
|
||||
* @param f Function to be optimized.
|
||||
*/
|
||||
public UnivariateObjectiveFunction(UnivariateFunction f) {
|
||||
function = f;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the function to be optimized.
|
||||
*
|
||||
* @return the objective function.
|
||||
*/
|
||||
public UnivariateFunction getObjectiveFunction() {
|
||||
return function;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,148 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.univariate;
|
||||
|
||||
import org.apache.commons.math3.analysis.UnivariateFunction;
|
||||
import org.apache.commons.math3.optim.BaseOptimizer;
|
||||
import org.apache.commons.math3.optim.OptimizationData;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.ConvergenceChecker;
|
||||
import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||
|
||||
/**
|
||||
* Base class for a univariate scalar function optimizer.
|
||||
*
|
||||
* @version $Id$
|
||||
* @since 3.1
|
||||
*/
|
||||
public abstract class UnivariateOptimizer
|
||||
extends BaseOptimizer<UnivariatePointValuePair> {
|
||||
/** Objective function. */
|
||||
private UnivariateFunction function;
|
||||
/** Type of optimization. */
|
||||
private GoalType goal;
|
||||
/** Initial guess. */
|
||||
private double start;
|
||||
/** Lower bound. */
|
||||
private double min;
|
||||
/** Upper bound. */
|
||||
private double max;
|
||||
|
||||
/**
|
||||
* @param checker Convergence checker.
|
||||
*/
|
||||
protected UnivariateOptimizer(ConvergenceChecker<UnivariatePointValuePair> checker) {
|
||||
super(checker);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link GoalType}</li>
|
||||
* <li>{@link SearchInterval}</li>
|
||||
* <li>{@link UnivariateObjectiveFunction}</li>
|
||||
* </ul>
|
||||
* @return {@inheritDoc}
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations is exceeded.
|
||||
*/
|
||||
public UnivariatePointValuePair optimize(OptimizationData... optData)
|
||||
throws TooManyEvaluationsException {
|
||||
// Retrieve settings.
|
||||
parseOptimizationData(optData);
|
||||
// Perform computation.
|
||||
return super.optimize(optData);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the optimization type.
|
||||
*/
|
||||
public GoalType getGoalType() {
|
||||
return goal;
|
||||
}
|
||||
|
||||
/**
|
||||
* Scans the list of (required and optional) optimization data that
|
||||
* characterize the problem.
|
||||
*
|
||||
* @param optData Optimization data.
|
||||
* The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link GoalType}</li>
|
||||
* <li>{@link SearchInterval}</li>
|
||||
* <li>{@link UnivariateObjectiveFunction}</li>
|
||||
* </ul>
|
||||
*/
|
||||
private void parseOptimizationData(OptimizationData... optData) {
|
||||
// The existing values (as set by the previous call) are reused if
|
||||
// not provided in the argument list.
|
||||
for (OptimizationData data : optData) {
|
||||
if (data instanceof SearchInterval) {
|
||||
final SearchInterval interval = (SearchInterval) data;
|
||||
min = interval.getMin();
|
||||
max = interval.getMax();
|
||||
start = interval.getStartValue();
|
||||
continue;
|
||||
}
|
||||
if (data instanceof UnivariateObjectiveFunction) {
|
||||
function = ((UnivariateObjectiveFunction) data).getObjectiveFunction();
|
||||
continue;
|
||||
}
|
||||
if (data instanceof GoalType) {
|
||||
goal = (GoalType) data;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the initial guess.
|
||||
*/
|
||||
public double getStartValue() {
|
||||
return start;
|
||||
}
|
||||
/**
|
||||
* @return the lower bounds.
|
||||
*/
|
||||
public double getMin() {
|
||||
return min;
|
||||
}
|
||||
/**
|
||||
* @return the upper bounds.
|
||||
*/
|
||||
public double getMax() {
|
||||
return max;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the objective function value.
|
||||
* This method <em>must</em> be called by subclasses to enforce the
|
||||
* evaluation counter limit.
|
||||
*
|
||||
* @param x Point at which the objective function must be evaluated.
|
||||
* @return the objective function value at the specified point.
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations is exceeded.
|
||||
*/
|
||||
protected double computeObjectiveValue(double x) {
|
||||
super.incrementEvaluationCount();
|
||||
return function.value(x);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim.univariate;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* This class holds a point and the value of an objective function at this
|
||||
* point.
|
||||
* This is a simple immutable container.
|
||||
*
|
||||
* @version $Id: UnivariatePointValuePair.java 1364392 2012-07-22 18:27:12Z tn $
|
||||
* @since 3.0
|
||||
*/
|
||||
public class UnivariatePointValuePair implements Serializable {
|
||||
/** Serializable version identifier. */
|
||||
private static final long serialVersionUID = 1003888396256744753L;
|
||||
/** Point. */
|
||||
private final double point;
|
||||
/** Value of the objective function at the point. */
|
||||
private final double value;
|
||||
|
||||
/**
|
||||
* Build a point/objective function value pair.
|
||||
*
|
||||
* @param point Point.
|
||||
* @param value Value of an objective function at the point
|
||||
*/
|
||||
public UnivariatePointValuePair(final double point,
|
||||
final double value) {
|
||||
this.point = point;
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the point.
|
||||
*
|
||||
* @return the point.
|
||||
*/
|
||||
public double getPoint() {
|
||||
return point;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the value of the objective function.
|
||||
*
|
||||
* @return the stored value of the objective function.
|
||||
*/
|
||||
public double getValue() {
|
||||
return value;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.univariate;
|
||||
|
||||
/**
|
||||
* One-dimensional optimization algorithms.
|
||||
*/
|
|
@ -119,6 +119,7 @@ INVALID_REGRESSION_ARRAY= la longueur du tableau de donn\u00e9es = {0} ne corres
|
|||
INVALID_REGRESSION_OBSERVATION = la longueur du tableau de variables explicatives ({0}) ne correspond pas au nombre de variables dans le mod\u00e8le ({1})
|
||||
INVALID_ROUNDING_METHOD = m\u00e9thode d''arondi {0} invalide, m\u00e9thodes valides : {1} ({2}), {3} ({4}), {5} ({6}), {7} ({8}), {9} ({10}), {11} ({12}), {13} ({14}), {15} ({16})
|
||||
ITERATOR_EXHAUSTED = it\u00e9ration achev\u00e9e
|
||||
ITERATIONS = it\u00e9rations
|
||||
LCM_OVERFLOW_32_BITS = d\u00e9passement de capacit\u00e9 : le MCM de {0} et {1} vaut 2^31
|
||||
LCM_OVERFLOW_64_BITS = d\u00e9passement de capacit\u00e9 : le MCM de {0} et {1} vaut 2^63
|
||||
LIST_OF_CHROMOSOMES_BIGGER_THAN_POPULATION_SIZE = la liste des chromosomes d\u00e9passe maxPopulationSize
|
||||
|
|
|
@ -30,7 +30,7 @@ public class LocalizedFormatsTest {
|
|||
|
||||
@Test
|
||||
public void testMessageNumber() {
|
||||
Assert.assertEquals(311, LocalizedFormats.values().length);
|
||||
Assert.assertEquals(312, LocalizedFormats.values().length);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.fitting;
|
||||
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.LevenbergMarquardtOptimizer;
|
||||
import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class CurveFitterTest {
|
||||
@Test
|
||||
public void testMath303() {
|
||||
LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
|
||||
CurveFitter<ParametricUnivariateFunction> fitter = new CurveFitter<ParametricUnivariateFunction>(optimizer);
|
||||
fitter.addObservedPoint(2.805d, 0.6934785852953367d);
|
||||
fitter.addObservedPoint(2.74333333333333d, 0.6306772025518496d);
|
||||
fitter.addObservedPoint(1.655d, 0.9474675497289684);
|
||||
fitter.addObservedPoint(1.725d, 0.9013594835804194d);
|
||||
|
||||
ParametricUnivariateFunction sif = new SimpleInverseFunction();
|
||||
|
||||
double[] initialguess1 = new double[1];
|
||||
initialguess1[0] = 1.0d;
|
||||
Assert.assertEquals(1, fitter.fit(sif, initialguess1).length);
|
||||
|
||||
double[] initialguess2 = new double[2];
|
||||
initialguess2[0] = 1.0d;
|
||||
initialguess2[1] = .5d;
|
||||
Assert.assertEquals(2, fitter.fit(sif, initialguess2).length);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath304() {
|
||||
LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
|
||||
CurveFitter<ParametricUnivariateFunction> fitter = new CurveFitter<ParametricUnivariateFunction>(optimizer);
|
||||
fitter.addObservedPoint(2.805d, 0.6934785852953367d);
|
||||
fitter.addObservedPoint(2.74333333333333d, 0.6306772025518496d);
|
||||
fitter.addObservedPoint(1.655d, 0.9474675497289684);
|
||||
fitter.addObservedPoint(1.725d, 0.9013594835804194d);
|
||||
|
||||
ParametricUnivariateFunction sif = new SimpleInverseFunction();
|
||||
|
||||
double[] initialguess1 = new double[1];
|
||||
initialguess1[0] = 1.0d;
|
||||
Assert.assertEquals(1.6357215104109237, fitter.fit(sif, initialguess1)[0], 1.0e-14);
|
||||
|
||||
double[] initialguess2 = new double[1];
|
||||
initialguess2[0] = 10.0d;
|
||||
Assert.assertEquals(1.6357215104109237, fitter.fit(sif, initialguess1)[0], 1.0e-14);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath372() {
|
||||
LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
|
||||
CurveFitter<ParametricUnivariateFunction> curveFitter = new CurveFitter<ParametricUnivariateFunction>(optimizer);
|
||||
|
||||
curveFitter.addObservedPoint( 15, 4443);
|
||||
curveFitter.addObservedPoint( 31, 8493);
|
||||
curveFitter.addObservedPoint( 62, 17586);
|
||||
curveFitter.addObservedPoint(125, 30582);
|
||||
curveFitter.addObservedPoint(250, 45087);
|
||||
curveFitter.addObservedPoint(500, 50683);
|
||||
|
||||
ParametricUnivariateFunction f = new ParametricUnivariateFunction() {
|
||||
public double value(double x, double ... parameters) {
|
||||
double a = parameters[0];
|
||||
double b = parameters[1];
|
||||
double c = parameters[2];
|
||||
double d = parameters[3];
|
||||
|
||||
return d + ((a - d) / (1 + FastMath.pow(x / c, b)));
|
||||
}
|
||||
|
||||
public double[] gradient(double x, double ... parameters) {
|
||||
double a = parameters[0];
|
||||
double b = parameters[1];
|
||||
double c = parameters[2];
|
||||
double d = parameters[3];
|
||||
|
||||
double[] gradients = new double[4];
|
||||
double den = 1 + FastMath.pow(x / c, b);
|
||||
|
||||
// derivative with respect to a
|
||||
gradients[0] = 1 / den;
|
||||
|
||||
// derivative with respect to b
|
||||
// in the reported (invalid) issue, there was a sign error here
|
||||
gradients[1] = -((a - d) * FastMath.pow(x / c, b) * FastMath.log(x / c)) / (den * den);
|
||||
|
||||
// derivative with respect to c
|
||||
gradients[2] = (b * FastMath.pow(x / c, b - 1) * (x / (c * c)) * (a - d)) / (den * den);
|
||||
|
||||
// derivative with respect to d
|
||||
gradients[3] = 1 - (1 / den);
|
||||
|
||||
return gradients;
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
double[] initialGuess = new double[] { 1500, 0.95, 65, 35000 };
|
||||
double[] estimatedParameters = curveFitter.fit(f, initialGuess);
|
||||
|
||||
Assert.assertEquals( 2411.00, estimatedParameters[0], 500.00);
|
||||
Assert.assertEquals( 1.62, estimatedParameters[1], 0.04);
|
||||
Assert.assertEquals( 111.22, estimatedParameters[2], 0.30);
|
||||
Assert.assertEquals(55347.47, estimatedParameters[3], 300.00);
|
||||
Assert.assertTrue(optimizer.getRMS() < 600.0);
|
||||
}
|
||||
|
||||
private static class SimpleInverseFunction implements ParametricUnivariateFunction {
|
||||
|
||||
public double value(double x, double ... parameters) {
|
||||
return parameters[0] / x + (parameters.length < 2 ? 0 : parameters[1]);
|
||||
}
|
||||
|
||||
public double[] gradient(double x, double ... doubles) {
|
||||
double[] gradientVector = new double[doubles.length];
|
||||
gradientVector[0] = 1 / x;
|
||||
if (doubles.length >= 2) {
|
||||
gradientVector[1] = 1;
|
||||
}
|
||||
return gradientVector;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,363 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.fitting;
|
||||
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.LevenbergMarquardtOptimizer;
|
||||
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Tests {@link GaussianFitter}.
|
||||
*
|
||||
* @since 2.2
|
||||
* @version $Id: GaussianFitterTest.java 1349707 2012-06-13 09:30:56Z erans $
|
||||
*/
|
||||
public class GaussianFitterTest {
|
||||
/** Good data. */
|
||||
protected static final double[][] DATASET1 = new double[][] {
|
||||
{4.0254623, 531026.0},
|
||||
{4.02804905, 664002.0},
|
||||
{4.02934242, 787079.0},
|
||||
{4.03128248, 984167.0},
|
||||
{4.03386923, 1294546.0},
|
||||
{4.03580929, 1560230.0},
|
||||
{4.03839603, 1887233.0},
|
||||
{4.0396894, 2113240.0},
|
||||
{4.04162946, 2375211.0},
|
||||
{4.04421621, 2687152.0},
|
||||
{4.04550958, 2862644.0},
|
||||
{4.04744964, 3078898.0},
|
||||
{4.05003639, 3327238.0},
|
||||
{4.05132976, 3461228.0},
|
||||
{4.05326982, 3580526.0},
|
||||
{4.05585657, 3576946.0},
|
||||
{4.05779662, 3439750.0},
|
||||
{4.06038337, 3220296.0},
|
||||
{4.06167674, 3070073.0},
|
||||
{4.0636168, 2877648.0},
|
||||
{4.06620355, 2595848.0},
|
||||
{4.06749692, 2390157.0},
|
||||
{4.06943698, 2175960.0},
|
||||
{4.07202373, 1895104.0},
|
||||
{4.0733171, 1687576.0},
|
||||
{4.07525716, 1447024.0},
|
||||
{4.0778439, 1130879.0},
|
||||
{4.07978396, 904900.0},
|
||||
{4.08237071, 717104.0},
|
||||
{4.08366408, 620014.0}
|
||||
};
|
||||
/** Poor data: right of peak not symmetric with left of peak. */
|
||||
protected static final double[][] DATASET2 = new double[][] {
|
||||
{-20.15, 1523.0},
|
||||
{-19.65, 1566.0},
|
||||
{-19.15, 1592.0},
|
||||
{-18.65, 1927.0},
|
||||
{-18.15, 3089.0},
|
||||
{-17.65, 6068.0},
|
||||
{-17.15, 14239.0},
|
||||
{-16.65, 34124.0},
|
||||
{-16.15, 64097.0},
|
||||
{-15.65, 110352.0},
|
||||
{-15.15, 164742.0},
|
||||
{-14.65, 209499.0},
|
||||
{-14.15, 267274.0},
|
||||
{-13.65, 283290.0},
|
||||
{-13.15, 275363.0},
|
||||
{-12.65, 258014.0},
|
||||
{-12.15, 225000.0},
|
||||
{-11.65, 200000.0},
|
||||
{-11.15, 190000.0},
|
||||
{-10.65, 185000.0},
|
||||
{-10.15, 180000.0},
|
||||
{ -9.65, 179000.0},
|
||||
{ -9.15, 178000.0},
|
||||
{ -8.65, 177000.0},
|
||||
{ -8.15, 176000.0},
|
||||
{ -7.65, 175000.0},
|
||||
{ -7.15, 174000.0},
|
||||
{ -6.65, 173000.0},
|
||||
{ -6.15, 172000.0},
|
||||
{ -5.65, 171000.0},
|
||||
{ -5.15, 170000.0}
|
||||
};
|
||||
/** Poor data: long tails. */
|
||||
protected static final double[][] DATASET3 = new double[][] {
|
||||
{-90.15, 1513.0},
|
||||
{-80.15, 1514.0},
|
||||
{-70.15, 1513.0},
|
||||
{-60.15, 1514.0},
|
||||
{-50.15, 1513.0},
|
||||
{-40.15, 1514.0},
|
||||
{-30.15, 1513.0},
|
||||
{-20.15, 1523.0},
|
||||
{-19.65, 1566.0},
|
||||
{-19.15, 1592.0},
|
||||
{-18.65, 1927.0},
|
||||
{-18.15, 3089.0},
|
||||
{-17.65, 6068.0},
|
||||
{-17.15, 14239.0},
|
||||
{-16.65, 34124.0},
|
||||
{-16.15, 64097.0},
|
||||
{-15.65, 110352.0},
|
||||
{-15.15, 164742.0},
|
||||
{-14.65, 209499.0},
|
||||
{-14.15, 267274.0},
|
||||
{-13.65, 283290.0},
|
||||
{-13.15, 275363.0},
|
||||
{-12.65, 258014.0},
|
||||
{-12.15, 214073.0},
|
||||
{-11.65, 182244.0},
|
||||
{-11.15, 136419.0},
|
||||
{-10.65, 97823.0},
|
||||
{-10.15, 58930.0},
|
||||
{ -9.65, 35404.0},
|
||||
{ -9.15, 16120.0},
|
||||
{ -8.65, 9823.0},
|
||||
{ -8.15, 5064.0},
|
||||
{ -7.65, 2575.0},
|
||||
{ -7.15, 1642.0},
|
||||
{ -6.65, 1101.0},
|
||||
{ -6.15, 812.0},
|
||||
{ -5.65, 690.0},
|
||||
{ -5.15, 565.0},
|
||||
{ 5.15, 564.0},
|
||||
{ 15.15, 565.0},
|
||||
{ 25.15, 564.0},
|
||||
{ 35.15, 565.0},
|
||||
{ 45.15, 564.0},
|
||||
{ 55.15, 565.0},
|
||||
{ 65.15, 564.0},
|
||||
{ 75.15, 565.0}
|
||||
};
|
||||
/** Poor data: right of peak is missing. */
|
||||
protected static final double[][] DATASET4 = new double[][] {
|
||||
{-20.15, 1523.0},
|
||||
{-19.65, 1566.0},
|
||||
{-19.15, 1592.0},
|
||||
{-18.65, 1927.0},
|
||||
{-18.15, 3089.0},
|
||||
{-17.65, 6068.0},
|
||||
{-17.15, 14239.0},
|
||||
{-16.65, 34124.0},
|
||||
{-16.15, 64097.0},
|
||||
{-15.65, 110352.0},
|
||||
{-15.15, 164742.0},
|
||||
{-14.65, 209499.0},
|
||||
{-14.15, 267274.0},
|
||||
{-13.65, 283290.0}
|
||||
};
|
||||
/** Good data, but few points. */
|
||||
protected static final double[][] DATASET5 = new double[][] {
|
||||
{4.0254623, 531026.0},
|
||||
{4.03128248, 984167.0},
|
||||
{4.03839603, 1887233.0},
|
||||
{4.04421621, 2687152.0},
|
||||
{4.05132976, 3461228.0},
|
||||
{4.05326982, 3580526.0},
|
||||
{4.05779662, 3439750.0},
|
||||
{4.0636168, 2877648.0},
|
||||
{4.06943698, 2175960.0},
|
||||
{4.07525716, 1447024.0},
|
||||
{4.08237071, 717104.0},
|
||||
{4.08366408, 620014.0}
|
||||
};
|
||||
|
||||
/**
|
||||
* Basic.
|
||||
*/
|
||||
@Test
|
||||
public void testFit01() {
|
||||
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
|
||||
addDatasetToGaussianFitter(DATASET1, fitter);
|
||||
double[] parameters = fitter.fit();
|
||||
|
||||
Assert.assertEquals(3496978.1837704973, parameters[0], 1e-4);
|
||||
Assert.assertEquals(4.054933085999146, parameters[1], 1e-4);
|
||||
Assert.assertEquals(0.015039355620304326, parameters[2], 1e-4);
|
||||
}
|
||||
|
||||
/**
|
||||
* Zero points is not enough observed points.
|
||||
*/
|
||||
@Test(expected=MathIllegalArgumentException.class)
|
||||
public void testFit02() {
|
||||
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
|
||||
fitter.fit();
|
||||
}
|
||||
|
||||
/**
|
||||
* Two points is not enough observed points.
|
||||
*/
|
||||
@Test(expected=MathIllegalArgumentException.class)
|
||||
public void testFit03() {
|
||||
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
|
||||
addDatasetToGaussianFitter(new double[][] {
|
||||
{4.0254623, 531026.0},
|
||||
{4.02804905, 664002.0}},
|
||||
fitter);
|
||||
fitter.fit();
|
||||
}
|
||||
|
||||
/**
|
||||
* Poor data: right of peak not symmetric with left of peak.
|
||||
*/
|
||||
@Test
|
||||
public void testFit04() {
|
||||
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
|
||||
addDatasetToGaussianFitter(DATASET2, fitter);
|
||||
double[] parameters = fitter.fit();
|
||||
|
||||
Assert.assertEquals(233003.2967252038, parameters[0], 1e-4);
|
||||
Assert.assertEquals(-10.654887521095983, parameters[1], 1e-4);
|
||||
Assert.assertEquals(4.335937353196641, parameters[2], 1e-4);
|
||||
}
|
||||
|
||||
/**
|
||||
* Poor data: long tails.
|
||||
*/
|
||||
@Test
|
||||
public void testFit05() {
|
||||
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
|
||||
addDatasetToGaussianFitter(DATASET3, fitter);
|
||||
double[] parameters = fitter.fit();
|
||||
|
||||
Assert.assertEquals(283863.81929180305, parameters[0], 1e-4);
|
||||
Assert.assertEquals(-13.29641995105174, parameters[1], 1e-4);
|
||||
Assert.assertEquals(1.7297330293549908, parameters[2], 1e-4);
|
||||
}
|
||||
|
||||
/**
|
||||
* Poor data: right of peak is missing.
|
||||
*/
|
||||
@Test
|
||||
public void testFit06() {
|
||||
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
|
||||
addDatasetToGaussianFitter(DATASET4, fitter);
|
||||
double[] parameters = fitter.fit();
|
||||
|
||||
Assert.assertEquals(285250.66754309234, parameters[0], 1e-4);
|
||||
Assert.assertEquals(-13.528375695228455, parameters[1], 1e-4);
|
||||
Assert.assertEquals(1.5204344894331614, parameters[2], 1e-4);
|
||||
}
|
||||
|
||||
/**
|
||||
* Basic with smaller dataset.
|
||||
*/
|
||||
@Test
|
||||
public void testFit07() {
|
||||
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
|
||||
addDatasetToGaussianFitter(DATASET5, fitter);
|
||||
double[] parameters = fitter.fit();
|
||||
|
||||
Assert.assertEquals(3514384.729342235, parameters[0], 1e-4);
|
||||
Assert.assertEquals(4.054970307455625, parameters[1], 1e-4);
|
||||
Assert.assertEquals(0.015029412832160017, parameters[2], 1e-4);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath519() {
|
||||
// The optimizer will try negative sigma values but "GaussianFitter"
|
||||
// will catch the raised exceptions and return NaN values instead.
|
||||
|
||||
final double[] data = {
|
||||
1.1143831578403364E-29,
|
||||
4.95281403484594E-28,
|
||||
1.1171347211930288E-26,
|
||||
1.7044813962636277E-25,
|
||||
1.9784716574832164E-24,
|
||||
1.8630236407866774E-23,
|
||||
1.4820532905097742E-22,
|
||||
1.0241963854632831E-21,
|
||||
6.275077366673128E-21,
|
||||
3.461808994532493E-20,
|
||||
1.7407124684715706E-19,
|
||||
8.056687953553974E-19,
|
||||
3.460193945992071E-18,
|
||||
1.3883326374011525E-17,
|
||||
5.233894983671116E-17,
|
||||
1.8630791465263745E-16,
|
||||
6.288759227922111E-16,
|
||||
2.0204433920597856E-15,
|
||||
6.198768938576155E-15,
|
||||
1.821419346860626E-14,
|
||||
5.139176445538471E-14,
|
||||
1.3956427429045787E-13,
|
||||
3.655705706448139E-13,
|
||||
9.253753324779779E-13,
|
||||
2.267636001476696E-12,
|
||||
5.3880460095836855E-12,
|
||||
1.2431632654852931E-11
|
||||
};
|
||||
|
||||
GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
|
||||
for (int i = 0; i < data.length; i++) {
|
||||
fitter.addObservedPoint(i, data[i]);
|
||||
}
|
||||
final double[] p = fitter.fit();
|
||||
|
||||
Assert.assertEquals(53.1572792, p[1], 1e-7);
|
||||
Assert.assertEquals(5.75214622, p[2], 1e-8);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath798() {
|
||||
final GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
|
||||
|
||||
// When the data points are not commented out below, the fit stalls.
|
||||
// This is expected however, since the whole dataset hardly looks like
|
||||
// a Gaussian.
|
||||
// When commented out, the fit proceeds fine.
|
||||
|
||||
fitter.addObservedPoint(0.23, 395.0);
|
||||
//fitter.addObservedPoint(0.68, 0.0);
|
||||
fitter.addObservedPoint(1.14, 376.0);
|
||||
//fitter.addObservedPoint(1.59, 0.0);
|
||||
fitter.addObservedPoint(2.05, 163.0);
|
||||
//fitter.addObservedPoint(2.50, 0.0);
|
||||
fitter.addObservedPoint(2.95, 49.0);
|
||||
//fitter.addObservedPoint(3.41, 0.0);
|
||||
fitter.addObservedPoint(3.86, 16.0);
|
||||
//fitter.addObservedPoint(4.32, 0.0);
|
||||
fitter.addObservedPoint(4.77, 1.0);
|
||||
|
||||
final double[] p = fitter.fit();
|
||||
|
||||
// Values are copied from a previous run of this test.
|
||||
Assert.assertEquals(420.8397296167364, p[0], 1e-12);
|
||||
Assert.assertEquals(0.603770729862231, p[1], 1e-15);
|
||||
Assert.assertEquals(1.0786447936766612, p[2], 1e-14);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the specified points to specified <code>GaussianFitter</code>
|
||||
* instance.
|
||||
*
|
||||
* @param points data points where first dimension is a point index and
|
||||
* second dimension is an array of length two representing the point
|
||||
* with the first value corresponding to X and the second value
|
||||
* corresponding to Y
|
||||
* @param fitter fitter to which the points in <code>points</code> should be
|
||||
* added as observed points
|
||||
*/
|
||||
protected static void addDatasetToGaussianFitter(double[][] points,
|
||||
GaussianFitter fitter) {
|
||||
for (int i = 0; i < points.length; i++) {
|
||||
fitter.addObservedPoint(points[i][0], points[i][1]);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.fitting;
|
||||
|
||||
import java.util.Random;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.LevenbergMarquardtOptimizer;
|
||||
import org.apache.commons.math3.analysis.function.HarmonicOscillator;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.exception.MathIllegalStateException;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.util.MathUtils;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
public class HarmonicFitterTest {
|
||||
@Test(expected=NumberIsTooSmallException.class)
|
||||
public void testPreconditions1() {
|
||||
HarmonicFitter fitter =
|
||||
new HarmonicFitter(new LevenbergMarquardtOptimizer());
|
||||
|
||||
fitter.fit();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNoError() {
|
||||
final double a = 0.2;
|
||||
final double w = 3.4;
|
||||
final double p = 4.1;
|
||||
HarmonicOscillator f = new HarmonicOscillator(a, w, p);
|
||||
|
||||
HarmonicFitter fitter =
|
||||
new HarmonicFitter(new LevenbergMarquardtOptimizer());
|
||||
for (double x = 0.0; x < 1.3; x += 0.01) {
|
||||
fitter.addObservedPoint(1, x, f.value(x));
|
||||
}
|
||||
|
||||
final double[] fitted = fitter.fit();
|
||||
Assert.assertEquals(a, fitted[0], 1.0e-13);
|
||||
Assert.assertEquals(w, fitted[1], 1.0e-13);
|
||||
Assert.assertEquals(p, MathUtils.normalizeAngle(fitted[2], p), 1e-13);
|
||||
|
||||
HarmonicOscillator ff = new HarmonicOscillator(fitted[0], fitted[1], fitted[2]);
|
||||
|
||||
for (double x = -1.0; x < 1.0; x += 0.01) {
|
||||
Assert.assertTrue(FastMath.abs(f.value(x) - ff.value(x)) < 1e-13);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test1PercentError() {
|
||||
Random randomizer = new Random(64925784252l);
|
||||
final double a = 0.2;
|
||||
final double w = 3.4;
|
||||
final double p = 4.1;
|
||||
HarmonicOscillator f = new HarmonicOscillator(a, w, p);
|
||||
|
||||
HarmonicFitter fitter =
|
||||
new HarmonicFitter(new LevenbergMarquardtOptimizer());
|
||||
for (double x = 0.0; x < 10.0; x += 0.1) {
|
||||
fitter.addObservedPoint(1, x,
|
||||
f.value(x) + 0.01 * randomizer.nextGaussian());
|
||||
}
|
||||
|
||||
final double[] fitted = fitter.fit();
|
||||
Assert.assertEquals(a, fitted[0], 7.6e-4);
|
||||
Assert.assertEquals(w, fitted[1], 2.7e-3);
|
||||
Assert.assertEquals(p, MathUtils.normalizeAngle(fitted[2], p), 1.3e-2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTinyVariationsData() {
|
||||
Random randomizer = new Random(64925784252l);
|
||||
|
||||
HarmonicFitter fitter =
|
||||
new HarmonicFitter(new LevenbergMarquardtOptimizer());
|
||||
for (double x = 0.0; x < 10.0; x += 0.1) {
|
||||
fitter.addObservedPoint(1, x, 1e-7 * randomizer.nextGaussian());
|
||||
}
|
||||
|
||||
fitter.fit();
|
||||
// This test serves to cover the part of the code of "guessAOmega"
|
||||
// when the algorithm using integrals fails.
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInitialGuess() {
|
||||
Random randomizer = new Random(45314242l);
|
||||
final double a = 0.2;
|
||||
final double w = 3.4;
|
||||
final double p = 4.1;
|
||||
HarmonicOscillator f = new HarmonicOscillator(a, w, p);
|
||||
|
||||
HarmonicFitter fitter =
|
||||
new HarmonicFitter(new LevenbergMarquardtOptimizer());
|
||||
for (double x = 0.0; x < 10.0; x += 0.1) {
|
||||
fitter.addObservedPoint(1, x,
|
||||
f.value(x) + 0.01 * randomizer.nextGaussian());
|
||||
}
|
||||
|
||||
final double[] fitted = fitter.fit(new double[] { 0.15, 3.6, 4.5 });
|
||||
Assert.assertEquals(a, fitted[0], 1.2e-3);
|
||||
Assert.assertEquals(w, fitted[1], 3.3e-3);
|
||||
Assert.assertEquals(p, MathUtils.normalizeAngle(fitted[2], p), 1.7e-2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnsorted() {
|
||||
Random randomizer = new Random(64925784252l);
|
||||
final double a = 0.2;
|
||||
final double w = 3.4;
|
||||
final double p = 4.1;
|
||||
HarmonicOscillator f = new HarmonicOscillator(a, w, p);
|
||||
|
||||
HarmonicFitter fitter =
|
||||
new HarmonicFitter(new LevenbergMarquardtOptimizer());
|
||||
|
||||
// build a regularly spaced array of measurements
|
||||
int size = 100;
|
||||
double[] xTab = new double[size];
|
||||
double[] yTab = new double[size];
|
||||
for (int i = 0; i < size; ++i) {
|
||||
xTab[i] = 0.1 * i;
|
||||
yTab[i] = f.value(xTab[i]) + 0.01 * randomizer.nextGaussian();
|
||||
}
|
||||
|
||||
// shake it
|
||||
for (int i = 0; i < size; ++i) {
|
||||
int i1 = randomizer.nextInt(size);
|
||||
int i2 = randomizer.nextInt(size);
|
||||
double xTmp = xTab[i1];
|
||||
double yTmp = yTab[i1];
|
||||
xTab[i1] = xTab[i2];
|
||||
yTab[i1] = yTab[i2];
|
||||
xTab[i2] = xTmp;
|
||||
yTab[i2] = yTmp;
|
||||
}
|
||||
|
||||
// pass it to the fitter
|
||||
for (int i = 0; i < size; ++i) {
|
||||
fitter.addObservedPoint(1, xTab[i], yTab[i]);
|
||||
}
|
||||
|
||||
final double[] fitted = fitter.fit();
|
||||
Assert.assertEquals(a, fitted[0], 7.6e-4);
|
||||
Assert.assertEquals(w, fitted[1], 3.5e-3);
|
||||
Assert.assertEquals(p, MathUtils.normalizeAngle(fitted[2], p), 1.5e-2);
|
||||
}
|
||||
|
||||
@Test(expected=MathIllegalStateException.class)
|
||||
public void testMath844() {
|
||||
final double[] y = { 0, 1, 2, 3, 2, 1,
|
||||
0, -1, -2, -3, -2, -1,
|
||||
0, 1, 2, 3, 2, 1,
|
||||
0, -1, -2, -3, -2, -1,
|
||||
0, 1, 2, 3, 2, 1, 0 };
|
||||
final int len = y.length;
|
||||
final WeightedObservedPoint[] points = new WeightedObservedPoint[len];
|
||||
for (int i = 0; i < len; i++) {
|
||||
points[i] = new WeightedObservedPoint(1, i, y[i]);
|
||||
}
|
||||
|
||||
// The guesser fails because the function is far from an harmonic
|
||||
// function: It is a triangular periodic function with amplitude 3
|
||||
// and period 12, and all sample points are taken at integer abscissae
|
||||
// so function values all belong to the integer subset {-3, -2, -1, 0,
|
||||
// 1, 2, 3}.
|
||||
final HarmonicFitter.ParameterGuesser guesser
|
||||
= new HarmonicFitter.ParameterGuesser(points);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,256 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.fitting;
|
||||
|
||||
import java.util.Random;
|
||||
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
|
||||
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction.Parametric;
|
||||
import org.apache.commons.math3.exception.ConvergenceException;
|
||||
import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.LevenbergMarquardtOptimizer;
|
||||
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.GaussNewtonOptimizer;
|
||||
import org.apache.commons.math3.optim.SimpleVectorValueChecker;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.apache.commons.math3.distribution.UniformRealDistribution;
|
||||
import org.apache.commons.math3.TestUtils;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
/**
|
||||
* Test for class {@link CurveFitter} where the function to fit is a
|
||||
* polynomial.
|
||||
*/
|
||||
public class PolynomialFitterTest {
|
||||
@Test
|
||||
public void testFit() {
|
||||
final RealDistribution rng = new UniformRealDistribution(-100, 100);
|
||||
rng.reseedRandomGenerator(64925784252L);
|
||||
|
||||
final LevenbergMarquardtOptimizer optim = new LevenbergMarquardtOptimizer();
|
||||
final PolynomialFitter fitter = new PolynomialFitter(optim);
|
||||
final double[] coeff = { 12.9, -3.4, 2.1 }; // 12.9 - 3.4 x + 2.1 x^2
|
||||
final PolynomialFunction f = new PolynomialFunction(coeff);
|
||||
|
||||
// Collect data from a known polynomial.
|
||||
for (int i = 0; i < 100; i++) {
|
||||
final double x = rng.sample();
|
||||
fitter.addObservedPoint(x, f.value(x));
|
||||
}
|
||||
|
||||
// Start fit from initial guesses that are far from the optimal values.
|
||||
final double[] best = fitter.fit(new double[] { -1e-20, 3e15, -5e25 });
|
||||
|
||||
TestUtils.assertEquals("best != coeff", coeff, best, 1e-12);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNoError() {
|
||||
Random randomizer = new Random(64925784252l);
|
||||
for (int degree = 1; degree < 10; ++degree) {
|
||||
PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
|
||||
|
||||
PolynomialFitter fitter = new PolynomialFitter(new LevenbergMarquardtOptimizer());
|
||||
for (int i = 0; i <= degree; ++i) {
|
||||
fitter.addObservedPoint(1.0, i, p.value(i));
|
||||
}
|
||||
|
||||
final double[] init = new double[degree + 1];
|
||||
PolynomialFunction fitted = new PolynomialFunction(fitter.fit(init));
|
||||
|
||||
for (double x = -1.0; x < 1.0; x += 0.01) {
|
||||
double error = FastMath.abs(p.value(x) - fitted.value(x)) /
|
||||
(1.0 + FastMath.abs(p.value(x)));
|
||||
Assert.assertEquals(0.0, error, 1.0e-6);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSmallError() {
|
||||
Random randomizer = new Random(53882150042l);
|
||||
double maxError = 0;
|
||||
for (int degree = 0; degree < 10; ++degree) {
|
||||
PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
|
||||
|
||||
PolynomialFitter fitter = new PolynomialFitter(new LevenbergMarquardtOptimizer());
|
||||
for (double x = -1.0; x < 1.0; x += 0.01) {
|
||||
fitter.addObservedPoint(1.0, x,
|
||||
p.value(x) + 0.1 * randomizer.nextGaussian());
|
||||
}
|
||||
|
||||
final double[] init = new double[degree + 1];
|
||||
PolynomialFunction fitted = new PolynomialFunction(fitter.fit(init));
|
||||
|
||||
for (double x = -1.0; x < 1.0; x += 0.01) {
|
||||
double error = FastMath.abs(p.value(x) - fitted.value(x)) /
|
||||
(1.0 + FastMath.abs(p.value(x)));
|
||||
maxError = FastMath.max(maxError, error);
|
||||
Assert.assertTrue(FastMath.abs(error) < 0.1);
|
||||
}
|
||||
}
|
||||
Assert.assertTrue(maxError > 0.01);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath798() {
|
||||
final double tol = 1e-14;
|
||||
final SimpleVectorValueChecker checker = new SimpleVectorValueChecker(tol, tol);
|
||||
final double[] init = new double[] { 0, 0 };
|
||||
final int maxEval = 3;
|
||||
|
||||
final double[] lm = doMath798(new LevenbergMarquardtOptimizer(checker), maxEval, init);
|
||||
final double[] gn = doMath798(new GaussNewtonOptimizer(checker), maxEval, init);
|
||||
|
||||
for (int i = 0; i <= 1; i++) {
|
||||
Assert.assertEquals(lm[i], gn[i], tol);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This test shows that the user can set the maximum number of iterations
|
||||
* to avoid running for too long.
|
||||
* But in the test case, the real problem is that the tolerance is way too
|
||||
* stringent.
|
||||
*/
|
||||
@Test(expected=TooManyEvaluationsException.class)
|
||||
public void testMath798WithToleranceTooLow() {
|
||||
final double tol = 1e-100;
|
||||
final SimpleVectorValueChecker checker = new SimpleVectorValueChecker(tol, tol);
|
||||
final double[] init = new double[] { 0, 0 };
|
||||
final int maxEval = 10000; // Trying hard to fit.
|
||||
|
||||
final double[] gn = doMath798(new GaussNewtonOptimizer(checker), maxEval, init);
|
||||
}
|
||||
|
||||
/**
|
||||
* This test shows that the user can set the maximum number of iterations
|
||||
* to avoid running for too long.
|
||||
* Even if the real problem is that the tolerance is way too stringent, it
|
||||
* is possible to get the best solution so far, i.e. a checker will return
|
||||
* the point when the maximum iteration count has been reached.
|
||||
*/
|
||||
@Test
|
||||
public void testMath798WithToleranceTooLowButNoException() {
|
||||
final double tol = 1e-100;
|
||||
final double[] init = new double[] { 0, 0 };
|
||||
final int maxEval = 10000; // Trying hard to fit.
|
||||
final SimpleVectorValueChecker checker = new SimpleVectorValueChecker(tol, tol, maxEval);
|
||||
|
||||
final double[] lm = doMath798(new LevenbergMarquardtOptimizer(checker), maxEval, init);
|
||||
final double[] gn = doMath798(new GaussNewtonOptimizer(checker), maxEval, init);
|
||||
|
||||
for (int i = 0; i <= 1; i++) {
|
||||
Assert.assertEquals(lm[i], gn[i], 1e-15);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param optimizer Optimizer.
|
||||
* @param maxEval Maximum number of function evaluations.
|
||||
* @param init First guess.
|
||||
* @return the solution found by the given optimizer.
|
||||
*/
|
||||
private double[] doMath798(MultivariateVectorOptimizer optimizer,
|
||||
int maxEval,
|
||||
double[] init) {
|
||||
final CurveFitter<Parametric> fitter = new CurveFitter<Parametric>(optimizer);
|
||||
|
||||
fitter.addObservedPoint(-0.2, -7.12442E-13);
|
||||
fitter.addObservedPoint(-0.199, -4.33397E-13);
|
||||
fitter.addObservedPoint(-0.198, -2.823E-13);
|
||||
fitter.addObservedPoint(-0.197, -1.40405E-13);
|
||||
fitter.addObservedPoint(-0.196, -7.80821E-15);
|
||||
fitter.addObservedPoint(-0.195, 6.20484E-14);
|
||||
fitter.addObservedPoint(-0.194, 7.24673E-14);
|
||||
fitter.addObservedPoint(-0.193, 1.47152E-13);
|
||||
fitter.addObservedPoint(-0.192, 1.9629E-13);
|
||||
fitter.addObservedPoint(-0.191, 2.12038E-13);
|
||||
fitter.addObservedPoint(-0.19, 2.46906E-13);
|
||||
fitter.addObservedPoint(-0.189, 2.77495E-13);
|
||||
fitter.addObservedPoint(-0.188, 2.51281E-13);
|
||||
fitter.addObservedPoint(-0.187, 2.64001E-13);
|
||||
fitter.addObservedPoint(-0.186, 2.8882E-13);
|
||||
fitter.addObservedPoint(-0.185, 3.13604E-13);
|
||||
fitter.addObservedPoint(-0.184, 3.14248E-13);
|
||||
fitter.addObservedPoint(-0.183, 3.1172E-13);
|
||||
fitter.addObservedPoint(-0.182, 3.12912E-13);
|
||||
fitter.addObservedPoint(-0.181, 3.06761E-13);
|
||||
fitter.addObservedPoint(-0.18, 2.8559E-13);
|
||||
fitter.addObservedPoint(-0.179, 2.86806E-13);
|
||||
fitter.addObservedPoint(-0.178, 2.985E-13);
|
||||
fitter.addObservedPoint(-0.177, 2.67148E-13);
|
||||
fitter.addObservedPoint(-0.176, 2.94173E-13);
|
||||
fitter.addObservedPoint(-0.175, 3.27528E-13);
|
||||
fitter.addObservedPoint(-0.174, 3.33858E-13);
|
||||
fitter.addObservedPoint(-0.173, 2.97511E-13);
|
||||
fitter.addObservedPoint(-0.172, 2.8615E-13);
|
||||
fitter.addObservedPoint(-0.171, 2.84624E-13);
|
||||
|
||||
final double[] coeff = fitter.fit(maxEval,
|
||||
new PolynomialFunction.Parametric(),
|
||||
init);
|
||||
return coeff;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRedundantSolvable() {
|
||||
// Levenberg-Marquardt should handle redundant information gracefully
|
||||
checkUnsolvableProblem(new LevenbergMarquardtOptimizer(), true);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRedundantUnsolvable() {
|
||||
// Gauss-Newton should not be able to solve redundant information
|
||||
checkUnsolvableProblem(new GaussNewtonOptimizer(true, new SimpleVectorValueChecker(1e-15, 1e-15)), false);
|
||||
}
|
||||
|
||||
private void checkUnsolvableProblem(MultivariateVectorOptimizer optimizer,
|
||||
boolean solvable) {
|
||||
Random randomizer = new Random(1248788532l);
|
||||
for (int degree = 0; degree < 10; ++degree) {
|
||||
PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
|
||||
|
||||
PolynomialFitter fitter = new PolynomialFitter(optimizer);
|
||||
|
||||
// reusing the same point over and over again does not bring
|
||||
// information, the problem cannot be solved in this case for
|
||||
// degrees greater than 1 (but one point is sufficient for
|
||||
// degree 0)
|
||||
for (double x = -1.0; x < 1.0; x += 0.01) {
|
||||
fitter.addObservedPoint(1.0, 0.0, p.value(0.0));
|
||||
}
|
||||
|
||||
try {
|
||||
final double[] init = new double[degree + 1];
|
||||
fitter.fit(init);
|
||||
Assert.assertTrue(solvable || (degree == 0));
|
||||
} catch(ConvergenceException e) {
|
||||
Assert.assertTrue((! solvable) && (degree > 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private PolynomialFunction buildRandomPolynomial(int degree, Random randomizer) {
|
||||
final double[] coefficients = new double[degree + 1];
|
||||
for (int i = 0; i <= degree; ++i) {
|
||||
coefficients[i] = randomizer.nextGaussian();
|
||||
}
|
||||
return new PolynomialFunction(coefficients);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.TestUtils;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class PointValuePairTest {
|
||||
@Test
|
||||
public void testSerial() {
|
||||
PointValuePair pv1 = new PointValuePair(new double[] { 1.0, 2.0, 3.0 }, 4.0);
|
||||
PointValuePair pv2 = (PointValuePair) TestUtils.serializeAndRecover(pv1);
|
||||
Assert.assertEquals(pv1.getKey().length, pv2.getKey().length);
|
||||
for (int i = 0; i < pv1.getKey().length; ++i) {
|
||||
Assert.assertEquals(pv1.getKey()[i], pv2.getKey()[i], 1.0e-15);
|
||||
}
|
||||
Assert.assertEquals(pv1.getValue(), pv2.getValue(), 1.0e-15);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.TestUtils;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class PointVectorValuePairTest {
|
||||
@Test
|
||||
public void testSerial() {
|
||||
PointVectorValuePair pv1 = new PointVectorValuePair(new double[] { 1.0, 2.0, 3.0 },
|
||||
new double[] { 4.0, 5.0 });
|
||||
PointVectorValuePair pv2 = (PointVectorValuePair) TestUtils.serializeAndRecover(pv1);
|
||||
Assert.assertEquals(pv1.getKey().length, pv2.getKey().length);
|
||||
for (int i = 0; i < pv1.getKey().length; ++i) {
|
||||
Assert.assertEquals(pv1.getKey()[i], pv2.getKey()[i], 1.0e-15);
|
||||
}
|
||||
Assert.assertEquals(pv1.getValue().length, pv2.getValue().length);
|
||||
for (int i = 0; i < pv1.getValue().length; ++i) {
|
||||
Assert.assertEquals(pv1.getValue()[i], pv2.getValue()[i], 1.0e-15);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
public class SimplePointCheckerTest {
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testIterationCheckPrecondition() {
|
||||
new SimplePointChecker<PointValuePair>(1e-1, 1e-2, 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIterationCheck() {
|
||||
final int max = 10;
|
||||
final SimplePointChecker<PointValuePair> checker
|
||||
= new SimplePointChecker<PointValuePair>(1e-1, 1e-2, max);
|
||||
Assert.assertTrue(checker.converged(max, null, null));
|
||||
Assert.assertTrue(checker.converged(max + 1, null, null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIterationCheckDisabled() {
|
||||
final SimplePointChecker<PointValuePair> checker
|
||||
= new SimplePointChecker<PointValuePair>(1e-8, 1e-8);
|
||||
|
||||
final PointValuePair a = new PointValuePair(new double[] { 1d }, 1d);
|
||||
final PointValuePair b = new PointValuePair(new double[] { 10d }, 10d);
|
||||
|
||||
Assert.assertFalse(checker.converged(-1, a, b));
|
||||
Assert.assertFalse(checker.converged(0, a, b));
|
||||
Assert.assertFalse(checker.converged(1000000, a, b));
|
||||
|
||||
Assert.assertTrue(checker.converged(-1, a, a));
|
||||
Assert.assertTrue(checker.converged(-1, b, b));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
public class SimpleValueCheckerTest {
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testIterationCheckPrecondition() {
|
||||
new SimpleValueChecker(1e-1, 1e-2, 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIterationCheck() {
|
||||
final int max = 10;
|
||||
final SimpleValueChecker checker = new SimpleValueChecker(1e-1, 1e-2, max);
|
||||
Assert.assertTrue(checker.converged(max, null, null));
|
||||
Assert.assertTrue(checker.converged(max + 1, null, null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIterationCheckDisabled() {
|
||||
final SimpleValueChecker checker = new SimpleValueChecker(1e-8, 1e-8);
|
||||
|
||||
final PointValuePair a = new PointValuePair(new double[] { 1d }, 1d);
|
||||
final PointValuePair b = new PointValuePair(new double[] { 10d }, 10d);
|
||||
|
||||
Assert.assertFalse(checker.converged(-1, a, b));
|
||||
Assert.assertFalse(checker.converged(0, a, b));
|
||||
Assert.assertFalse(checker.converged(1000000, a, b));
|
||||
|
||||
Assert.assertTrue(checker.converged(-1, a, a));
|
||||
Assert.assertTrue(checker.converged(-1, b, b));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim;
|
||||
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
public class SimpleVectorValueCheckerTest {
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testIterationCheckPrecondition() {
|
||||
new SimpleVectorValueChecker(1e-1, 1e-2, 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIterationCheck() {
|
||||
final int max = 10;
|
||||
final SimpleVectorValueChecker checker = new SimpleVectorValueChecker(1e-1, 1e-2, max);
|
||||
Assert.assertTrue(checker.converged(max, null, null));
|
||||
Assert.assertTrue(checker.converged(max + 1, null, null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIterationCheckDisabled() {
|
||||
final SimpleVectorValueChecker checker = new SimpleVectorValueChecker(1e-8, 1e-8);
|
||||
|
||||
final PointVectorValuePair a = new PointVectorValuePair(new double[] { 1d },
|
||||
new double[] { 1d });
|
||||
final PointVectorValuePair b = new PointVectorValuePair(new double[] { 10d },
|
||||
new double[] { 10d });
|
||||
|
||||
Assert.assertFalse(checker.converged(-1, a, b));
|
||||
Assert.assertFalse(checker.converged(0, a, b));
|
||||
Assert.assertFalse(checker.converged(1000000, a, b));
|
||||
|
||||
Assert.assertTrue(checker.converged(-1, a, a));
|
||||
Assert.assertTrue(checker.converged(-1, b, b));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,664 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import org.apache.commons.math3.optim.MaxIter;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.util.Precision;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
public class SimplexSolverTest {
|
||||
private static final MaxIter DEFAULT_MAX_ITER = new MaxIter(100);
|
||||
|
||||
@Test
|
||||
public void testMath828() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(
|
||||
new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 0.0);
|
||||
|
||||
ArrayList <LinearConstraint>constraints = new ArrayList<LinearConstraint>();
|
||||
|
||||
constraints.add(new LinearConstraint(new double[] {0.0, 39.0, 23.0, 96.0, 15.0, 48.0, 9.0, 21.0, 48.0, 36.0, 76.0, 19.0, 88.0, 17.0, 16.0, 36.0,}, Relationship.GEQ, 15.0));
|
||||
constraints.add(new LinearConstraint(new double[] {0.0, 59.0, 93.0, 12.0, 29.0, 78.0, 73.0, 87.0, 32.0, 70.0, 68.0, 24.0, 11.0, 26.0, 65.0, 25.0,}, Relationship.GEQ, 29.0));
|
||||
constraints.add(new LinearConstraint(new double[] {0.0, 74.0, 5.0, 82.0, 6.0, 97.0, 55.0, 44.0, 52.0, 54.0, 5.0, 93.0, 91.0, 8.0, 20.0, 97.0,}, Relationship.GEQ, 6.0));
|
||||
constraints.add(new LinearConstraint(new double[] {8.0, -3.0, -28.0, -72.0, -8.0, -31.0, -31.0, -74.0, -47.0, -59.0, -24.0, -57.0, -56.0, -16.0, -92.0, -59.0,}, Relationship.GEQ, 0.0));
|
||||
constraints.add(new LinearConstraint(new double[] {25.0, -7.0, -99.0, -78.0, -25.0, -14.0, -16.0, -89.0, -39.0, -56.0, -53.0, -9.0, -18.0, -26.0, -11.0, -61.0,}, Relationship.GEQ, 0.0));
|
||||
constraints.add(new LinearConstraint(new double[] {33.0, -95.0, -15.0, -4.0, -33.0, -3.0, -20.0, -96.0, -27.0, -13.0, -80.0, -24.0, -3.0, -13.0, -57.0, -76.0,}, Relationship.GEQ, 0.0));
|
||||
constraints.add(new LinearConstraint(new double[] {7.0, -95.0, -39.0, -93.0, -7.0, -94.0, -94.0, -62.0, -76.0, -26.0, -53.0, -57.0, -31.0, -76.0, -53.0, -52.0,}, Relationship.GEQ, 0.0));
|
||||
|
||||
double epsilon = 1e-6;
|
||||
PointValuePair solution = new SimplexSolver().optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(true));
|
||||
Assert.assertEquals(1.0d, solution.getValue(), epsilon);
|
||||
Assert.assertTrue(validSolution(solution, constraints, epsilon));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath828Cycle() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(
|
||||
new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 0.0);
|
||||
|
||||
ArrayList <LinearConstraint>constraints = new ArrayList<LinearConstraint>();
|
||||
|
||||
constraints.add(new LinearConstraint(new double[] {0.0, 16.0, 14.0, 69.0, 1.0, 85.0, 52.0, 43.0, 64.0, 97.0, 14.0, 74.0, 89.0, 28.0, 94.0, 58.0, 13.0, 22.0, 21.0, 17.0, 30.0, 25.0, 1.0, 59.0, 91.0, 78.0, 12.0, 74.0, 56.0, 3.0, 88.0,}, Relationship.GEQ, 91.0));
|
||||
constraints.add(new LinearConstraint(new double[] {0.0, 60.0, 40.0, 81.0, 71.0, 72.0, 46.0, 45.0, 38.0, 48.0, 40.0, 17.0, 33.0, 85.0, 64.0, 32.0, 84.0, 3.0, 54.0, 44.0, 71.0, 67.0, 90.0, 95.0, 54.0, 99.0, 99.0, 29.0, 52.0, 98.0, 9.0,}, Relationship.GEQ, 54.0));
|
||||
constraints.add(new LinearConstraint(new double[] {0.0, 41.0, 12.0, 86.0, 90.0, 61.0, 31.0, 41.0, 23.0, 89.0, 17.0, 74.0, 44.0, 27.0, 16.0, 47.0, 80.0, 32.0, 11.0, 56.0, 68.0, 82.0, 11.0, 62.0, 62.0, 53.0, 39.0, 16.0, 48.0, 1.0, 63.0,}, Relationship.GEQ, 62.0));
|
||||
constraints.add(new LinearConstraint(new double[] {83.0, -76.0, -94.0, -19.0, -15.0, -70.0, -72.0, -57.0, -63.0, -65.0, -22.0, -94.0, -22.0, -88.0, -86.0, -89.0, -72.0, -16.0, -80.0, -49.0, -70.0, -93.0, -95.0, -17.0, -83.0, -97.0, -31.0, -47.0, -31.0, -13.0, -23.0,}, Relationship.GEQ, 0.0));
|
||||
constraints.add(new LinearConstraint(new double[] {41.0, -96.0, -41.0, -48.0, -70.0, -43.0, -43.0, -43.0, -97.0, -37.0, -85.0, -70.0, -45.0, -67.0, -87.0, -69.0, -94.0, -54.0, -54.0, -92.0, -79.0, -10.0, -35.0, -20.0, -41.0, -41.0, -65.0, -25.0, -12.0, -8.0, -46.0,}, Relationship.GEQ, 0.0));
|
||||
constraints.add(new LinearConstraint(new double[] {27.0, -42.0, -65.0, -49.0, -53.0, -42.0, -17.0, -2.0, -61.0, -31.0, -76.0, -47.0, -8.0, -93.0, -86.0, -62.0, -65.0, -63.0, -22.0, -43.0, -27.0, -23.0, -32.0, -74.0, -27.0, -63.0, -47.0, -78.0, -29.0, -95.0, -73.0,}, Relationship.GEQ, 0.0));
|
||||
constraints.add(new LinearConstraint(new double[] {15.0, -46.0, -41.0, -83.0, -98.0, -99.0, -21.0, -35.0, -7.0, -14.0, -80.0, -63.0, -18.0, -42.0, -5.0, -34.0, -56.0, -70.0, -16.0, -18.0, -74.0, -61.0, -47.0, -41.0, -15.0, -79.0, -18.0, -47.0, -88.0, -68.0, -55.0,}, Relationship.GEQ, 0.0));
|
||||
|
||||
double epsilon = 1e-6;
|
||||
PointValuePair solution = new SimplexSolver().optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(true));
|
||||
Assert.assertEquals(1.0d, solution.getValue(), epsilon);
|
||||
Assert.assertTrue(validSolution(solution, constraints, epsilon));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath781() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 2, 6, 7 }, 0);
|
||||
|
||||
ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 2, 1 }, Relationship.LEQ, 2));
|
||||
constraints.add(new LinearConstraint(new double[] { -1, 1, 1 }, Relationship.LEQ, -1));
|
||||
constraints.add(new LinearConstraint(new double[] { 2, -3, 1 }, Relationship.LEQ, -1));
|
||||
|
||||
double epsilon = 1e-6;
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
|
||||
|
||||
Assert.assertTrue(Precision.compareTo(solution.getPoint()[0], 0.0d, epsilon) > 0);
|
||||
Assert.assertTrue(Precision.compareTo(solution.getPoint()[1], 0.0d, epsilon) > 0);
|
||||
Assert.assertTrue(Precision.compareTo(solution.getPoint()[2], 0.0d, epsilon) < 0);
|
||||
Assert.assertEquals(2.0d, solution.getValue(), epsilon);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath713NegativeVariable() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] {1.0, 1.0}, 0.0d);
|
||||
ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] {1, 0}, Relationship.EQ, 1));
|
||||
|
||||
double epsilon = 1e-6;
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(true));
|
||||
|
||||
Assert.assertTrue(Precision.compareTo(solution.getPoint()[0], 0.0d, epsilon) >= 0);
|
||||
Assert.assertTrue(Precision.compareTo(solution.getPoint()[1], 0.0d, epsilon) >= 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath434NegativeVariable() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] {0.0, 0.0, 1.0}, 0.0d);
|
||||
ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] {1, 1, 0}, Relationship.EQ, 5));
|
||||
constraints.add(new LinearConstraint(new double[] {0, 0, 1}, Relationship.GEQ, -10));
|
||||
|
||||
double epsilon = 1e-6;
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(false));
|
||||
|
||||
Assert.assertEquals(5.0, solution.getPoint()[0] + solution.getPoint()[1], epsilon);
|
||||
Assert.assertEquals(-10.0, solution.getPoint()[2], epsilon);
|
||||
Assert.assertEquals(-10.0, solution.getValue(), epsilon);
|
||||
|
||||
}
|
||||
|
||||
@Test(expected = NoFeasibleSolutionException.class)
|
||||
public void testMath434UnfeasibleSolution() {
|
||||
double epsilon = 1e-6;
|
||||
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] {1.0, 0.0}, 0.0);
|
||||
ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] {epsilon/2, 0.5}, Relationship.EQ, 0));
|
||||
constraints.add(new LinearConstraint(new double[] {1e-3, 0.1}, Relationship.EQ, 10));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
// allowing only non-negative values, no feasible solution shall be found
|
||||
solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(true));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath434PivotRowSelection() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] {1.0}, 0.0);
|
||||
|
||||
double epsilon = 1e-6;
|
||||
ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] {200}, Relationship.GEQ, 1));
|
||||
constraints.add(new LinearConstraint(new double[] {100}, Relationship.GEQ, 0.499900001));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(false));
|
||||
|
||||
Assert.assertTrue(Precision.compareTo(solution.getPoint()[0] * 200.d, 1.d, epsilon) >= 0);
|
||||
Assert.assertEquals(0.0050, solution.getValue(), epsilon);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath434PivotRowSelection2() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] {0.0d, 1.0d, 1.0d, 0.0d, 0.0d, 0.0d, 0.0d}, 0.0d);
|
||||
|
||||
ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] {1.0d, -0.1d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, Relationship.EQ, -0.1d));
|
||||
constraints.add(new LinearConstraint(new double[] {1.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, Relationship.GEQ, -1e-18d));
|
||||
constraints.add(new LinearConstraint(new double[] {0.0d, 1.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, Relationship.GEQ, 0.0d));
|
||||
constraints.add(new LinearConstraint(new double[] {0.0d, 0.0d, 0.0d, 1.0d, 0.0d, -0.0128588d, 1e-5d}, Relationship.EQ, 0.0d));
|
||||
constraints.add(new LinearConstraint(new double[] {0.0d, 0.0d, 0.0d, 0.0d, 1.0d, 1e-5d, -0.0128586d}, Relationship.EQ, 1e-10d));
|
||||
constraints.add(new LinearConstraint(new double[] {0.0d, 0.0d, 1.0d, -1.0d, 0.0d, 0.0d, 0.0d}, Relationship.GEQ, 0.0d));
|
||||
constraints.add(new LinearConstraint(new double[] {0.0d, 0.0d, 1.0d, 1.0d, 0.0d, 0.0d, 0.0d}, Relationship.GEQ, 0.0d));
|
||||
constraints.add(new LinearConstraint(new double[] {0.0d, 0.0d, 1.0d, 0.0d, -1.0d, 0.0d, 0.0d}, Relationship.GEQ, 0.0d));
|
||||
constraints.add(new LinearConstraint(new double[] {0.0d, 0.0d, 1.0d, 0.0d, 1.0d, 0.0d, 0.0d}, Relationship.GEQ, 0.0d));
|
||||
|
||||
double epsilon = 1e-7;
|
||||
SimplexSolver simplex = new SimplexSolver();
|
||||
PointValuePair solution = simplex.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(false));
|
||||
|
||||
Assert.assertTrue(Precision.compareTo(solution.getPoint()[0], -1e-18d, epsilon) >= 0);
|
||||
Assert.assertEquals(1.0d, solution.getPoint()[1], epsilon);
|
||||
Assert.assertEquals(0.0d, solution.getPoint()[2], epsilon);
|
||||
Assert.assertEquals(1.0d, solution.getValue(), epsilon);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath272() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 2, 2, 1 }, 0);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 1, 0 }, Relationship.GEQ, 1));
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 0, 1 }, Relationship.GEQ, 1));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 1, 0 }, Relationship.GEQ, 1));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(true));
|
||||
|
||||
Assert.assertEquals(0.0, solution.getPoint()[0], .0000001);
|
||||
Assert.assertEquals(1.0, solution.getPoint()[1], .0000001);
|
||||
Assert.assertEquals(1.0, solution.getPoint()[2], .0000001);
|
||||
Assert.assertEquals(3.0, solution.getValue(), .0000001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath286() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 0.8, 0.2, 0.7, 0.3, 0.6, 0.4 }, 0 );
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 0, 1, 0, 1, 0 }, Relationship.EQ, 23.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 1, 0, 1, 0, 1 }, Relationship.EQ, 23.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 0, 0, 0, 0, 0 }, Relationship.GEQ, 10.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 0, 1, 0, 0, 0 }, Relationship.GEQ, 8.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 0, 0, 0, 1, 0 }, Relationship.GEQ, 5.0));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
|
||||
|
||||
Assert.assertEquals(25.8, solution.getValue(), .0000001);
|
||||
Assert.assertEquals(23.0, solution.getPoint()[0] + solution.getPoint()[2] + solution.getPoint()[4], 0.0000001);
|
||||
Assert.assertEquals(23.0, solution.getPoint()[1] + solution.getPoint()[3] + solution.getPoint()[5], 0.0000001);
|
||||
Assert.assertTrue(solution.getPoint()[0] >= 10.0 - 0.0000001);
|
||||
Assert.assertTrue(solution.getPoint()[2] >= 8.0 - 0.0000001);
|
||||
Assert.assertTrue(solution.getPoint()[4] >= 5.0 - 0.0000001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDegeneracy() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 0.8, 0.7 }, 0 );
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 1 }, Relationship.LEQ, 18.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 0 }, Relationship.GEQ, 10.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 1 }, Relationship.GEQ, 8.0));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(true));
|
||||
Assert.assertEquals(13.6, solution.getValue(), .0000001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath288() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 7, 3, 0, 0 }, 0 );
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 3, 0, -5, 0 }, Relationship.LEQ, 0.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 2, 0, 0, -5 }, Relationship.LEQ, 0.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 3, 0, -5 }, Relationship.LEQ, 0.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 0, 0, 0 }, Relationship.LEQ, 1.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 1, 0, 0 }, Relationship.LEQ, 1.0));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
|
||||
Assert.assertEquals(10.0, solution.getValue(), .0000001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath290GEQ() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 1, 5 }, 0 );
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 2, 0 }, Relationship.GEQ, -1.0));
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(true));
|
||||
Assert.assertEquals(0, solution.getValue(), .0000001);
|
||||
Assert.assertEquals(0, solution.getPoint()[0], .0000001);
|
||||
Assert.assertEquals(0, solution.getPoint()[1], .0000001);
|
||||
}
|
||||
|
||||
@Test(expected=NoFeasibleSolutionException.class)
|
||||
public void testMath290LEQ() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 1, 5 }, 0 );
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 2, 0 }, Relationship.LEQ, -1.0));
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(true));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath293() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 0.8, 0.2, 0.7, 0.3, 0.4, 0.6}, 0 );
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 0, 1, 0, 1, 0 }, Relationship.EQ, 30.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 1, 0, 1, 0, 1 }, Relationship.EQ, 30.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 0.8, 0.2, 0.0, 0.0, 0.0, 0.0 }, Relationship.GEQ, 10.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 0.0, 0.0, 0.7, 0.3, 0.0, 0.0 }, Relationship.GEQ, 10.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 0.0, 0.0, 0.0, 0.0, 0.4, 0.6 }, Relationship.GEQ, 10.0));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution1 = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
|
||||
|
||||
Assert.assertEquals(15.7143, solution1.getPoint()[0], .0001);
|
||||
Assert.assertEquals(0.0, solution1.getPoint()[1], .0001);
|
||||
Assert.assertEquals(14.2857, solution1.getPoint()[2], .0001);
|
||||
Assert.assertEquals(0.0, solution1.getPoint()[3], .0001);
|
||||
Assert.assertEquals(0.0, solution1.getPoint()[4], .0001);
|
||||
Assert.assertEquals(30.0, solution1.getPoint()[5], .0001);
|
||||
Assert.assertEquals(40.57143, solution1.getValue(), .0001);
|
||||
|
||||
double valA = 0.8 * solution1.getPoint()[0] + 0.2 * solution1.getPoint()[1];
|
||||
double valB = 0.7 * solution1.getPoint()[2] + 0.3 * solution1.getPoint()[3];
|
||||
double valC = 0.4 * solution1.getPoint()[4] + 0.6 * solution1.getPoint()[5];
|
||||
|
||||
f = new LinearObjectiveFunction(new double[] { 0.8, 0.2, 0.7, 0.3, 0.4, 0.6}, 0 );
|
||||
constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 0, 1, 0, 1, 0 }, Relationship.EQ, 30.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 1, 0, 1, 0, 1 }, Relationship.EQ, 30.0));
|
||||
constraints.add(new LinearConstraint(new double[] { 0.8, 0.2, 0.0, 0.0, 0.0, 0.0 }, Relationship.GEQ, valA));
|
||||
constraints.add(new LinearConstraint(new double[] { 0.0, 0.0, 0.7, 0.3, 0.0, 0.0 }, Relationship.GEQ, valB));
|
||||
constraints.add(new LinearConstraint(new double[] { 0.0, 0.0, 0.0, 0.0, 0.4, 0.6 }, Relationship.GEQ, valC));
|
||||
|
||||
PointValuePair solution2 = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
|
||||
Assert.assertEquals(40.57143, solution2.getValue(), .0001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimplexSolver() {
|
||||
LinearObjectiveFunction f =
|
||||
new LinearObjectiveFunction(new double[] { 15, 10 }, 7);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 0 }, Relationship.LEQ, 2));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 1 }, Relationship.LEQ, 3));
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 1 }, Relationship.EQ, 4));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
|
||||
Assert.assertEquals(2.0, solution.getPoint()[0], 0.0);
|
||||
Assert.assertEquals(2.0, solution.getPoint()[1], 0.0);
|
||||
Assert.assertEquals(57.0, solution.getValue(), 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSingleVariableAndConstraint() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 3 }, 0);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1 }, Relationship.LEQ, 10));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
|
||||
Assert.assertEquals(10.0, solution.getPoint()[0], 0.0);
|
||||
Assert.assertEquals(30.0, solution.getValue(), 0.0);
|
||||
}
|
||||
|
||||
/**
|
||||
* With no artificial variables needed (no equals and no greater than
|
||||
* constraints) we can go straight to Phase 2.
|
||||
*/
|
||||
@Test
|
||||
public void testModelWithNoArtificialVars() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 15, 10 }, 0);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 0 }, Relationship.LEQ, 2));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 1 }, Relationship.LEQ, 3));
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 1 }, Relationship.LEQ, 4));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
|
||||
Assert.assertEquals(2.0, solution.getPoint()[0], 0.0);
|
||||
Assert.assertEquals(2.0, solution.getPoint()[1], 0.0);
|
||||
Assert.assertEquals(50.0, solution.getValue(), 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMinimization() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { -2, 1 }, -5);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 2 }, Relationship.LEQ, 6));
|
||||
constraints.add(new LinearConstraint(new double[] { 3, 2 }, Relationship.LEQ, 12));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 1 }, Relationship.GEQ, 0));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(false));
|
||||
Assert.assertEquals(4.0, solution.getPoint()[0], 0.0);
|
||||
Assert.assertEquals(0.0, solution.getPoint()[1], 0.0);
|
||||
Assert.assertEquals(-13.0, solution.getValue(), 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSolutionWithNegativeDecisionVariable() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { -2, 1 }, 0);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 1 }, Relationship.GEQ, 6));
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 2 }, Relationship.LEQ, 14));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
|
||||
Assert.assertEquals(-2.0, solution.getPoint()[0], 0.0);
|
||||
Assert.assertEquals(8.0, solution.getPoint()[1], 0.0);
|
||||
Assert.assertEquals(12.0, solution.getValue(), 0.0);
|
||||
}
|
||||
|
||||
@Test(expected = NoFeasibleSolutionException.class)
|
||||
public void testInfeasibleSolution() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 15 }, 0);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1 }, Relationship.LEQ, 1));
|
||||
constraints.add(new LinearConstraint(new double[] { 1 }, Relationship.GEQ, 3));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
|
||||
}
|
||||
|
||||
@Test(expected = UnboundedSolutionException.class)
|
||||
public void testUnboundedSolution() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 15, 10 }, 0);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 0 }, Relationship.EQ, 2));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRestrictVariablesToNonNegative() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 409, 523, 70, 204, 339 }, 0);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 43, 56, 345, 56, 5 }, Relationship.LEQ, 4567456));
|
||||
constraints.add(new LinearConstraint(new double[] { 12, 45, 7, 56, 23 }, Relationship.LEQ, 56454));
|
||||
constraints.add(new LinearConstraint(new double[] { 8, 768, 0, 34, 7456 }, Relationship.LEQ, 1923421));
|
||||
constraints.add(new LinearConstraint(new double[] { 12342, 2342, 34, 678, 2342 }, Relationship.GEQ, 4356));
|
||||
constraints.add(new LinearConstraint(new double[] { 45, 678, 76, 52, 23 }, Relationship.EQ, 456356));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
|
||||
Assert.assertEquals(2902.92783505155, solution.getPoint()[0], .0000001);
|
||||
Assert.assertEquals(480.419243986254, solution.getPoint()[1], .0000001);
|
||||
Assert.assertEquals(0.0, solution.getPoint()[2], .0000001);
|
||||
Assert.assertEquals(0.0, solution.getPoint()[3], .0000001);
|
||||
Assert.assertEquals(0.0, solution.getPoint()[4], .0000001);
|
||||
Assert.assertEquals(1438556.7491409, solution.getValue(), .0000001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEpsilon() {
|
||||
LinearObjectiveFunction f =
|
||||
new LinearObjectiveFunction(new double[] { 10, 5, 1 }, 0);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 9, 8, 0 }, Relationship.EQ, 17));
|
||||
constraints.add(new LinearConstraint(new double[] { 0, 7, 8 }, Relationship.LEQ, 7));
|
||||
constraints.add(new LinearConstraint(new double[] { 10, 0, 2 }, Relationship.LEQ, 10));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(false));
|
||||
Assert.assertEquals(1.0, solution.getPoint()[0], 0.0);
|
||||
Assert.assertEquals(1.0, solution.getPoint()[1], 0.0);
|
||||
Assert.assertEquals(0.0, solution.getPoint()[2], 0.0);
|
||||
Assert.assertEquals(15.0, solution.getValue(), 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTrivialModel() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] { 1, 1 }, 0);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] { 1, 1 }, Relationship.EQ, 0));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MAXIMIZE, new NonNegativeConstraint(true));
|
||||
Assert.assertEquals(0, solution.getValue(), .0000001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLargeModel() {
|
||||
double[] objective = new double[] {
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 12, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
12, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 12, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 12, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 12, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 12, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1};
|
||||
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(objective, 0);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(equationFromString(objective.length, "x0 + x1 + x2 + x3 - x12 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11 - x13 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11 >= 49"));
|
||||
constraints.add(equationFromString(objective.length, "x0 + x1 + x2 + x3 >= 42"));
|
||||
constraints.add(equationFromString(objective.length, "x14 + x15 + x16 + x17 - x26 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x18 + x19 + x20 + x21 + x22 + x23 + x24 + x25 - x27 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x14 + x15 + x16 + x17 - x12 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x18 + x19 + x20 + x21 + x22 + x23 + x24 + x25 - x13 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x28 + x29 + x30 + x31 - x40 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x32 + x33 + x34 + x35 + x36 + x37 + x38 + x39 - x41 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x32 + x33 + x34 + x35 + x36 + x37 + x38 + x39 >= 49"));
|
||||
constraints.add(equationFromString(objective.length, "x28 + x29 + x30 + x31 >= 42"));
|
||||
constraints.add(equationFromString(objective.length, "x42 + x43 + x44 + x45 - x54 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x46 + x47 + x48 + x49 + x50 + x51 + x52 + x53 - x55 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x42 + x43 + x44 + x45 - x40 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x46 + x47 + x48 + x49 + x50 + x51 + x52 + x53 - x41 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x56 + x57 + x58 + x59 - x68 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x60 + x61 + x62 + x63 + x64 + x65 + x66 + x67 - x69 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x60 + x61 + x62 + x63 + x64 + x65 + x66 + x67 >= 51"));
|
||||
constraints.add(equationFromString(objective.length, "x56 + x57 + x58 + x59 >= 44"));
|
||||
constraints.add(equationFromString(objective.length, "x70 + x71 + x72 + x73 - x82 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x74 + x75 + x76 + x77 + x78 + x79 + x80 + x81 - x83 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x70 + x71 + x72 + x73 - x68 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x74 + x75 + x76 + x77 + x78 + x79 + x80 + x81 - x69 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x84 + x85 + x86 + x87 - x96 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x88 + x89 + x90 + x91 + x92 + x93 + x94 + x95 - x97 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x88 + x89 + x90 + x91 + x92 + x93 + x94 + x95 >= 51"));
|
||||
constraints.add(equationFromString(objective.length, "x84 + x85 + x86 + x87 >= 44"));
|
||||
constraints.add(equationFromString(objective.length, "x98 + x99 + x100 + x101 - x110 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x102 + x103 + x104 + x105 + x106 + x107 + x108 + x109 - x111 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x98 + x99 + x100 + x101 - x96 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x102 + x103 + x104 + x105 + x106 + x107 + x108 + x109 - x97 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x112 + x113 + x114 + x115 - x124 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x116 + x117 + x118 + x119 + x120 + x121 + x122 + x123 - x125 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x116 + x117 + x118 + x119 + x120 + x121 + x122 + x123 >= 49"));
|
||||
constraints.add(equationFromString(objective.length, "x112 + x113 + x114 + x115 >= 42"));
|
||||
constraints.add(equationFromString(objective.length, "x126 + x127 + x128 + x129 - x138 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x130 + x131 + x132 + x133 + x134 + x135 + x136 + x137 - x139 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x126 + x127 + x128 + x129 - x124 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x130 + x131 + x132 + x133 + x134 + x135 + x136 + x137 - x125 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x140 + x141 + x142 + x143 - x152 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x144 + x145 + x146 + x147 + x148 + x149 + x150 + x151 - x153 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x144 + x145 + x146 + x147 + x148 + x149 + x150 + x151 >= 59"));
|
||||
constraints.add(equationFromString(objective.length, "x140 + x141 + x142 + x143 >= 42"));
|
||||
constraints.add(equationFromString(objective.length, "x154 + x155 + x156 + x157 - x166 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x158 + x159 + x160 + x161 + x162 + x163 + x164 + x165 - x167 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x154 + x155 + x156 + x157 - x152 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x158 + x159 + x160 + x161 + x162 + x163 + x164 + x165 - x153 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x83 + x82 - x168 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x111 + x110 - x169 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x170 - x182 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x171 - x183 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x172 - x184 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x173 - x185 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x174 - x186 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x175 + x176 - x187 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x177 - x188 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x178 - x189 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x179 - x190 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x180 - x191 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x181 - x192 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x170 - x26 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x171 - x27 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x172 - x54 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x173 - x55 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x174 - x168 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x177 - x169 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x178 - x138 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x179 - x139 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x180 - x166 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x181 - x167 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x193 - x205 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x194 - x206 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x195 - x207 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x196 - x208 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x197 - x209 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x198 + x199 - x210 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x200 - x211 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x201 - x212 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x202 - x213 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x203 - x214 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x204 - x215 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x193 - x182 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x194 - x183 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x195 - x184 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x196 - x185 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x197 - x186 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x198 + x199 - x187 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x200 - x188 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x201 - x189 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x202 - x190 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x203 - x191 = 0"));
|
||||
constraints.add(equationFromString(objective.length, "x204 - x192 = 0"));
|
||||
|
||||
SimplexSolver solver = new SimplexSolver();
|
||||
PointValuePair solution = solver.optimize(DEFAULT_MAX_ITER, f, new LinearConstraintSet(constraints),
|
||||
GoalType.MINIMIZE, new NonNegativeConstraint(true));
|
||||
Assert.assertEquals(7518.0, solution.getValue(), .0000001);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a test string to a {@link LinearConstraint}.
|
||||
* Ex: x0 + x1 + x2 + x3 - x12 = 0
|
||||
*/
|
||||
private LinearConstraint equationFromString(int numCoefficients, String s) {
|
||||
Relationship relationship;
|
||||
if (s.contains(">=")) {
|
||||
relationship = Relationship.GEQ;
|
||||
} else if (s.contains("<=")) {
|
||||
relationship = Relationship.LEQ;
|
||||
} else if (s.contains("=")) {
|
||||
relationship = Relationship.EQ;
|
||||
} else {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
|
||||
String[] equationParts = s.split("[>|<]?=");
|
||||
double rhs = Double.parseDouble(equationParts[1].trim());
|
||||
|
||||
double[] lhs = new double[numCoefficients];
|
||||
String left = equationParts[0].replaceAll(" ?x", "");
|
||||
String[] coefficients = left.split(" ");
|
||||
for (String coefficient : coefficients) {
|
||||
double value = coefficient.charAt(0) == '-' ? -1 : 1;
|
||||
int index = Integer.parseInt(coefficient.replaceFirst("[+|-]", "").trim());
|
||||
lhs[index] = value;
|
||||
}
|
||||
return new LinearConstraint(lhs, relationship, rhs);
|
||||
}
|
||||
|
||||
private static boolean validSolution(PointValuePair solution, List<LinearConstraint> constraints, double epsilon) {
|
||||
double[] vals = solution.getPoint();
|
||||
for (LinearConstraint c : constraints) {
|
||||
double[] coeffs = c.getCoefficients().toArray();
|
||||
double result = 0.0d;
|
||||
for (int i = 0; i < vals.length; i++) {
|
||||
result += vals[i] * coeffs[i];
|
||||
}
|
||||
|
||||
switch (c.getRelationship()) {
|
||||
case EQ:
|
||||
if (!Precision.equals(result, c.getValue(), epsilon)) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
|
||||
case GEQ:
|
||||
if (Precision.compareTo(result, c.getValue(), epsilon) < 0) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
|
||||
case LEQ:
|
||||
if (Precision.compareTo(result, c.getValue(), epsilon) > 0) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.linear;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import org.apache.commons.math3.TestUtils;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class SimplexTableauTest {
|
||||
|
||||
@Test
|
||||
public void testInitialization() {
|
||||
LinearObjectiveFunction f = createFunction();
|
||||
Collection<LinearConstraint> constraints = createConstraints();
|
||||
SimplexTableau tableau =
|
||||
new SimplexTableau(f, constraints, GoalType.MAXIMIZE, false, 1.0e-6);
|
||||
double[][] expectedInitialTableau = {
|
||||
{-1, 0, -1, -1, 2, 0, 0, 0, -4},
|
||||
{ 0, 1, -15, -10, 25, 0, 0, 0, 0},
|
||||
{ 0, 0, 1, 0, -1, 1, 0, 0, 2},
|
||||
{ 0, 0, 0, 1, -1, 0, 1, 0, 3},
|
||||
{ 0, 0, 1, 1, -2, 0, 0, 1, 4}
|
||||
};
|
||||
assertMatrixEquals(expectedInitialTableau, tableau.getData());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDropPhase1Objective() {
|
||||
LinearObjectiveFunction f = createFunction();
|
||||
Collection<LinearConstraint> constraints = createConstraints();
|
||||
SimplexTableau tableau =
|
||||
new SimplexTableau(f, constraints, GoalType.MAXIMIZE, false, 1.0e-6);
|
||||
double[][] expectedTableau = {
|
||||
{ 1, -15, -10, 0, 0, 0, 0},
|
||||
{ 0, 1, 0, 1, 0, 0, 2},
|
||||
{ 0, 0, 1, 0, 1, 0, 3},
|
||||
{ 0, 1, 1, 0, 0, 1, 4}
|
||||
};
|
||||
tableau.dropPhase1Objective();
|
||||
assertMatrixEquals(expectedTableau, tableau.getData());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTableauWithNoArtificialVars() {
|
||||
LinearObjectiveFunction f = new LinearObjectiveFunction(new double[] {15, 10}, 0);
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] {1, 0}, Relationship.LEQ, 2));
|
||||
constraints.add(new LinearConstraint(new double[] {0, 1}, Relationship.LEQ, 3));
|
||||
constraints.add(new LinearConstraint(new double[] {1, 1}, Relationship.LEQ, 4));
|
||||
SimplexTableau tableau =
|
||||
new SimplexTableau(f, constraints, GoalType.MAXIMIZE, false, 1.0e-6);
|
||||
double[][] initialTableau = {
|
||||
{1, -15, -10, 25, 0, 0, 0, 0},
|
||||
{0, 1, 0, -1, 1, 0, 0, 2},
|
||||
{0, 0, 1, -1, 0, 1, 0, 3},
|
||||
{0, 1, 1, -2, 0, 0, 1, 4}
|
||||
};
|
||||
assertMatrixEquals(initialTableau, tableau.getData());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerial() {
|
||||
LinearObjectiveFunction f = createFunction();
|
||||
Collection<LinearConstraint> constraints = createConstraints();
|
||||
SimplexTableau tableau =
|
||||
new SimplexTableau(f, constraints, GoalType.MAXIMIZE, false, 1.0e-6);
|
||||
Assert.assertEquals(tableau, TestUtils.serializeAndRecover(tableau));
|
||||
}
|
||||
|
||||
private LinearObjectiveFunction createFunction() {
|
||||
return new LinearObjectiveFunction(new double[] {15, 10}, 0);
|
||||
}
|
||||
|
||||
private Collection<LinearConstraint> createConstraints() {
|
||||
Collection<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
|
||||
constraints.add(new LinearConstraint(new double[] {1, 0}, Relationship.LEQ, 2));
|
||||
constraints.add(new LinearConstraint(new double[] {0, 1}, Relationship.LEQ, 3));
|
||||
constraints.add(new LinearConstraint(new double[] {1, 1}, Relationship.EQ, 4));
|
||||
return constraints;
|
||||
}
|
||||
|
||||
private void assertMatrixEquals(double[][] expected, double[][] result) {
|
||||
Assert.assertEquals("Wrong number of rows.", expected.length, result.length);
|
||||
for (int i = 0; i < expected.length; i++) {
|
||||
Assert.assertEquals("Wrong number of columns.", expected[i].length, result[i].length);
|
||||
for (int j = 0; j < expected[i].length; j++) {
|
||||
Assert.assertEquals("Wrong value at position [" + i + "," + j + "]", expected[i][j], result[i][j], 1.0e-15);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
|
||||
import org.apache.commons.math3.optim.MaxEval;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.InitialGuess;
|
||||
import org.apache.commons.math3.optim.ObjectiveFunction;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.SimpleValueChecker;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.CircleScalar;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
|
||||
import org.apache.commons.math3.random.GaussianRandomGenerator;
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomVectorGenerator;
|
||||
import org.apache.commons.math3.random.UncorrelatedRandomVectorGenerator;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class MultiStartMultivariateOptimizerTest {
|
||||
@Test
|
||||
public void testCircleFitting() {
|
||||
CircleScalar circle = new CircleScalar();
|
||||
circle.addPoint( 30.0, 68.0);
|
||||
circle.addPoint( 50.0, -6.0);
|
||||
circle.addPoint(110.0, -20.0);
|
||||
circle.addPoint( 35.0, 15.0);
|
||||
circle.addPoint( 45.0, 97.0);
|
||||
// TODO: the wrapper around NonLinearConjugateGradientOptimizer is a temporary hack for
|
||||
// version 3.1 of the library. It should be removed when NonLinearConjugateGradientOptimizer
|
||||
// will officially be declared as implementing MultivariateDifferentiableOptimizer
|
||||
GradientMultivariateOptimizer underlying
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-10, 1e-10));
|
||||
JDKRandomGenerator g = new JDKRandomGenerator();
|
||||
g.setSeed(753289573253l);
|
||||
RandomVectorGenerator generator
|
||||
= new UncorrelatedRandomVectorGenerator(new double[] { 50, 50 },
|
||||
new double[] { 10, 10 },
|
||||
new GaussianRandomGenerator(g));
|
||||
MultiStartMultivariateOptimizer optimizer
|
||||
= new MultiStartMultivariateOptimizer(underlying, 10, generator);
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(200),
|
||||
circle.getObjectiveFunction(),
|
||||
circle.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 98.680, 47.345 }));
|
||||
Assert.assertEquals(200, optimizer.getMaxEvaluations());
|
||||
PointValuePair[] optima = optimizer.getOptima();
|
||||
for (PointValuePair o : optima) {
|
||||
Vector2D center = new Vector2D(o.getPointRef()[0], o.getPointRef()[1]);
|
||||
Assert.assertEquals(69.960161753, circle.getRadius(center), 1e-8);
|
||||
Assert.assertEquals(96.075902096, center.getX(), 1e-8);
|
||||
Assert.assertEquals(48.135167894, center.getY(), 1e-8);
|
||||
}
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 70);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 90);
|
||||
Assert.assertEquals(3.1267527, optimum.getValue(), 1e-8);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRosenbrock() {
|
||||
Rosenbrock rosenbrock = new Rosenbrock();
|
||||
SimplexOptimizer underlying
|
||||
= new SimplexOptimizer(new SimpleValueChecker(-1, 1e-3));
|
||||
NelderMeadSimplex simplex = new NelderMeadSimplex(new double[][] {
|
||||
{ -1.2, 1.0 },
|
||||
{ 0.9, 1.2 } ,
|
||||
{ 3.5, -2.3 }
|
||||
});
|
||||
JDKRandomGenerator g = new JDKRandomGenerator();
|
||||
g.setSeed(16069223052l);
|
||||
RandomVectorGenerator generator
|
||||
= new UncorrelatedRandomVectorGenerator(2, new GaussianRandomGenerator(g));
|
||||
MultiStartMultivariateOptimizer optimizer
|
||||
= new MultiStartMultivariateOptimizer(underlying, 10, generator);
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(1100),
|
||||
new ObjectiveFunction(rosenbrock),
|
||||
GoalType.MINIMIZE,
|
||||
simplex,
|
||||
new InitialGuess(new double[] { -1.2, 1.0 }));
|
||||
|
||||
Assert.assertEquals(rosenbrock.getCount(), optimizer.getEvaluations());
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 900);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 1200);
|
||||
Assert.assertTrue(optimum.getValue() < 8e-4);
|
||||
}
|
||||
|
||||
private static class Rosenbrock implements MultivariateFunction {
|
||||
private int count;
|
||||
|
||||
public Rosenbrock() {
|
||||
count = 0;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
++count;
|
||||
double a = x[1] - x[0] * x[0];
|
||||
double b = 1 - x[0];
|
||||
return 100 * a * a + b * b;
|
||||
}
|
||||
|
||||
public int getCount() {
|
||||
return count;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,195 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.optim.MaxEval;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.InitialGuess;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.ObjectiveFunction;
|
||||
import org.apache.commons.math3.optim.SimplePointChecker;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class MultivariateFunctionMappingAdapterTest {
|
||||
@Test
|
||||
public void testStartSimplexInsideRange() {
|
||||
final BiQuadratic biQuadratic = new BiQuadratic(2.0, 2.5, 1.0, 3.0, 2.0, 3.0);
|
||||
final MultivariateFunctionMappingAdapter wrapped
|
||||
= new MultivariateFunctionMappingAdapter(biQuadratic,
|
||||
biQuadratic.getLower(),
|
||||
biQuadratic.getUpper());
|
||||
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
|
||||
final AbstractSimplex simplex = new NelderMeadSimplex(new double[][] {
|
||||
wrapped.boundedToUnbounded(new double[] { 1.5, 2.75 }),
|
||||
wrapped.boundedToUnbounded(new double[] { 1.5, 2.95 }),
|
||||
wrapped.boundedToUnbounded(new double[] { 1.7, 2.90 })
|
||||
});
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(300),
|
||||
new ObjectiveFunction(wrapped),
|
||||
simplex,
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(wrapped.boundedToUnbounded(new double[] { 1.5, 2.25 })));
|
||||
final double[] bounded = wrapped.unboundedToBounded(optimum.getPoint());
|
||||
|
||||
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), bounded[0], 2e-7);
|
||||
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), bounded[1], 2e-7);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOptimumOutsideRange() {
|
||||
final BiQuadratic biQuadratic = new BiQuadratic(4.0, 0.0, 1.0, 3.0, 2.0, 3.0);
|
||||
final MultivariateFunctionMappingAdapter wrapped
|
||||
= new MultivariateFunctionMappingAdapter(biQuadratic,
|
||||
biQuadratic.getLower(),
|
||||
biQuadratic.getUpper());
|
||||
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
|
||||
final AbstractSimplex simplex = new NelderMeadSimplex(new double[][] {
|
||||
wrapped.boundedToUnbounded(new double[] { 1.5, 2.75 }),
|
||||
wrapped.boundedToUnbounded(new double[] { 1.5, 2.95 }),
|
||||
wrapped.boundedToUnbounded(new double[] { 1.7, 2.90 })
|
||||
});
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
new ObjectiveFunction(wrapped),
|
||||
simplex,
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(wrapped.boundedToUnbounded(new double[] { 1.5, 2.25 })));
|
||||
final double[] bounded = wrapped.unboundedToBounded(optimum.getPoint());
|
||||
|
||||
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), bounded[0], 2e-7);
|
||||
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), bounded[1], 2e-7);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnbounded() {
|
||||
final BiQuadratic biQuadratic = new BiQuadratic(4.0, 0.0,
|
||||
Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY,
|
||||
Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
|
||||
final MultivariateFunctionMappingAdapter wrapped
|
||||
= new MultivariateFunctionMappingAdapter(biQuadratic,
|
||||
biQuadratic.getLower(),
|
||||
biQuadratic.getUpper());
|
||||
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
|
||||
final AbstractSimplex simplex = new NelderMeadSimplex(new double[][] {
|
||||
wrapped.boundedToUnbounded(new double[] { 1.5, 2.75 }),
|
||||
wrapped.boundedToUnbounded(new double[] { 1.5, 2.95 }),
|
||||
wrapped.boundedToUnbounded(new double[] { 1.7, 2.90 })
|
||||
});
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(300),
|
||||
new ObjectiveFunction(wrapped),
|
||||
simplex,
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(wrapped.boundedToUnbounded(new double[] { 1.5, 2.25 })));
|
||||
final double[] bounded = wrapped.unboundedToBounded(optimum.getPoint());
|
||||
|
||||
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), bounded[0], 2e-7);
|
||||
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), bounded[1], 2e-7);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHalfBounded() {
|
||||
final BiQuadratic biQuadratic = new BiQuadratic(4.0, 4.0,
|
||||
1.0, Double.POSITIVE_INFINITY,
|
||||
Double.NEGATIVE_INFINITY, 3.0);
|
||||
final MultivariateFunctionMappingAdapter wrapped
|
||||
= new MultivariateFunctionMappingAdapter(biQuadratic,
|
||||
biQuadratic.getLower(),
|
||||
biQuadratic.getUpper());
|
||||
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-13, 1e-30);
|
||||
final AbstractSimplex simplex = new NelderMeadSimplex(new double[][] {
|
||||
wrapped.boundedToUnbounded(new double[] { 1.5, 2.75 }),
|
||||
wrapped.boundedToUnbounded(new double[] { 1.5, 2.95 }),
|
||||
wrapped.boundedToUnbounded(new double[] { 1.7, 2.90 })
|
||||
});
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(200),
|
||||
new ObjectiveFunction(wrapped),
|
||||
simplex,
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(wrapped.boundedToUnbounded(new double[] { 1.5, 2.25 })));
|
||||
final double[] bounded = wrapped.unboundedToBounded(optimum.getPoint());
|
||||
|
||||
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), bounded[0], 1e-7);
|
||||
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), bounded[1], 1e-7);
|
||||
}
|
||||
|
||||
private static class BiQuadratic implements MultivariateFunction {
|
||||
|
||||
private final double xOptimum;
|
||||
private final double yOptimum;
|
||||
|
||||
private final double xMin;
|
||||
private final double xMax;
|
||||
private final double yMin;
|
||||
private final double yMax;
|
||||
|
||||
public BiQuadratic(final double xOptimum, final double yOptimum,
|
||||
final double xMin, final double xMax,
|
||||
final double yMin, final double yMax) {
|
||||
this.xOptimum = xOptimum;
|
||||
this.yOptimum = yOptimum;
|
||||
this.xMin = xMin;
|
||||
this.xMax = xMax;
|
||||
this.yMin = yMin;
|
||||
this.yMax = yMax;
|
||||
}
|
||||
|
||||
public double value(double[] point) {
|
||||
// the function should never be called with out of range points
|
||||
Assert.assertTrue(point[0] >= xMin);
|
||||
Assert.assertTrue(point[0] <= xMax);
|
||||
Assert.assertTrue(point[1] >= yMin);
|
||||
Assert.assertTrue(point[1] <= yMax);
|
||||
|
||||
final double dx = point[0] - xOptimum;
|
||||
final double dy = point[1] - yOptimum;
|
||||
return dx * dx + dy * dy;
|
||||
|
||||
}
|
||||
|
||||
public double[] getLower() {
|
||||
return new double[] { xMin, yMin };
|
||||
}
|
||||
|
||||
public double[] getUpper() {
|
||||
return new double[] { xMax, yMax };
|
||||
}
|
||||
|
||||
public double getBoundedXOptimum() {
|
||||
return (xOptimum < xMin) ? xMin : ((xOptimum > xMax) ? xMax : xOptimum);
|
||||
}
|
||||
|
||||
public double getBoundedYOptimum() {
|
||||
return (yOptimum < yMin) ? yMin : ((yOptimum > yMax) ? yMax : yOptimum);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,204 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.optim.MaxEval;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.InitialGuess;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.ObjectiveFunction;
|
||||
import org.apache.commons.math3.optim.SimplePointChecker;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class MultivariateFunctionPenaltyAdapterTest {
|
||||
@Test
|
||||
public void testStartSimplexInsideRange() {
|
||||
final BiQuadratic biQuadratic = new BiQuadratic(2.0, 2.5, 1.0, 3.0, 2.0, 3.0);
|
||||
final MultivariateFunctionPenaltyAdapter wrapped
|
||||
= new MultivariateFunctionPenaltyAdapter(biQuadratic,
|
||||
biQuadratic.getLower(),
|
||||
biQuadratic.getUpper(),
|
||||
1000.0, new double[] { 100.0, 100.0 });
|
||||
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
|
||||
final AbstractSimplex simplex = new NelderMeadSimplex(new double[] { 1.0, 0.5 });
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(300),
|
||||
new ObjectiveFunction(wrapped),
|
||||
simplex,
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 1.5, 2.25 }));
|
||||
|
||||
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), optimum.getPoint()[0], 2e-7);
|
||||
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), optimum.getPoint()[1], 2e-7);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStartSimplexOutsideRange() {
|
||||
final BiQuadratic biQuadratic = new BiQuadratic(2.0, 2.5, 1.0, 3.0, 2.0, 3.0);
|
||||
final MultivariateFunctionPenaltyAdapter wrapped
|
||||
= new MultivariateFunctionPenaltyAdapter(biQuadratic,
|
||||
biQuadratic.getLower(),
|
||||
biQuadratic.getUpper(),
|
||||
1000.0, new double[] { 100.0, 100.0 });
|
||||
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
|
||||
final AbstractSimplex simplex = new NelderMeadSimplex(new double[] { 1.0, 0.5 });
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(300),
|
||||
new ObjectiveFunction(wrapped),
|
||||
simplex,
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { -1.5, 4.0 }));
|
||||
|
||||
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), optimum.getPoint()[0], 2e-7);
|
||||
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), optimum.getPoint()[1], 2e-7);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOptimumOutsideRange() {
|
||||
final BiQuadratic biQuadratic = new BiQuadratic(4.0, 0.0, 1.0, 3.0, 2.0, 3.0);
|
||||
final MultivariateFunctionPenaltyAdapter wrapped
|
||||
= new MultivariateFunctionPenaltyAdapter(biQuadratic,
|
||||
biQuadratic.getLower(),
|
||||
biQuadratic.getUpper(),
|
||||
1000.0, new double[] { 100.0, 100.0 });
|
||||
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(new SimplePointChecker<PointValuePair>(1.0e-11, 1.0e-20));
|
||||
final AbstractSimplex simplex = new NelderMeadSimplex(new double[] { 1.0, 0.5 });
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(600),
|
||||
new ObjectiveFunction(wrapped),
|
||||
simplex,
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { -1.5, 4.0 }));
|
||||
|
||||
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), optimum.getPoint()[0], 2e-7);
|
||||
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), optimum.getPoint()[1], 2e-7);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnbounded() {
|
||||
final BiQuadratic biQuadratic = new BiQuadratic(4.0, 0.0,
|
||||
Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY,
|
||||
Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
|
||||
final MultivariateFunctionPenaltyAdapter wrapped
|
||||
= new MultivariateFunctionPenaltyAdapter(biQuadratic,
|
||||
biQuadratic.getLower(),
|
||||
biQuadratic.getUpper(),
|
||||
1000.0, new double[] { 100.0, 100.0 });
|
||||
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
|
||||
final AbstractSimplex simplex = new NelderMeadSimplex(new double[] { 1.0, 0.5 });
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(300),
|
||||
new ObjectiveFunction(wrapped),
|
||||
simplex,
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { -1.5, 4.0 }));
|
||||
|
||||
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), optimum.getPoint()[0], 2e-7);
|
||||
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), optimum.getPoint()[1], 2e-7);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHalfBounded() {
|
||||
final BiQuadratic biQuadratic = new BiQuadratic(4.0, 4.0,
|
||||
1.0, Double.POSITIVE_INFINITY,
|
||||
Double.NEGATIVE_INFINITY, 3.0);
|
||||
final MultivariateFunctionPenaltyAdapter wrapped
|
||||
= new MultivariateFunctionPenaltyAdapter(biQuadratic,
|
||||
biQuadratic.getLower(),
|
||||
biQuadratic.getUpper(),
|
||||
1000.0, new double[] { 100.0, 100.0 });
|
||||
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(new SimplePointChecker<PointValuePair>(1.0e-10, 1.0e-20));
|
||||
final AbstractSimplex simplex = new NelderMeadSimplex(new double[] { 1.0, 0.5 });
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(400),
|
||||
new ObjectiveFunction(wrapped),
|
||||
simplex,
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { -1.5, 4.0 }));
|
||||
|
||||
Assert.assertEquals(biQuadratic.getBoundedXOptimum(), optimum.getPoint()[0], 2e-7);
|
||||
Assert.assertEquals(biQuadratic.getBoundedYOptimum(), optimum.getPoint()[1], 2e-7);
|
||||
}
|
||||
|
||||
private static class BiQuadratic implements MultivariateFunction {
|
||||
|
||||
private final double xOptimum;
|
||||
private final double yOptimum;
|
||||
|
||||
private final double xMin;
|
||||
private final double xMax;
|
||||
private final double yMin;
|
||||
private final double yMax;
|
||||
|
||||
public BiQuadratic(final double xOptimum, final double yOptimum,
|
||||
final double xMin, final double xMax,
|
||||
final double yMin, final double yMax) {
|
||||
this.xOptimum = xOptimum;
|
||||
this.yOptimum = yOptimum;
|
||||
this.xMin = xMin;
|
||||
this.xMax = xMax;
|
||||
this.yMin = yMin;
|
||||
this.yMax = yMax;
|
||||
}
|
||||
|
||||
public double value(double[] point) {
|
||||
// the function should never be called with out of range points
|
||||
Assert.assertTrue(point[0] >= xMin);
|
||||
Assert.assertTrue(point[0] <= xMax);
|
||||
Assert.assertTrue(point[1] >= yMin);
|
||||
Assert.assertTrue(point[1] <= yMax);
|
||||
|
||||
final double dx = point[0] - xOptimum;
|
||||
final double dy = point[1] - yOptimum;
|
||||
return dx * dx + dy * dy;
|
||||
|
||||
}
|
||||
|
||||
public double[] getLower() {
|
||||
return new double[] { xMin, yMin };
|
||||
}
|
||||
|
||||
public double[] getUpper() {
|
||||
return new double[] { xMax, yMax };
|
||||
}
|
||||
|
||||
public double getBoundedXOptimum() {
|
||||
return (xOptimum < xMin) ? xMin : ((xOptimum > xMax) ? xMax : xOptimum);
|
||||
}
|
||||
|
||||
public double getBoundedYOptimum() {
|
||||
return (yOptimum < yMin) ? yMin : ((yOptimum > yMax) ? yMax : yOptimum);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.gradient;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
|
||||
import org.apache.commons.math3.optim.ObjectiveFunction;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient;
|
||||
|
||||
/**
|
||||
* Class used in the tests.
|
||||
*/
|
||||
public class CircleScalar {
|
||||
private ArrayList<Vector2D> points;
|
||||
|
||||
public CircleScalar() {
|
||||
points = new ArrayList<Vector2D>();
|
||||
}
|
||||
|
||||
public void addPoint(double px, double py) {
|
||||
points.add(new Vector2D(px, py));
|
||||
}
|
||||
|
||||
public double getRadius(Vector2D center) {
|
||||
double r = 0;
|
||||
for (Vector2D point : points) {
|
||||
r += point.distance(center);
|
||||
}
|
||||
return r / points.size();
|
||||
}
|
||||
|
||||
public ObjectiveFunction getObjectiveFunction() {
|
||||
return new ObjectiveFunction(new MultivariateFunction() {
|
||||
public double value(double[] params) {
|
||||
Vector2D center = new Vector2D(params[0], params[1]);
|
||||
double radius = getRadius(center);
|
||||
double sum = 0;
|
||||
for (Vector2D point : points) {
|
||||
double di = point.distance(center) - radius;
|
||||
sum += di * di;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public ObjectiveFunctionGradient getObjectiveFunctionGradient() {
|
||||
return new ObjectiveFunctionGradient(new MultivariateVectorFunction() {
|
||||
public double[] value(double[] params) {
|
||||
Vector2D center = new Vector2D(params[0], params[1]);
|
||||
double radius = getRadius(center);
|
||||
// gradient of the sum of squared residuals
|
||||
double dJdX = 0;
|
||||
double dJdY = 0;
|
||||
for (Vector2D pk : points) {
|
||||
double dk = pk.distance(center);
|
||||
dJdX += (center.getX() - pk.getX()) * (dk - radius) / dk;
|
||||
dJdY += (center.getY() - pk.getY()) * (dk - radius) / dk;
|
||||
}
|
||||
dJdX *= 2;
|
||||
dJdY *= 2;
|
||||
|
||||
return new double[] { dJdX, dJdY };
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
|
@ -0,0 +1,447 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.gradient;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import org.apache.commons.math3.analysis.DifferentiableMultivariateFunction;
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
|
||||
import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableFunction;
|
||||
import org.apache.commons.math3.analysis.solvers.BrentSolver;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
||||
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
|
||||
import org.apache.commons.math3.linear.BlockRealMatrix;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.SimpleValueChecker;
|
||||
import org.apache.commons.math3.optim.InitialGuess;
|
||||
import org.apache.commons.math3.optim.MaxEval;
|
||||
import org.apache.commons.math3.optim.ObjectiveFunction;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* <p>Some of the unit tests are re-implementations of the MINPACK <a
|
||||
* href="http://www.netlib.org/minpack/ex/file17">file17</a> and <a
|
||||
* href="http://www.netlib.org/minpack/ex/file22">file22</a> test files.
|
||||
* The redistribution policy for MINPACK is available <a
|
||||
* href="http://www.netlib.org/minpack/disclaimer">here</a>, for
|
||||
* convenience, it is reproduced below.</p>
|
||||
*
|
||||
* <table border="0" width="80%" cellpadding="10" align="center" bgcolor="#E0E0E0">
|
||||
* <tr><td>
|
||||
* Minpack Copyright Notice (1999) University of Chicago.
|
||||
* All rights reserved
|
||||
* </td></tr>
|
||||
* <tr><td>
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions
|
||||
* are met:
|
||||
* <ol>
|
||||
* <li>Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.</li>
|
||||
* <li>Redistributions in binary form must reproduce the above
|
||||
* copyright notice, this list of conditions and the following
|
||||
* disclaimer in the documentation and/or other materials provided
|
||||
* with the distribution.</li>
|
||||
* <li>The end-user documentation included with the redistribution, if any,
|
||||
* must include the following acknowledgment:
|
||||
* <code>This product includes software developed by the University of
|
||||
* Chicago, as Operator of Argonne National Laboratory.</code>
|
||||
* Alternately, this acknowledgment may appear in the software itself,
|
||||
* if and wherever such third-party acknowledgments normally appear.</li>
|
||||
* <li><strong>WARRANTY DISCLAIMER. THE SOFTWARE IS SUPPLIED "AS IS"
|
||||
* WITHOUT WARRANTY OF ANY KIND. THE COPYRIGHT HOLDER, THE
|
||||
* UNITED STATES, THE UNITED STATES DEPARTMENT OF ENERGY, AND
|
||||
* THEIR EMPLOYEES: (1) DISCLAIM ANY WARRANTIES, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED WARRANTIES
|
||||
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE
|
||||
* OR NON-INFRINGEMENT, (2) DO NOT ASSUME ANY LEGAL LIABILITY
|
||||
* OR RESPONSIBILITY FOR THE ACCURACY, COMPLETENESS, OR
|
||||
* USEFULNESS OF THE SOFTWARE, (3) DO NOT REPRESENT THAT USE OF
|
||||
* THE SOFTWARE WOULD NOT INFRINGE PRIVATELY OWNED RIGHTS, (4)
|
||||
* DO NOT WARRANT THAT THE SOFTWARE WILL FUNCTION
|
||||
* UNINTERRUPTED, THAT IT IS ERROR-FREE OR THAT ANY ERRORS WILL
|
||||
* BE CORRECTED.</strong></li>
|
||||
* <li><strong>LIMITATION OF LIABILITY. IN NO EVENT WILL THE COPYRIGHT
|
||||
* HOLDER, THE UNITED STATES, THE UNITED STATES DEPARTMENT OF
|
||||
* ENERGY, OR THEIR EMPLOYEES: BE LIABLE FOR ANY INDIRECT,
|
||||
* INCIDENTAL, CONSEQUENTIAL, SPECIAL OR PUNITIVE DAMAGES OF
|
||||
* ANY KIND OR NATURE, INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
* PROFITS OR LOSS OF DATA, FOR ANY REASON WHATSOEVER, WHETHER
|
||||
* SUCH LIABILITY IS ASSERTED ON THE BASIS OF CONTRACT, TORT
|
||||
* (INCLUDING NEGLIGENCE OR STRICT LIABILITY), OR OTHERWISE,
|
||||
* EVEN IF ANY OF SAID PARTIES HAS BEEN WARNED OF THE
|
||||
* POSSIBILITY OF SUCH LOSS OR DAMAGES.</strong></li>
|
||||
* <ol></td></tr>
|
||||
* </table>
|
||||
*
|
||||
* @author Argonne National Laboratory. MINPACK project. March 1980 (original fortran minpack tests)
|
||||
* @author Burton S. Garbow (original fortran minpack tests)
|
||||
* @author Kenneth E. Hillstrom (original fortran minpack tests)
|
||||
* @author Jorge J. More (original fortran minpack tests)
|
||||
* @author Luc Maisonobe (non-minpack tests and minpack tests Java translation)
|
||||
*/
|
||||
public class NonLinearConjugateGradientOptimizerTest {
|
||||
@Test
|
||||
public void testTrivial() {
|
||||
LinearProblem problem
|
||||
= new LinearProblem(new double[][] { { 2 } }, new double[] { 3 });
|
||||
NonLinearConjugateGradientOptimizer optimizer
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-6, 1e-6));
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
problem.getObjectiveFunction(),
|
||||
problem.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 0 }));
|
||||
Assert.assertEquals(1.5, optimum.getPoint()[0], 1.0e-10);
|
||||
Assert.assertEquals(0.0, optimum.getValue(), 1.0e-10);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testColumnsPermutation() {
|
||||
LinearProblem problem
|
||||
= new LinearProblem(new double[][] { { 1.0, -1.0 }, { 0.0, 2.0 }, { 1.0, -2.0 } },
|
||||
new double[] { 4.0, 6.0, 1.0 });
|
||||
|
||||
NonLinearConjugateGradientOptimizer optimizer
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-6, 1e-6));
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
problem.getObjectiveFunction(),
|
||||
problem.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 0, 0 }));
|
||||
Assert.assertEquals(7.0, optimum.getPoint()[0], 1.0e-10);
|
||||
Assert.assertEquals(3.0, optimum.getPoint()[1], 1.0e-10);
|
||||
Assert.assertEquals(0.0, optimum.getValue(), 1.0e-10);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNoDependency() {
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 2, 0, 0, 0, 0, 0 },
|
||||
{ 0, 2, 0, 0, 0, 0 },
|
||||
{ 0, 0, 2, 0, 0, 0 },
|
||||
{ 0, 0, 0, 2, 0, 0 },
|
||||
{ 0, 0, 0, 0, 2, 0 },
|
||||
{ 0, 0, 0, 0, 0, 2 }
|
||||
}, new double[] { 0.0, 1.1, 2.2, 3.3, 4.4, 5.5 });
|
||||
NonLinearConjugateGradientOptimizer optimizer
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-6, 1e-6));
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
problem.getObjectiveFunction(),
|
||||
problem.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 0, 0, 0, 0, 0, 0 }));
|
||||
for (int i = 0; i < problem.target.length; ++i) {
|
||||
Assert.assertEquals(0.55 * i, optimum.getPoint()[i], 1.0e-10);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOneSet() {
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 1, 0, 0 },
|
||||
{ -1, 1, 0 },
|
||||
{ 0, -1, 1 }
|
||||
}, new double[] { 1, 1, 1});
|
||||
NonLinearConjugateGradientOptimizer optimizer
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-6, 1e-6));
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
problem.getObjectiveFunction(),
|
||||
problem.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 0, 0, 0 }));
|
||||
Assert.assertEquals(1.0, optimum.getPoint()[0], 1.0e-10);
|
||||
Assert.assertEquals(2.0, optimum.getPoint()[1], 1.0e-10);
|
||||
Assert.assertEquals(3.0, optimum.getPoint()[2], 1.0e-10);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTwoSets() {
|
||||
final double epsilon = 1.0e-7;
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 2, 1, 0, 4, 0, 0 },
|
||||
{ -4, -2, 3, -7, 0, 0 },
|
||||
{ 4, 1, -2, 8, 0, 0 },
|
||||
{ 0, -3, -12, -1, 0, 0 },
|
||||
{ 0, 0, 0, 0, epsilon, 1 },
|
||||
{ 0, 0, 0, 0, 1, 1 }
|
||||
}, new double[] { 2, -9, 2, 2, 1 + epsilon * epsilon, 2});
|
||||
|
||||
final Preconditioner preconditioner
|
||||
= new Preconditioner() {
|
||||
public double[] precondition(double[] point, double[] r) {
|
||||
double[] d = r.clone();
|
||||
d[0] /= 72.0;
|
||||
d[1] /= 30.0;
|
||||
d[2] /= 314.0;
|
||||
d[3] /= 260.0;
|
||||
d[4] /= 2 * (1 + epsilon * epsilon);
|
||||
d[5] /= 4.0;
|
||||
return d;
|
||||
}
|
||||
};
|
||||
|
||||
NonLinearConjugateGradientOptimizer optimizer
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-13, 1e-13),
|
||||
new BrentSolver(),
|
||||
preconditioner);
|
||||
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
problem.getObjectiveFunction(),
|
||||
problem.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 0, 0, 0, 0, 0, 0 }));
|
||||
Assert.assertEquals( 3.0, optimum.getPoint()[0], 1.0e-10);
|
||||
Assert.assertEquals( 4.0, optimum.getPoint()[1], 1.0e-10);
|
||||
Assert.assertEquals(-1.0, optimum.getPoint()[2], 1.0e-10);
|
||||
Assert.assertEquals(-2.0, optimum.getPoint()[3], 1.0e-10);
|
||||
Assert.assertEquals( 1.0 + epsilon, optimum.getPoint()[4], 1.0e-10);
|
||||
Assert.assertEquals( 1.0 - epsilon, optimum.getPoint()[5], 1.0e-10);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNonInversible() {
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 1, 2, -3 },
|
||||
{ 2, 1, 3 },
|
||||
{ -3, 0, -9 }
|
||||
}, new double[] { 1, 1, 1 });
|
||||
NonLinearConjugateGradientOptimizer optimizer
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-6, 1e-6));
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
problem.getObjectiveFunction(),
|
||||
problem.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 0, 0, 0 }));
|
||||
Assert.assertTrue(optimum.getValue() > 0.5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIllConditioned() {
|
||||
LinearProblem problem1 = new LinearProblem(new double[][] {
|
||||
{ 10.0, 7.0, 8.0, 7.0 },
|
||||
{ 7.0, 5.0, 6.0, 5.0 },
|
||||
{ 8.0, 6.0, 10.0, 9.0 },
|
||||
{ 7.0, 5.0, 9.0, 10.0 }
|
||||
}, new double[] { 32, 23, 33, 31 });
|
||||
NonLinearConjugateGradientOptimizer optimizer
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-13, 1e-13),
|
||||
new BrentSolver(1e-15, 1e-15));
|
||||
PointValuePair optimum1
|
||||
= optimizer.optimize(new MaxEval(200),
|
||||
problem1.getObjectiveFunction(),
|
||||
problem1.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 0, 1, 2, 3 }));
|
||||
Assert.assertEquals(1.0, optimum1.getPoint()[0], 1.0e-4);
|
||||
Assert.assertEquals(1.0, optimum1.getPoint()[1], 1.0e-4);
|
||||
Assert.assertEquals(1.0, optimum1.getPoint()[2], 1.0e-4);
|
||||
Assert.assertEquals(1.0, optimum1.getPoint()[3], 1.0e-4);
|
||||
|
||||
LinearProblem problem2 = new LinearProblem(new double[][] {
|
||||
{ 10.00, 7.00, 8.10, 7.20 },
|
||||
{ 7.08, 5.04, 6.00, 5.00 },
|
||||
{ 8.00, 5.98, 9.89, 9.00 },
|
||||
{ 6.99, 4.99, 9.00, 9.98 }
|
||||
}, new double[] { 32, 23, 33, 31 });
|
||||
PointValuePair optimum2
|
||||
= optimizer.optimize(new MaxEval(200),
|
||||
problem2.getObjectiveFunction(),
|
||||
problem2.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 0, 1, 2, 3 }));
|
||||
Assert.assertEquals(-81.0, optimum2.getPoint()[0], 1.0e-1);
|
||||
Assert.assertEquals(137.0, optimum2.getPoint()[1], 1.0e-1);
|
||||
Assert.assertEquals(-34.0, optimum2.getPoint()[2], 1.0e-1);
|
||||
Assert.assertEquals( 22.0, optimum2.getPoint()[3], 1.0e-1);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMoreEstimatedParametersSimple() {
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 3.0, 2.0, 0.0, 0.0 },
|
||||
{ 0.0, 1.0, -1.0, 1.0 },
|
||||
{ 2.0, 0.0, 1.0, 0.0 }
|
||||
}, new double[] { 7.0, 3.0, 5.0 });
|
||||
|
||||
NonLinearConjugateGradientOptimizer optimizer
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-6, 1e-6));
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
problem.getObjectiveFunction(),
|
||||
problem.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 7, 6, 5, 4 }));
|
||||
Assert.assertEquals(0, optimum.getValue(), 1.0e-10);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMoreEstimatedParametersUnsorted() {
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 1.0, 1.0, 0.0, 0.0, 0.0, 0.0 },
|
||||
{ 0.0, 0.0, 1.0, 1.0, 1.0, 0.0 },
|
||||
{ 0.0, 0.0, 0.0, 0.0, 1.0, -1.0 },
|
||||
{ 0.0, 0.0, -1.0, 1.0, 0.0, 1.0 },
|
||||
{ 0.0, 0.0, 0.0, -1.0, 1.0, 0.0 }
|
||||
}, new double[] { 3.0, 12.0, -1.0, 7.0, 1.0 });
|
||||
NonLinearConjugateGradientOptimizer optimizer
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-6, 1e-6));
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
problem.getObjectiveFunction(),
|
||||
problem.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 2, 2, 2, 2, 2, 2 }));
|
||||
Assert.assertEquals(0, optimum.getValue(), 1.0e-10);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRedundantEquations() {
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 1.0, 1.0 },
|
||||
{ 1.0, -1.0 },
|
||||
{ 1.0, 3.0 }
|
||||
}, new double[] { 3.0, 1.0, 5.0 });
|
||||
|
||||
NonLinearConjugateGradientOptimizer optimizer
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-6, 1e-6));
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
problem.getObjectiveFunction(),
|
||||
problem.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 1, 1 }));
|
||||
Assert.assertEquals(2.0, optimum.getPoint()[0], 1.0e-8);
|
||||
Assert.assertEquals(1.0, optimum.getPoint()[1], 1.0e-8);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInconsistentEquations() {
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 1.0, 1.0 },
|
||||
{ 1.0, -1.0 },
|
||||
{ 1.0, 3.0 }
|
||||
}, new double[] { 3.0, 1.0, 4.0 });
|
||||
|
||||
NonLinearConjugateGradientOptimizer optimizer
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-6, 1e-6));
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
problem.getObjectiveFunction(),
|
||||
problem.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 1, 1 }));
|
||||
Assert.assertTrue(optimum.getValue() > 0.1);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCircleFitting() {
|
||||
CircleScalar problem = new CircleScalar();
|
||||
problem.addPoint( 30.0, 68.0);
|
||||
problem.addPoint( 50.0, -6.0);
|
||||
problem.addPoint(110.0, -20.0);
|
||||
problem.addPoint( 35.0, 15.0);
|
||||
problem.addPoint( 45.0, 97.0);
|
||||
NonLinearConjugateGradientOptimizer optimizer
|
||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||
new SimpleValueChecker(1e-30, 1e-30),
|
||||
new BrentSolver(1e-15, 1e-13));
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
problem.getObjectiveFunction(),
|
||||
problem.getObjectiveFunctionGradient(),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 98.680, 47.345 }));
|
||||
Vector2D center = new Vector2D(optimum.getPointRef()[0], optimum.getPointRef()[1]);
|
||||
Assert.assertEquals(69.960161753, problem.getRadius(center), 1.0e-8);
|
||||
Assert.assertEquals(96.075902096, center.getX(), 1.0e-8);
|
||||
Assert.assertEquals(48.135167894, center.getY(), 1.0e-8);
|
||||
}
|
||||
|
||||
private static class LinearProblem {
|
||||
final RealMatrix factors;
|
||||
final double[] target;
|
||||
|
||||
public LinearProblem(double[][] factors,
|
||||
double[] target) {
|
||||
this.factors = new BlockRealMatrix(factors);
|
||||
this.target = target;
|
||||
}
|
||||
|
||||
public ObjectiveFunction getObjectiveFunction() {
|
||||
return new ObjectiveFunction(new MultivariateFunction() {
|
||||
public double value(double[] point) {
|
||||
double[] y = factors.operate(point);
|
||||
double sum = 0;
|
||||
for (int i = 0; i < y.length; ++i) {
|
||||
double ri = y[i] - target[i];
|
||||
sum += ri * ri;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public ObjectiveFunctionGradient getObjectiveFunctionGradient() {
|
||||
return new ObjectiveFunctionGradient(new MultivariateVectorFunction() {
|
||||
public double[] value(double[] point) {
|
||||
double[] r = factors.operate(point);
|
||||
for (int i = 0; i < r.length; ++i) {
|
||||
r[i] -= target[i];
|
||||
}
|
||||
double[] p = factors.transpose().operate(r);
|
||||
for (int i = 0; i < p.length; ++i) {
|
||||
p[i] *= 2;
|
||||
}
|
||||
return p;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,627 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.optim.MaxEval;
|
||||
import org.apache.commons.math3.optim.ObjectiveFunction;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.InitialGuess;
|
||||
import org.apache.commons.math3.optim.SimpleBounds;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Test for {@link BOBYQAOptimizer}.
|
||||
*/
|
||||
public class BOBYQAOptimizerTest {
|
||||
|
||||
static final int DIM = 13;
|
||||
|
||||
@Test(expected=NumberIsTooLargeException.class)
|
||||
public void testInitOutOfBounds() {
|
||||
double[] startPoint = point(DIM, 3);
|
||||
double[][] boundaries = boundaries(DIM, -1, 2);
|
||||
doTest(new Rosen(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-13, 1e-6, 2000, null);
|
||||
}
|
||||
|
||||
@Test(expected=DimensionMismatchException.class)
|
||||
public void testBoundariesDimensionMismatch() {
|
||||
double[] startPoint = point(DIM, 0.5);
|
||||
double[][] boundaries = boundaries(DIM + 1, -1, 2);
|
||||
doTest(new Rosen(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-13, 1e-6, 2000, null);
|
||||
}
|
||||
|
||||
@Test(expected=NumberIsTooSmallException.class)
|
||||
public void testProblemDimensionTooSmall() {
|
||||
double[] startPoint = point(1, 0.5);
|
||||
doTest(new Rosen(), startPoint, null,
|
||||
GoalType.MINIMIZE,
|
||||
1e-13, 1e-6, 2000, null);
|
||||
}
|
||||
|
||||
@Test(expected=TooManyEvaluationsException.class)
|
||||
public void testMaxEvaluations() {
|
||||
final int lowMaxEval = 2;
|
||||
double[] startPoint = point(DIM, 0.1);
|
||||
double[][] boundaries = null;
|
||||
doTest(new Rosen(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-13, 1e-6, lowMaxEval, null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRosen() {
|
||||
double[] startPoint = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected = new PointValuePair(point(DIM,1.0),0.0);
|
||||
doTest(new Rosen(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-13, 1e-6, 2000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMaximize() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected = new PointValuePair(point(DIM,0.0),1.0);
|
||||
doTest(new MinusElli(), startPoint, boundaries,
|
||||
GoalType.MAXIMIZE,
|
||||
2e-10, 5e-6, 1000, expected);
|
||||
boundaries = boundaries(DIM,-0.3,0.3);
|
||||
startPoint = point(DIM,0.1);
|
||||
doTest(new MinusElli(), startPoint, boundaries,
|
||||
GoalType.MAXIMIZE,
|
||||
2e-10, 5e-6, 1000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEllipse() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Elli(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-13, 1e-6, 1000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testElliRotated() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new ElliRotated(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-12, 1e-6, 10000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCigar() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Cigar(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-13, 1e-6, 100, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTwoAxes() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new TwoAxes(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE, 2*
|
||||
1e-13, 1e-6, 100, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCigTab() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new CigTab(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-13, 5e-5, 100, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSphere() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Sphere(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-13, 1e-6, 100, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTablet() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Tablet(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-13, 1e-6, 100, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDiffPow() {
|
||||
double[] startPoint = point(DIM/2,1.0);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM/2,0.0),0.0);
|
||||
doTest(new DiffPow(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-8, 1e-1, 12000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSsDiffPow() {
|
||||
double[] startPoint = point(DIM/2,1.0);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM/2,0.0),0.0);
|
||||
doTest(new SsDiffPow(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-2, 1.3e-1, 50000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAckley() {
|
||||
double[] startPoint = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Ackley(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-8, 1e-5, 1000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRastrigin() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Rastrigin(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-13, 1e-6, 1000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConstrainedRosen() {
|
||||
double[] startPoint = point(DIM,0.1);
|
||||
|
||||
double[][] boundaries = boundaries(DIM,-1,2);
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,1.0),0.0);
|
||||
doTest(new Rosen(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-13, 1e-6, 2000, expected);
|
||||
}
|
||||
|
||||
// See MATH-728
|
||||
@Test
|
||||
public void testConstrainedRosenWithMoreInterpolationPoints() {
|
||||
final double[] startPoint = point(DIM, 0.1);
|
||||
final double[][] boundaries = boundaries(DIM, -1, 2);
|
||||
final PointValuePair expected = new PointValuePair(point(DIM, 1.0), 0.0);
|
||||
|
||||
// This should have been 78 because in the code the hard limit is
|
||||
// said to be
|
||||
// ((DIM + 1) * (DIM + 2)) / 2 - (2 * DIM + 1)
|
||||
// i.e. 78 in this case, but the test fails for 48, 59, 62, 63, 64,
|
||||
// 65, 66, ...
|
||||
final int maxAdditionalPoints = 47;
|
||||
|
||||
for (int num = 1; num <= maxAdditionalPoints; num++) {
|
||||
doTest(new Rosen(), startPoint, boundaries,
|
||||
GoalType.MINIMIZE,
|
||||
1e-12, 1e-6, 2000,
|
||||
num,
|
||||
expected,
|
||||
"num=" + num);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param func Function to optimize.
|
||||
* @param startPoint Starting point.
|
||||
* @param boundaries Upper / lower point limit.
|
||||
* @param goal Minimization or maximization.
|
||||
* @param fTol Tolerance relative error on the objective function.
|
||||
* @param pointTol Tolerance for checking that the optimum is correct.
|
||||
* @param maxEvaluations Maximum number of evaluations.
|
||||
* @param expected Expected point / value.
|
||||
*/
|
||||
private void doTest(MultivariateFunction func,
|
||||
double[] startPoint,
|
||||
double[][] boundaries,
|
||||
GoalType goal,
|
||||
double fTol,
|
||||
double pointTol,
|
||||
int maxEvaluations,
|
||||
PointValuePair expected) {
|
||||
doTest(func,
|
||||
startPoint,
|
||||
boundaries,
|
||||
goal,
|
||||
fTol,
|
||||
pointTol,
|
||||
maxEvaluations,
|
||||
0,
|
||||
expected,
|
||||
"");
|
||||
}
|
||||
|
||||
/**
|
||||
* @param func Function to optimize.
|
||||
* @param startPoint Starting point.
|
||||
* @param boundaries Upper / lower point limit.
|
||||
* @param goal Minimization or maximization.
|
||||
* @param fTol Tolerance relative error on the objective function.
|
||||
* @param pointTol Tolerance for checking that the optimum is correct.
|
||||
* @param maxEvaluations Maximum number of evaluations.
|
||||
* @param additionalInterpolationPoints Number of interpolation to used
|
||||
* in addition to the default (2 * dim + 1).
|
||||
* @param expected Expected point / value.
|
||||
*/
|
||||
private void doTest(MultivariateFunction func,
|
||||
double[] startPoint,
|
||||
double[][] boundaries,
|
||||
GoalType goal,
|
||||
double fTol,
|
||||
double pointTol,
|
||||
int maxEvaluations,
|
||||
int additionalInterpolationPoints,
|
||||
PointValuePair expected,
|
||||
String assertMsg) {
|
||||
|
||||
// System.out.println(func.getClass().getName() + " BEGIN"); // XXX
|
||||
|
||||
int dim = startPoint.length;
|
||||
final int numIterpolationPoints = 2 * dim + 1 + additionalInterpolationPoints;
|
||||
BOBYQAOptimizer optim = new BOBYQAOptimizer(numIterpolationPoints);
|
||||
PointValuePair result = boundaries == null ?
|
||||
optim.optimize(new MaxEval(maxEvaluations),
|
||||
new ObjectiveFunction(func),
|
||||
goal,
|
||||
SimpleBounds.unbounded(dim),
|
||||
new InitialGuess(startPoint)) :
|
||||
optim.optimize(new MaxEval(maxEvaluations),
|
||||
new ObjectiveFunction(func),
|
||||
goal,
|
||||
new InitialGuess(startPoint),
|
||||
new SimpleBounds(boundaries[0],
|
||||
boundaries[1]));
|
||||
// System.out.println(func.getClass().getName() + " = "
|
||||
// + optim.getEvaluations() + " f(");
|
||||
// for (double x: result.getPoint()) System.out.print(x + " ");
|
||||
// System.out.println(") = " + result.getValue());
|
||||
Assert.assertEquals(assertMsg, expected.getValue(), result.getValue(), fTol);
|
||||
for (int i = 0; i < dim; i++) {
|
||||
Assert.assertEquals(expected.getPoint()[i],
|
||||
result.getPoint()[i], pointTol);
|
||||
}
|
||||
|
||||
// System.out.println(func.getClass().getName() + " END"); // XXX
|
||||
}
|
||||
|
||||
private static double[] point(int n, double value) {
|
||||
double[] ds = new double[n];
|
||||
Arrays.fill(ds, value);
|
||||
return ds;
|
||||
}
|
||||
|
||||
private static double[][] boundaries(int dim,
|
||||
double lower, double upper) {
|
||||
double[][] boundaries = new double[2][dim];
|
||||
for (int i = 0; i < dim; i++)
|
||||
boundaries[0][i] = lower;
|
||||
for (int i = 0; i < dim; i++)
|
||||
boundaries[1][i] = upper;
|
||||
return boundaries;
|
||||
}
|
||||
|
||||
private static class Sphere implements MultivariateFunction {
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
for (int i = 0; i < x.length; ++i)
|
||||
f += x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Cigar implements MultivariateFunction {
|
||||
private double factor;
|
||||
|
||||
Cigar() {
|
||||
this(1e3);
|
||||
}
|
||||
|
||||
Cigar(double axisratio) {
|
||||
factor = axisratio * axisratio;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = x[0] * x[0];
|
||||
for (int i = 1; i < x.length; ++i)
|
||||
f += factor * x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Tablet implements MultivariateFunction {
|
||||
private double factor;
|
||||
|
||||
Tablet() {
|
||||
this(1e3);
|
||||
}
|
||||
|
||||
Tablet(double axisratio) {
|
||||
factor = axisratio * axisratio;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = factor * x[0] * x[0];
|
||||
for (int i = 1; i < x.length; ++i)
|
||||
f += x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class CigTab implements MultivariateFunction {
|
||||
private double factor;
|
||||
|
||||
CigTab() {
|
||||
this(1e4);
|
||||
}
|
||||
|
||||
CigTab(double axisratio) {
|
||||
factor = axisratio;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
int end = x.length - 1;
|
||||
double f = x[0] * x[0] / factor + factor * x[end] * x[end];
|
||||
for (int i = 1; i < end; ++i)
|
||||
f += x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class TwoAxes implements MultivariateFunction {
|
||||
|
||||
private double factor;
|
||||
|
||||
TwoAxes() {
|
||||
this(1e6);
|
||||
}
|
||||
|
||||
TwoAxes(double axisratio) {
|
||||
factor = axisratio * axisratio;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
for (int i = 0; i < x.length; ++i)
|
||||
f += (i < x.length / 2 ? factor : 1) * x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class ElliRotated implements MultivariateFunction {
|
||||
private Basis B = new Basis();
|
||||
private double factor;
|
||||
|
||||
ElliRotated() {
|
||||
this(1e3);
|
||||
}
|
||||
|
||||
ElliRotated(double axisratio) {
|
||||
factor = axisratio * axisratio;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
x = B.Rotate(x);
|
||||
for (int i = 0; i < x.length; ++i)
|
||||
f += Math.pow(factor, i / (x.length - 1.)) * x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Elli implements MultivariateFunction {
|
||||
|
||||
private double factor;
|
||||
|
||||
Elli() {
|
||||
this(1e3);
|
||||
}
|
||||
|
||||
Elli(double axisratio) {
|
||||
factor = axisratio * axisratio;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
for (int i = 0; i < x.length; ++i)
|
||||
f += Math.pow(factor, i / (x.length - 1.)) * x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class MinusElli implements MultivariateFunction {
|
||||
private final Elli elli = new Elli();
|
||||
public double value(double[] x) {
|
||||
return 1.0 - elli.value(x);
|
||||
}
|
||||
}
|
||||
|
||||
private static class DiffPow implements MultivariateFunction {
|
||||
// private int fcount = 0;
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
for (int i = 0; i < x.length; ++i)
|
||||
f += Math.pow(Math.abs(x[i]), 2. + 10 * (double) i
|
||||
/ (x.length - 1.));
|
||||
// System.out.print("" + (fcount++) + ") ");
|
||||
// for (int i = 0; i < x.length; i++)
|
||||
// System.out.print(x[i] + " ");
|
||||
// System.out.println(" = " + f);
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class SsDiffPow implements MultivariateFunction {
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = Math.pow(new DiffPow().value(x), 0.25);
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Rosen implements MultivariateFunction {
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
for (int i = 0; i < x.length - 1; ++i)
|
||||
f += 1e2 * (x[i] * x[i] - x[i + 1]) * (x[i] * x[i] - x[i + 1])
|
||||
+ (x[i] - 1.) * (x[i] - 1.);
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Ackley implements MultivariateFunction {
|
||||
private double axisratio;
|
||||
|
||||
Ackley(double axra) {
|
||||
axisratio = axra;
|
||||
}
|
||||
|
||||
public Ackley() {
|
||||
this(1);
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
double res2 = 0;
|
||||
double fac = 0;
|
||||
for (int i = 0; i < x.length; ++i) {
|
||||
fac = Math.pow(axisratio, (i - 1.) / (x.length - 1.));
|
||||
f += fac * fac * x[i] * x[i];
|
||||
res2 += Math.cos(2. * Math.PI * fac * x[i]);
|
||||
}
|
||||
f = (20. - 20. * Math.exp(-0.2 * Math.sqrt(f / x.length))
|
||||
+ Math.exp(1.) - Math.exp(res2 / x.length));
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Rastrigin implements MultivariateFunction {
|
||||
|
||||
private double axisratio;
|
||||
private double amplitude;
|
||||
|
||||
Rastrigin() {
|
||||
this(1, 10);
|
||||
}
|
||||
|
||||
Rastrigin(double axisratio, double amplitude) {
|
||||
this.axisratio = axisratio;
|
||||
this.amplitude = amplitude;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
double fac;
|
||||
for (int i = 0; i < x.length; ++i) {
|
||||
fac = Math.pow(axisratio, (i - 1.) / (x.length - 1.));
|
||||
if (i == 0 && x[i] < 0)
|
||||
fac *= 1.;
|
||||
f += fac * fac * x[i] * x[i] + amplitude
|
||||
* (1. - Math.cos(2. * Math.PI * fac * x[i]));
|
||||
}
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Basis {
|
||||
double[][] basis;
|
||||
Random rand = new Random(2); // use not always the same basis
|
||||
|
||||
double[] Rotate(double[] x) {
|
||||
GenBasis(x.length);
|
||||
double[] y = new double[x.length];
|
||||
for (int i = 0; i < x.length; ++i) {
|
||||
y[i] = 0;
|
||||
for (int j = 0; j < x.length; ++j)
|
||||
y[i] += basis[i][j] * x[j];
|
||||
}
|
||||
return y;
|
||||
}
|
||||
|
||||
void GenBasis(int DIM) {
|
||||
if (basis != null ? basis.length == DIM : false)
|
||||
return;
|
||||
|
||||
double sp;
|
||||
int i, j, k;
|
||||
|
||||
/* generate orthogonal basis */
|
||||
basis = new double[DIM][DIM];
|
||||
for (i = 0; i < DIM; ++i) {
|
||||
/* sample components gaussian */
|
||||
for (j = 0; j < DIM; ++j)
|
||||
basis[i][j] = rand.nextGaussian();
|
||||
/* substract projection of previous vectors */
|
||||
for (j = i - 1; j >= 0; --j) {
|
||||
for (sp = 0., k = 0; k < DIM; ++k)
|
||||
sp += basis[i][k] * basis[j][k]; /* scalar product */
|
||||
for (k = 0; k < DIM; ++k)
|
||||
basis[i][k] -= sp * basis[j][k]; /* substract */
|
||||
}
|
||||
/* normalize */
|
||||
for (sp = 0., k = 0; k < DIM; ++k)
|
||||
sp += basis[i][k] * basis[i][k]; /* squared norm */
|
||||
for (k = 0; k < DIM; ++k)
|
||||
basis[i][k] /= Math.sqrt(sp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,794 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
import org.apache.commons.math3.Retry;
|
||||
import org.apache.commons.math3.RetryRunner;
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.MathUnsupportedOperationException;
|
||||
import org.apache.commons.math3.exception.MathIllegalStateException;
|
||||
import org.apache.commons.math3.exception.NotPositiveException;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.InitialGuess;
|
||||
import org.apache.commons.math3.optim.SimpleBounds;
|
||||
import org.apache.commons.math3.optim.ObjectiveFunction;
|
||||
import org.apache.commons.math3.optim.MaxEval;
|
||||
import org.apache.commons.math3.random.MersenneTwister;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
/**
|
||||
* Test for {@link CMAESOptimizer}.
|
||||
*/
|
||||
@RunWith(RetryRunner.class)
|
||||
public class CMAESOptimizerTest {
|
||||
|
||||
static final int DIM = 13;
|
||||
static final int LAMBDA = 4 + (int)(3.*Math.log(DIM));
|
||||
|
||||
@Test(expected = NumberIsTooLargeException.class)
|
||||
public void testInitOutofbounds1() {
|
||||
double[] startPoint = point(DIM,3);
|
||||
double[] insigma = point(DIM, 0.3);
|
||||
double[][] boundaries = boundaries(DIM,-1,2);
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,1.0),0.0);
|
||||
doTest(new Rosen(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
@Test(expected = NumberIsTooSmallException.class)
|
||||
public void testInitOutofbounds2() {
|
||||
double[] startPoint = point(DIM, -2);
|
||||
double[] insigma = point(DIM, 0.3);
|
||||
double[][] boundaries = boundaries(DIM,-1,2);
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,1.0),0.0);
|
||||
doTest(new Rosen(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test(expected = DimensionMismatchException.class)
|
||||
public void testBoundariesDimensionMismatch() {
|
||||
double[] startPoint = point(DIM,0.5);
|
||||
double[] insigma = point(DIM, 0.3);
|
||||
double[][] boundaries = boundaries(DIM+1,-1,2);
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,1.0),0.0);
|
||||
doTest(new Rosen(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test(expected = NotPositiveException.class)
|
||||
public void testInputSigmaNegative() {
|
||||
double[] startPoint = point(DIM,0.5);
|
||||
double[] insigma = point(DIM,-0.5);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,1.0),0.0);
|
||||
doTest(new Rosen(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test(expected = OutOfRangeException.class)
|
||||
public void testInputSigmaOutOfRange() {
|
||||
double[] startPoint = point(DIM,0.5);
|
||||
double[] insigma = point(DIM, 1.1);
|
||||
double[][] boundaries = boundaries(DIM,-0.5,0.5);
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,1.0),0.0);
|
||||
doTest(new Rosen(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test(expected = DimensionMismatchException.class)
|
||||
public void testInputSigmaDimensionMismatch() {
|
||||
double[] startPoint = point(DIM,0.5);
|
||||
double[] insigma = point(DIM + 1, 0.5);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,1.0),0.0);
|
||||
doTest(new Rosen(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Retry(3)
|
||||
public void testRosen() {
|
||||
double[] startPoint = point(DIM,0.1);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,1.0),0.0);
|
||||
doTest(new Rosen(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
doTest(new Rosen(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Retry(3)
|
||||
public void testMaximize() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),1.0);
|
||||
doTest(new MinusElli(), startPoint, insigma, boundaries,
|
||||
GoalType.MAXIMIZE, LAMBDA, true, 0, 1.0-1e-13,
|
||||
2e-10, 5e-6, 100000, expected);
|
||||
doTest(new MinusElli(), startPoint, insigma, boundaries,
|
||||
GoalType.MAXIMIZE, LAMBDA, false, 0, 1.0-1e-13,
|
||||
2e-10, 5e-6, 100000, expected);
|
||||
boundaries = boundaries(DIM,-0.3,0.3);
|
||||
startPoint = point(DIM,0.1);
|
||||
doTest(new MinusElli(), startPoint, insigma, boundaries,
|
||||
GoalType.MAXIMIZE, LAMBDA, true, 0, 1.0-1e-13,
|
||||
2e-10, 5e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEllipse() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Elli(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
doTest(new Elli(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testElliRotated() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new ElliRotated(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
doTest(new ElliRotated(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCigar() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Cigar(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 200000, expected);
|
||||
doTest(new Cigar(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCigarWithBoundaries() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = boundaries(DIM, -1e100, Double.POSITIVE_INFINITY);
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Cigar(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 200000, expected);
|
||||
doTest(new Cigar(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTwoAxes() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new TwoAxes(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, 2*LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 200000, expected);
|
||||
doTest(new TwoAxes(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, 2*LAMBDA, false, 0, 1e-13,
|
||||
1e-8, 1e-3, 200000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCigTab() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[] insigma = point(DIM,0.3);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new CigTab(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 5e-5, 100000, expected);
|
||||
doTest(new CigTab(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
|
||||
1e-13, 5e-5, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSphere() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Sphere(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
doTest(new Sphere(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTablet() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Tablet(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
doTest(new Tablet(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, false, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDiffPow() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new DiffPow(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, 10, true, 0, 1e-13,
|
||||
1e-8, 1e-1, 100000, expected);
|
||||
doTest(new DiffPow(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, 10, false, 0, 1e-13,
|
||||
1e-8, 2e-1, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSsDiffPow() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new SsDiffPow(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, 10, true, 0, 1e-13,
|
||||
1e-4, 1e-1, 200000, expected);
|
||||
doTest(new SsDiffPow(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, 10, false, 0, 1e-13,
|
||||
1e-4, 1e-1, 200000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAckley() {
|
||||
double[] startPoint = point(DIM,1.0);
|
||||
double[] insigma = point(DIM,1.0);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Ackley(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, 2*LAMBDA, true, 0, 1e-13,
|
||||
1e-9, 1e-5, 100000, expected);
|
||||
doTest(new Ackley(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, 2*LAMBDA, false, 0, 1e-13,
|
||||
1e-9, 1e-5, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRastrigin() {
|
||||
double[] startPoint = point(DIM,0.1);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,0.0),0.0);
|
||||
doTest(new Rastrigin(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, (int)(200*Math.sqrt(DIM)), true, 0, 1e-13,
|
||||
1e-13, 1e-6, 200000, expected);
|
||||
doTest(new Rastrigin(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, (int)(200*Math.sqrt(DIM)), false, 0, 1e-13,
|
||||
1e-13, 1e-6, 200000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConstrainedRosen() {
|
||||
double[] startPoint = point(DIM, 0.1);
|
||||
double[] insigma = point(DIM, 0.1);
|
||||
double[][] boundaries = boundaries(DIM, -1, 2);
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,1.0),0.0);
|
||||
doTest(new Rosen(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, 2*LAMBDA, true, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
doTest(new Rosen(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, 2*LAMBDA, false, 0, 1e-13,
|
||||
1e-13, 1e-6, 100000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDiagonalRosen() {
|
||||
double[] startPoint = point(DIM,0.1);
|
||||
double[] insigma = point(DIM,0.1);
|
||||
double[][] boundaries = null;
|
||||
PointValuePair expected =
|
||||
new PointValuePair(point(DIM,1.0),0.0);
|
||||
doTest(new Rosen(), startPoint, insigma, boundaries,
|
||||
GoalType.MINIMIZE, LAMBDA, false, 1, 1e-13,
|
||||
1e-10, 1e-4, 1000000, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath864() {
|
||||
final CMAESOptimizer optimizer
|
||||
= new CMAESOptimizer(30000, 0, true, 10,
|
||||
0, new MersenneTwister(), false, null);
|
||||
final MultivariateFunction fitnessFunction = new MultivariateFunction() {
|
||||
public double value(double[] parameters) {
|
||||
final double target = 1;
|
||||
final double error = target - parameters[0];
|
||||
return error * error;
|
||||
}
|
||||
};
|
||||
|
||||
final double[] start = { 0 };
|
||||
final double[] lower = { -1e6 };
|
||||
final double[] upper = { 1.5 };
|
||||
final double[] sigma = { 1e-1 };
|
||||
final double[] result = optimizer.optimize(new MaxEval(10000),
|
||||
new ObjectiveFunction(fitnessFunction),
|
||||
GoalType.MINIMIZE,
|
||||
new CMAESOptimizer.PopulationSize(5),
|
||||
new CMAESOptimizer.Sigma(sigma),
|
||||
new InitialGuess(start),
|
||||
new SimpleBounds(lower, upper)).getPoint();
|
||||
Assert.assertTrue("Out of bounds (" + result[0] + " > " + upper[0] + ")",
|
||||
result[0] <= upper[0]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Cf. MATH-867
|
||||
*/
|
||||
@Test
|
||||
public void testFitAccuracyDependsOnBoundary() {
|
||||
final CMAESOptimizer optimizer
|
||||
= new CMAESOptimizer(30000, 0, true, 10,
|
||||
0, new MersenneTwister(), false, null);
|
||||
final MultivariateFunction fitnessFunction = new MultivariateFunction() {
|
||||
public double value(double[] parameters) {
|
||||
final double target = 11.1;
|
||||
final double error = target - parameters[0];
|
||||
return error * error;
|
||||
}
|
||||
};
|
||||
|
||||
final double[] start = { 1 };
|
||||
|
||||
// No bounds.
|
||||
PointValuePair result = optimizer.optimize(new MaxEval(100000),
|
||||
new ObjectiveFunction(fitnessFunction),
|
||||
GoalType.MINIMIZE,
|
||||
SimpleBounds.unbounded(1),
|
||||
new CMAESOptimizer.PopulationSize(5),
|
||||
new CMAESOptimizer.Sigma(new double[] { 1e-1 }),
|
||||
new InitialGuess(start));
|
||||
final double resNoBound = result.getPoint()[0];
|
||||
|
||||
// Optimum is near the lower bound.
|
||||
final double[] lower = { -20 };
|
||||
final double[] upper = { 5e16 };
|
||||
final double[] sigma = { 10 };
|
||||
result = optimizer.optimize(new MaxEval(100000),
|
||||
new ObjectiveFunction(fitnessFunction),
|
||||
GoalType.MINIMIZE,
|
||||
new CMAESOptimizer.PopulationSize(5),
|
||||
new CMAESOptimizer.Sigma(sigma),
|
||||
new InitialGuess(start),
|
||||
new SimpleBounds(lower, upper));
|
||||
final double resNearLo = result.getPoint()[0];
|
||||
|
||||
// Optimum is near the upper bound.
|
||||
lower[0] = -5e16;
|
||||
upper[0] = 20;
|
||||
result = optimizer.optimize(new MaxEval(100000),
|
||||
new ObjectiveFunction(fitnessFunction),
|
||||
GoalType.MINIMIZE,
|
||||
new CMAESOptimizer.PopulationSize(5),
|
||||
new CMAESOptimizer.Sigma(sigma),
|
||||
new InitialGuess(start),
|
||||
new SimpleBounds(lower, upper));
|
||||
final double resNearHi = result.getPoint()[0];
|
||||
|
||||
// System.out.println("resNoBound=" + resNoBound +
|
||||
// " resNearLo=" + resNearLo +
|
||||
// " resNearHi=" + resNearHi);
|
||||
|
||||
// The two values currently differ by a substantial amount, indicating that
|
||||
// the bounds definition can prevent reaching the optimum.
|
||||
Assert.assertEquals(resNoBound, resNearLo, 1e-3);
|
||||
Assert.assertEquals(resNoBound, resNearHi, 1e-3);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param func Function to optimize.
|
||||
* @param startPoint Starting point.
|
||||
* @param inSigma Individual input sigma.
|
||||
* @param boundaries Upper / lower point limit.
|
||||
* @param goal Minimization or maximization.
|
||||
* @param lambda Population size used for offspring.
|
||||
* @param isActive Covariance update mechanism.
|
||||
* @param diagonalOnly Simplified covariance update.
|
||||
* @param stopValue Termination criteria for optimization.
|
||||
* @param fTol Tolerance relative error on the objective function.
|
||||
* @param pointTol Tolerance for checking that the optimum is correct.
|
||||
* @param maxEvaluations Maximum number of evaluations.
|
||||
* @param expected Expected point / value.
|
||||
*/
|
||||
private void doTest(MultivariateFunction func,
|
||||
double[] startPoint,
|
||||
double[] inSigma,
|
||||
double[][] boundaries,
|
||||
GoalType goal,
|
||||
int lambda,
|
||||
boolean isActive,
|
||||
int diagonalOnly,
|
||||
double stopValue,
|
||||
double fTol,
|
||||
double pointTol,
|
||||
int maxEvaluations,
|
||||
PointValuePair expected) {
|
||||
int dim = startPoint.length;
|
||||
// test diagonalOnly = 0 - slow but normally fewer feval#
|
||||
CMAESOptimizer optim = new CMAESOptimizer(30000, stopValue, isActive, diagonalOnly,
|
||||
0, new MersenneTwister(), false, null);
|
||||
PointValuePair result = boundaries == null ?
|
||||
optim.optimize(new MaxEval(maxEvaluations),
|
||||
new ObjectiveFunction(func),
|
||||
goal,
|
||||
new InitialGuess(startPoint),
|
||||
SimpleBounds.unbounded(dim),
|
||||
new CMAESOptimizer.Sigma(inSigma),
|
||||
new CMAESOptimizer.PopulationSize(lambda)) :
|
||||
optim.optimize(new MaxEval(maxEvaluations),
|
||||
new ObjectiveFunction(func),
|
||||
goal,
|
||||
new SimpleBounds(boundaries[0],
|
||||
boundaries[1]),
|
||||
new InitialGuess(startPoint),
|
||||
new CMAESOptimizer.Sigma(inSigma),
|
||||
new CMAESOptimizer.PopulationSize(lambda));
|
||||
|
||||
// System.out.println("sol=" + Arrays.toString(result.getPoint()));
|
||||
Assert.assertEquals(expected.getValue(), result.getValue(), fTol);
|
||||
for (int i = 0; i < dim; i++) {
|
||||
Assert.assertEquals(expected.getPoint()[i], result.getPoint()[i], pointTol);
|
||||
}
|
||||
}
|
||||
|
||||
private static double[] point(int n, double value) {
|
||||
double[] ds = new double[n];
|
||||
Arrays.fill(ds, value);
|
||||
return ds;
|
||||
}
|
||||
|
||||
private static double[][] boundaries(int dim,
|
||||
double lower, double upper) {
|
||||
double[][] boundaries = new double[2][dim];
|
||||
for (int i = 0; i < dim; i++)
|
||||
boundaries[0][i] = lower;
|
||||
for (int i = 0; i < dim; i++)
|
||||
boundaries[1][i] = upper;
|
||||
return boundaries;
|
||||
}
|
||||
|
||||
private static class Sphere implements MultivariateFunction {
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
for (int i = 0; i < x.length; ++i)
|
||||
f += x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Cigar implements MultivariateFunction {
|
||||
private double factor;
|
||||
|
||||
Cigar() {
|
||||
this(1e3);
|
||||
}
|
||||
|
||||
Cigar(double axisratio) {
|
||||
factor = axisratio * axisratio;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = x[0] * x[0];
|
||||
for (int i = 1; i < x.length; ++i)
|
||||
f += factor * x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Tablet implements MultivariateFunction {
|
||||
private double factor;
|
||||
|
||||
Tablet() {
|
||||
this(1e3);
|
||||
}
|
||||
|
||||
Tablet(double axisratio) {
|
||||
factor = axisratio * axisratio;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = factor * x[0] * x[0];
|
||||
for (int i = 1; i < x.length; ++i)
|
||||
f += x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class CigTab implements MultivariateFunction {
|
||||
private double factor;
|
||||
|
||||
CigTab() {
|
||||
this(1e4);
|
||||
}
|
||||
|
||||
CigTab(double axisratio) {
|
||||
factor = axisratio;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
int end = x.length - 1;
|
||||
double f = x[0] * x[0] / factor + factor * x[end] * x[end];
|
||||
for (int i = 1; i < end; ++i)
|
||||
f += x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class TwoAxes implements MultivariateFunction {
|
||||
|
||||
private double factor;
|
||||
|
||||
TwoAxes() {
|
||||
this(1e6);
|
||||
}
|
||||
|
||||
TwoAxes(double axisratio) {
|
||||
factor = axisratio * axisratio;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
for (int i = 0; i < x.length; ++i)
|
||||
f += (i < x.length / 2 ? factor : 1) * x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class ElliRotated implements MultivariateFunction {
|
||||
private Basis B = new Basis();
|
||||
private double factor;
|
||||
|
||||
ElliRotated() {
|
||||
this(1e3);
|
||||
}
|
||||
|
||||
ElliRotated(double axisratio) {
|
||||
factor = axisratio * axisratio;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
x = B.Rotate(x);
|
||||
for (int i = 0; i < x.length; ++i)
|
||||
f += Math.pow(factor, i / (x.length - 1.)) * x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Elli implements MultivariateFunction {
|
||||
|
||||
private double factor;
|
||||
|
||||
Elli() {
|
||||
this(1e3);
|
||||
}
|
||||
|
||||
Elli(double axisratio) {
|
||||
factor = axisratio * axisratio;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
for (int i = 0; i < x.length; ++i)
|
||||
f += Math.pow(factor, i / (x.length - 1.)) * x[i] * x[i];
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class MinusElli implements MultivariateFunction {
|
||||
|
||||
public double value(double[] x) {
|
||||
return 1.0-(new Elli().value(x));
|
||||
}
|
||||
}
|
||||
|
||||
private static class DiffPow implements MultivariateFunction {
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
for (int i = 0; i < x.length; ++i)
|
||||
f += Math.pow(Math.abs(x[i]), 2. + 10 * (double) i
|
||||
/ (x.length - 1.));
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class SsDiffPow implements MultivariateFunction {
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = Math.pow(new DiffPow().value(x), 0.25);
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Rosen implements MultivariateFunction {
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
for (int i = 0; i < x.length - 1; ++i)
|
||||
f += 1e2 * (x[i] * x[i] - x[i + 1]) * (x[i] * x[i] - x[i + 1])
|
||||
+ (x[i] - 1.) * (x[i] - 1.);
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Ackley implements MultivariateFunction {
|
||||
private double axisratio;
|
||||
|
||||
Ackley(double axra) {
|
||||
axisratio = axra;
|
||||
}
|
||||
|
||||
public Ackley() {
|
||||
this(1);
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
double res2 = 0;
|
||||
double fac = 0;
|
||||
for (int i = 0; i < x.length; ++i) {
|
||||
fac = Math.pow(axisratio, (i - 1.) / (x.length - 1.));
|
||||
f += fac * fac * x[i] * x[i];
|
||||
res2 += Math.cos(2. * Math.PI * fac * x[i]);
|
||||
}
|
||||
f = (20. - 20. * Math.exp(-0.2 * Math.sqrt(f / x.length))
|
||||
+ Math.exp(1.) - Math.exp(res2 / x.length));
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Rastrigin implements MultivariateFunction {
|
||||
|
||||
private double axisratio;
|
||||
private double amplitude;
|
||||
|
||||
Rastrigin() {
|
||||
this(1, 10);
|
||||
}
|
||||
|
||||
Rastrigin(double axisratio, double amplitude) {
|
||||
this.axisratio = axisratio;
|
||||
this.amplitude = amplitude;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
double f = 0;
|
||||
double fac;
|
||||
for (int i = 0; i < x.length; ++i) {
|
||||
fac = Math.pow(axisratio, (i - 1.) / (x.length - 1.));
|
||||
if (i == 0 && x[i] < 0)
|
||||
fac *= 1.;
|
||||
f += fac * fac * x[i] * x[i] + amplitude
|
||||
* (1. - Math.cos(2. * Math.PI * fac * x[i]));
|
||||
}
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Basis {
|
||||
double[][] basis;
|
||||
Random rand = new Random(2); // use not always the same basis
|
||||
|
||||
double[] Rotate(double[] x) {
|
||||
GenBasis(x.length);
|
||||
double[] y = new double[x.length];
|
||||
for (int i = 0; i < x.length; ++i) {
|
||||
y[i] = 0;
|
||||
for (int j = 0; j < x.length; ++j)
|
||||
y[i] += basis[i][j] * x[j];
|
||||
}
|
||||
return y;
|
||||
}
|
||||
|
||||
void GenBasis(int DIM) {
|
||||
if (basis != null ? basis.length == DIM : false)
|
||||
return;
|
||||
|
||||
double sp;
|
||||
int i, j, k;
|
||||
|
||||
/* generate orthogonal basis */
|
||||
basis = new double[DIM][DIM];
|
||||
for (i = 0; i < DIM; ++i) {
|
||||
/* sample components gaussian */
|
||||
for (j = 0; j < DIM; ++j)
|
||||
basis[i][j] = rand.nextGaussian();
|
||||
/* substract projection of previous vectors */
|
||||
for (j = i - 1; j >= 0; --j) {
|
||||
for (sp = 0., k = 0; k < DIM; ++k)
|
||||
sp += basis[i][k] * basis[j][k]; /* scalar product */
|
||||
for (k = 0; k < DIM; ++k)
|
||||
basis[i][k] -= sp * basis[j][k]; /* substract */
|
||||
}
|
||||
/* normalize */
|
||||
for (sp = 0., k = 0; k < DIM; ++k)
|
||||
sp += basis[i][k] * basis[i][k]; /* squared norm */
|
||||
for (k = 0; k < DIM; ++k)
|
||||
basis[i][k] /= Math.sqrt(sp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,251 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.analysis.SumSincFunction;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.InitialGuess;
|
||||
import org.apache.commons.math3.optim.MaxEval;
|
||||
import org.apache.commons.math3.optim.ObjectiveFunction;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Test for {@link PowellOptimizer}.
|
||||
*/
|
||||
public class PowellOptimizerTest {
|
||||
|
||||
@Test
|
||||
public void testSumSinc() {
|
||||
final MultivariateFunction func = new SumSincFunction(-1);
|
||||
|
||||
int dim = 2;
|
||||
final double[] minPoint = new double[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
minPoint[i] = 0;
|
||||
}
|
||||
|
||||
double[] init = new double[dim];
|
||||
|
||||
// Initial is minimum.
|
||||
for (int i = 0; i < dim; i++) {
|
||||
init[i] = minPoint[i];
|
||||
}
|
||||
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-9, 1e-9);
|
||||
|
||||
// Initial is far from minimum.
|
||||
for (int i = 0; i < dim; i++) {
|
||||
init[i] = minPoint[i] + 3;
|
||||
}
|
||||
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-9, 1e-5);
|
||||
// More stringent line search tolerance enhances the precision
|
||||
// of the result.
|
||||
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-9, 1e-9, 1e-7);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testQuadratic() {
|
||||
final MultivariateFunction func = new MultivariateFunction() {
|
||||
public double value(double[] x) {
|
||||
final double a = x[0] - 1;
|
||||
final double b = x[1] - 1;
|
||||
return a * a + b * b + 1;
|
||||
}
|
||||
};
|
||||
|
||||
int dim = 2;
|
||||
final double[] minPoint = new double[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
minPoint[i] = 1;
|
||||
}
|
||||
|
||||
double[] init = new double[dim];
|
||||
|
||||
// Initial is minimum.
|
||||
for (int i = 0; i < dim; i++) {
|
||||
init[i] = minPoint[i];
|
||||
}
|
||||
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-9, 1e-8);
|
||||
|
||||
// Initial is far from minimum.
|
||||
for (int i = 0; i < dim; i++) {
|
||||
init[i] = minPoint[i] - 20;
|
||||
}
|
||||
doTest(func, minPoint, init, GoalType.MINIMIZE, 1e-9, 1e-8);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMaximizeQuadratic() {
|
||||
final MultivariateFunction func = new MultivariateFunction() {
|
||||
public double value(double[] x) {
|
||||
final double a = x[0] - 1;
|
||||
final double b = x[1] - 1;
|
||||
return -a * a - b * b + 1;
|
||||
}
|
||||
};
|
||||
|
||||
int dim = 2;
|
||||
final double[] maxPoint = new double[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
maxPoint[i] = 1;
|
||||
}
|
||||
|
||||
double[] init = new double[dim];
|
||||
|
||||
// Initial is minimum.
|
||||
for (int i = 0; i < dim; i++) {
|
||||
init[i] = maxPoint[i];
|
||||
}
|
||||
doTest(func, maxPoint, init, GoalType.MAXIMIZE, 1e-9, 1e-8);
|
||||
|
||||
// Initial is far from minimum.
|
||||
for (int i = 0; i < dim; i++) {
|
||||
init[i] = maxPoint[i] - 20;
|
||||
}
|
||||
doTest(func, maxPoint, init, GoalType.MAXIMIZE, 1e-9, 1e-8);
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure that we do not increase the number of function evaluations when
|
||||
* the function values are scaled up.
|
||||
* Note that the tolerances parameters passed to the constructor must
|
||||
* still hold sensible values because they are used to set the line search
|
||||
* tolerances.
|
||||
*/
|
||||
@Test
|
||||
public void testRelativeToleranceOnScaledValues() {
|
||||
final MultivariateFunction func = new MultivariateFunction() {
|
||||
public double value(double[] x) {
|
||||
final double a = x[0] - 1;
|
||||
final double b = x[1] - 1;
|
||||
return a * a * FastMath.sqrt(FastMath.abs(a)) + b * b + 1;
|
||||
}
|
||||
};
|
||||
|
||||
int dim = 2;
|
||||
final double[] minPoint = new double[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
minPoint[i] = 1;
|
||||
}
|
||||
|
||||
double[] init = new double[dim];
|
||||
// Initial is far from minimum.
|
||||
for (int i = 0; i < dim; i++) {
|
||||
init[i] = minPoint[i] - 20;
|
||||
}
|
||||
|
||||
final double relTol = 1e-10;
|
||||
|
||||
final int maxEval = 1000;
|
||||
// Very small absolute tolerance to rely solely on the relative
|
||||
// tolerance as a stopping criterion
|
||||
final PowellOptimizer optim = new PowellOptimizer(relTol, 1e-100);
|
||||
|
||||
final PointValuePair funcResult = optim.optimize(new MaxEval(maxEval),
|
||||
new ObjectiveFunction(func),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(init));
|
||||
final double funcValue = func.value(funcResult.getPoint());
|
||||
final int funcEvaluations = optim.getEvaluations();
|
||||
|
||||
final double scale = 1e10;
|
||||
final MultivariateFunction funcScaled = new MultivariateFunction() {
|
||||
public double value(double[] x) {
|
||||
return scale * func.value(x);
|
||||
}
|
||||
};
|
||||
|
||||
final PointValuePair funcScaledResult = optim.optimize(new MaxEval(maxEval),
|
||||
new ObjectiveFunction(funcScaled),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(init));
|
||||
final double funcScaledValue = funcScaled.value(funcScaledResult.getPoint());
|
||||
final int funcScaledEvaluations = optim.getEvaluations();
|
||||
|
||||
// Check that both minima provide the same objective funciton values,
|
||||
// within the relative function tolerance.
|
||||
Assert.assertEquals(1, funcScaledValue / (scale * funcValue), relTol);
|
||||
|
||||
// Check that the numbers of evaluations are the same.
|
||||
Assert.assertEquals(funcEvaluations, funcScaledEvaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param func Function to optimize.
|
||||
* @param optimum Expected optimum.
|
||||
* @param init Starting point.
|
||||
* @param goal Minimization or maximization.
|
||||
* @param fTol Tolerance (relative error on the objective function) for
|
||||
* "Powell" algorithm.
|
||||
* @param pointTol Tolerance for checking that the optimum is correct.
|
||||
*/
|
||||
private void doTest(MultivariateFunction func,
|
||||
double[] optimum,
|
||||
double[] init,
|
||||
GoalType goal,
|
||||
double fTol,
|
||||
double pointTol) {
|
||||
final PowellOptimizer optim = new PowellOptimizer(fTol, Math.ulp(1d));
|
||||
|
||||
final PointValuePair result = optim.optimize(new MaxEval(1000),
|
||||
new ObjectiveFunction(func),
|
||||
goal,
|
||||
new InitialGuess(init));
|
||||
final double[] point = result.getPoint();
|
||||
|
||||
for (int i = 0, dim = optimum.length; i < dim; i++) {
|
||||
Assert.assertEquals("found[" + i + "]=" + point[i] + " value=" + result.getValue(),
|
||||
optimum[i], point[i], pointTol);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param func Function to optimize.
|
||||
* @param optimum Expected optimum.
|
||||
* @param init Starting point.
|
||||
* @param goal Minimization or maximization.
|
||||
* @param fTol Tolerance (relative error on the objective function) for
|
||||
* "Powell" algorithm.
|
||||
* @param fLineTol Tolerance (relative error on the objective function)
|
||||
* for the internal line search algorithm.
|
||||
* @param pointTol Tolerance for checking that the optimum is correct.
|
||||
*/
|
||||
private void doTest(MultivariateFunction func,
|
||||
double[] optimum,
|
||||
double[] init,
|
||||
GoalType goal,
|
||||
double fTol,
|
||||
double fLineTol,
|
||||
double pointTol) {
|
||||
final PowellOptimizer optim = new PowellOptimizer(fTol, Math.ulp(1d),
|
||||
fLineTol, Math.ulp(1d));
|
||||
|
||||
final PointValuePair result = optim.optimize(new MaxEval(1000),
|
||||
new ObjectiveFunction(func),
|
||||
goal,
|
||||
new InitialGuess(init));
|
||||
final double[] point = result.getPoint();
|
||||
|
||||
for (int i = 0, dim = optimum.length; i < dim; i++) {
|
||||
Assert.assertEquals("found[" + i + "]=" + point[i] + " value=" + result.getValue(),
|
||||
optimum[i], point[i], pointTol);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,228 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.optim.MaxEval;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.InitialGuess;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.SimpleValueChecker;
|
||||
import org.apache.commons.math3.optim.ObjectiveFunction;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class SimplexOptimizerMultiDirectionalTest {
|
||||
@Test
|
||||
public void testMinimize1() {
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-11, 1e-30);
|
||||
final FourExtrema fourExtrema = new FourExtrema();
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(200),
|
||||
new ObjectiveFunction(fourExtrema),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { -3, 0 }),
|
||||
new MultiDirectionalSimplex(new double[] { 0.2, 0.2 }));
|
||||
Assert.assertEquals(fourExtrema.xM, optimum.getPoint()[0], 4e-6);
|
||||
Assert.assertEquals(fourExtrema.yP, optimum.getPoint()[1], 3e-6);
|
||||
Assert.assertEquals(fourExtrema.valueXmYp, optimum.getValue(), 8e-13);
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 120);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 150);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMinimize2() {
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-11, 1e-30);
|
||||
final FourExtrema fourExtrema = new FourExtrema();
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(200),
|
||||
new ObjectiveFunction(fourExtrema),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 1, 0 }),
|
||||
new MultiDirectionalSimplex(new double[] { 0.2, 0.2 }));
|
||||
Assert.assertEquals(fourExtrema.xP, optimum.getPoint()[0], 2e-8);
|
||||
Assert.assertEquals(fourExtrema.yM, optimum.getPoint()[1], 3e-6);
|
||||
Assert.assertEquals(fourExtrema.valueXpYm, optimum.getValue(), 2e-12);
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 120);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 150);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMaximize1() {
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-11, 1e-30);
|
||||
final FourExtrema fourExtrema = new FourExtrema();
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(200),
|
||||
new ObjectiveFunction(fourExtrema),
|
||||
GoalType.MAXIMIZE,
|
||||
new InitialGuess(new double[] { -3.0, 0.0 }),
|
||||
new MultiDirectionalSimplex(new double[] { 0.2, 0.2 }));
|
||||
Assert.assertEquals(fourExtrema.xM, optimum.getPoint()[0], 7e-7);
|
||||
Assert.assertEquals(fourExtrema.yM, optimum.getPoint()[1], 3e-7);
|
||||
Assert.assertEquals(fourExtrema.valueXmYm, optimum.getValue(), 2e-14);
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 120);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 150);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMaximize2() {
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(new SimpleValueChecker(1e-15, 1e-30));
|
||||
final FourExtrema fourExtrema = new FourExtrema();
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(200),
|
||||
new ObjectiveFunction(fourExtrema),
|
||||
GoalType.MAXIMIZE,
|
||||
new InitialGuess(new double[] { 1, 0 }),
|
||||
new MultiDirectionalSimplex(new double[] { 0.2, 0.2 }));
|
||||
Assert.assertEquals(fourExtrema.xP, optimum.getPoint()[0], 2e-8);
|
||||
Assert.assertEquals(fourExtrema.yP, optimum.getPoint()[1], 3e-6);
|
||||
Assert.assertEquals(fourExtrema.valueXpYp, optimum.getValue(), 2e-12);
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 180);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 220);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRosenbrock() {
|
||||
MultivariateFunction rosenbrock
|
||||
= new MultivariateFunction() {
|
||||
public double value(double[] x) {
|
||||
++count;
|
||||
double a = x[1] - x[0] * x[0];
|
||||
double b = 1.0 - x[0];
|
||||
return 100 * a * a + b * b;
|
||||
}
|
||||
};
|
||||
|
||||
count = 0;
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
new ObjectiveFunction(rosenbrock),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { -1.2, 1 }),
|
||||
new MultiDirectionalSimplex(new double[][] {
|
||||
{ -1.2, 1.0 },
|
||||
{ 0.9, 1.2 },
|
||||
{ 3.5, -2.3 } }));
|
||||
|
||||
Assert.assertEquals(count, optimizer.getEvaluations());
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 50);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 100);
|
||||
Assert.assertTrue(optimum.getValue() > 1e-2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPowell() {
|
||||
MultivariateFunction powell
|
||||
= new MultivariateFunction() {
|
||||
public double value(double[] x) {
|
||||
++count;
|
||||
double a = x[0] + 10 * x[1];
|
||||
double b = x[2] - x[3];
|
||||
double c = x[1] - 2 * x[2];
|
||||
double d = x[0] - x[3];
|
||||
return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
|
||||
}
|
||||
};
|
||||
|
||||
count = 0;
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(1000),
|
||||
new ObjectiveFunction(powell),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 3, -1, 0, 1 }),
|
||||
new MultiDirectionalSimplex(4));
|
||||
Assert.assertEquals(count, optimizer.getEvaluations());
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 800);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 900);
|
||||
Assert.assertTrue(optimum.getValue() > 1e-2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMath283() {
|
||||
// fails because MultiDirectional.iterateSimplex is looping forever
|
||||
// the while(true) should be replaced with a convergence check
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-14, 1e-14);
|
||||
final Gaussian2D function = new Gaussian2D(0, 0, 1);
|
||||
PointValuePair estimate = optimizer.optimize(new MaxEval(1000),
|
||||
new ObjectiveFunction(function),
|
||||
GoalType.MAXIMIZE,
|
||||
new InitialGuess(function.getMaximumPosition()),
|
||||
new MultiDirectionalSimplex(2));
|
||||
final double EPSILON = 1e-5;
|
||||
final double expectedMaximum = function.getMaximum();
|
||||
final double actualMaximum = estimate.getValue();
|
||||
Assert.assertEquals(expectedMaximum, actualMaximum, EPSILON);
|
||||
|
||||
final double[] expectedPosition = function.getMaximumPosition();
|
||||
final double[] actualPosition = estimate.getPoint();
|
||||
Assert.assertEquals(expectedPosition[0], actualPosition[0], EPSILON );
|
||||
Assert.assertEquals(expectedPosition[1], actualPosition[1], EPSILON );
|
||||
}
|
||||
|
||||
private static class FourExtrema implements MultivariateFunction {
|
||||
// The following function has 4 local extrema.
|
||||
final double xM = -3.841947088256863675365;
|
||||
final double yM = -1.391745200270734924416;
|
||||
final double xP = 0.2286682237349059125691;
|
||||
final double yP = -yM;
|
||||
final double valueXmYm = 0.2373295333134216789769; // Local maximum.
|
||||
final double valueXmYp = -valueXmYm; // Local minimum.
|
||||
final double valueXpYm = -0.7290400707055187115322; // Global minimum.
|
||||
final double valueXpYp = -valueXpYm; // Global maximum.
|
||||
|
||||
public double value(double[] variables) {
|
||||
final double x = variables[0];
|
||||
final double y = variables[1];
|
||||
return (x == 0 || y == 0) ? 0 :
|
||||
FastMath.atan(x) * FastMath.atan(x + 2) * FastMath.atan(y) * FastMath.atan(y) / (x * y);
|
||||
}
|
||||
}
|
||||
|
||||
private static class Gaussian2D implements MultivariateFunction {
|
||||
private final double[] maximumPosition;
|
||||
private final double std;
|
||||
|
||||
public Gaussian2D(double xOpt, double yOpt, double std) {
|
||||
maximumPosition = new double[] { xOpt, yOpt };
|
||||
this.std = std;
|
||||
}
|
||||
|
||||
public double getMaximum() {
|
||||
return value(maximumPosition);
|
||||
}
|
||||
|
||||
public double[] getMaximumPosition() {
|
||||
return maximumPosition.clone();
|
||||
}
|
||||
|
||||
public double value(double[] point) {
|
||||
final double x = point[0], y = point[1];
|
||||
final double twoS2 = 2.0 * std * std;
|
||||
return 1.0 / (twoS2 * FastMath.PI) * FastMath.exp(-(x * x + y * y) / twoS2);
|
||||
}
|
||||
}
|
||||
|
||||
private int count;
|
||||
}
|
|
@ -0,0 +1,295 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
|
||||
|
||||
|
||||
import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.optim.GoalType;
|
||||
import org.apache.commons.math3.optim.InitialGuess;
|
||||
import org.apache.commons.math3.optim.MaxEval;
|
||||
import org.apache.commons.math3.optim.ObjectiveFunction;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.LeastSquaresConverter;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class SimplexOptimizerNelderMeadTest {
|
||||
@Test
|
||||
public void testMinimize1() {
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
|
||||
final FourExtrema fourExtrema = new FourExtrema();
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
new ObjectiveFunction(fourExtrema),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { -3, 0 }),
|
||||
new NelderMeadSimplex(new double[] { 0.2, 0.2 }));
|
||||
Assert.assertEquals(fourExtrema.xM, optimum.getPoint()[0], 2e-7);
|
||||
Assert.assertEquals(fourExtrema.yP, optimum.getPoint()[1], 2e-5);
|
||||
Assert.assertEquals(fourExtrema.valueXmYp, optimum.getValue(), 6e-12);
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 60);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 90);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMinimize2() {
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
|
||||
final FourExtrema fourExtrema = new FourExtrema();
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
new ObjectiveFunction(fourExtrema),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 1, 0 }),
|
||||
new NelderMeadSimplex(new double[] { 0.2, 0.2 }));
|
||||
Assert.assertEquals(fourExtrema.xP, optimum.getPoint()[0], 5e-6);
|
||||
Assert.assertEquals(fourExtrema.yM, optimum.getPoint()[1], 6e-6);
|
||||
Assert.assertEquals(fourExtrema.valueXpYm, optimum.getValue(), 1e-11);
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 60);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 90);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMaximize1() {
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
|
||||
final FourExtrema fourExtrema = new FourExtrema();
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
new ObjectiveFunction(fourExtrema),
|
||||
GoalType.MAXIMIZE,
|
||||
new InitialGuess(new double[] { -3, 0 }),
|
||||
new NelderMeadSimplex(new double[] { 0.2, 0.2 }));
|
||||
Assert.assertEquals(fourExtrema.xM, optimum.getPoint()[0], 1e-5);
|
||||
Assert.assertEquals(fourExtrema.yM, optimum.getPoint()[1], 3e-6);
|
||||
Assert.assertEquals(fourExtrema.valueXmYm, optimum.getValue(), 3e-12);
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 60);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 90);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMaximize2() {
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
|
||||
final FourExtrema fourExtrema = new FourExtrema();
|
||||
|
||||
final PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
new ObjectiveFunction(fourExtrema),
|
||||
GoalType.MAXIMIZE,
|
||||
new InitialGuess(new double[] { 1, 0 }),
|
||||
new NelderMeadSimplex(new double[] { 0.2, 0.2 }));
|
||||
Assert.assertEquals(fourExtrema.xP, optimum.getPoint()[0], 4e-6);
|
||||
Assert.assertEquals(fourExtrema.yP, optimum.getPoint()[1], 5e-6);
|
||||
Assert.assertEquals(fourExtrema.valueXpYp, optimum.getValue(), 7e-12);
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 60);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 90);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRosenbrock() {
|
||||
|
||||
Rosenbrock rosenbrock = new Rosenbrock();
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(100),
|
||||
new ObjectiveFunction(rosenbrock),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { -1.2, 1 }),
|
||||
new NelderMeadSimplex(new double[][] {
|
||||
{ -1.2, 1 },
|
||||
{ 0.9, 1.2 },
|
||||
{ 3.5, -2.3 } }));
|
||||
|
||||
Assert.assertEquals(rosenbrock.getCount(), optimizer.getEvaluations());
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 40);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 50);
|
||||
Assert.assertTrue(optimum.getValue() < 8e-4);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPowell() {
|
||||
Powell powell = new Powell();
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
|
||||
PointValuePair optimum =
|
||||
optimizer.optimize(new MaxEval(200),
|
||||
new ObjectiveFunction(powell),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 3, -1, 0, 1 }),
|
||||
new NelderMeadSimplex(4));
|
||||
Assert.assertEquals(powell.getCount(), optimizer.getEvaluations());
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 110);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 130);
|
||||
Assert.assertTrue(optimum.getValue() < 2e-3);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLeastSquares1() {
|
||||
final RealMatrix factors
|
||||
= new Array2DRowRealMatrix(new double[][] {
|
||||
{ 1, 0 },
|
||||
{ 0, 1 }
|
||||
}, false);
|
||||
LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorFunction() {
|
||||
public double[] value(double[] variables) {
|
||||
return factors.operate(variables);
|
||||
}
|
||||
}, new double[] { 2.0, -3.0 });
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-6);
|
||||
PointValuePair optimum =
|
||||
optimizer.optimize(new MaxEval(200),
|
||||
new ObjectiveFunction(ls),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 10, 10 }),
|
||||
new NelderMeadSimplex(2));
|
||||
Assert.assertEquals( 2, optimum.getPointRef()[0], 3e-5);
|
||||
Assert.assertEquals(-3, optimum.getPointRef()[1], 4e-4);
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 60);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 80);
|
||||
Assert.assertTrue(optimum.getValue() < 1.0e-6);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLeastSquares2() {
|
||||
final RealMatrix factors
|
||||
= new Array2DRowRealMatrix(new double[][] {
|
||||
{ 1, 0 },
|
||||
{ 0, 1 }
|
||||
}, false);
|
||||
LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorFunction() {
|
||||
public double[] value(double[] variables) {
|
||||
return factors.operate(variables);
|
||||
}
|
||||
}, new double[] { 2, -3 }, new double[] { 10, 0.1 });
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-6);
|
||||
PointValuePair optimum =
|
||||
optimizer.optimize(new MaxEval(200),
|
||||
new ObjectiveFunction(ls),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 10, 10 }),
|
||||
new NelderMeadSimplex(2));
|
||||
Assert.assertEquals( 2, optimum.getPointRef()[0], 5e-5);
|
||||
Assert.assertEquals(-3, optimum.getPointRef()[1], 8e-4);
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 60);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 80);
|
||||
Assert.assertTrue(optimum.getValue() < 1e-6);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLeastSquares3() {
|
||||
final RealMatrix factors =
|
||||
new Array2DRowRealMatrix(new double[][] {
|
||||
{ 1, 0 },
|
||||
{ 0, 1 }
|
||||
}, false);
|
||||
LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorFunction() {
|
||||
public double[] value(double[] variables) {
|
||||
return factors.operate(variables);
|
||||
}
|
||||
}, new double[] { 2, -3 }, new Array2DRowRealMatrix(new double [][] {
|
||||
{ 1, 1.2 }, { 1.2, 2 }
|
||||
}));
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-6);
|
||||
PointValuePair optimum
|
||||
= optimizer.optimize(new MaxEval(200),
|
||||
new ObjectiveFunction(ls),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 10, 10 }),
|
||||
new NelderMeadSimplex(2));
|
||||
Assert.assertEquals( 2, optimum.getPointRef()[0], 2e-3);
|
||||
Assert.assertEquals(-3, optimum.getPointRef()[1], 8e-4);
|
||||
Assert.assertTrue(optimizer.getEvaluations() > 60);
|
||||
Assert.assertTrue(optimizer.getEvaluations() < 80);
|
||||
Assert.assertTrue(optimum.getValue() < 1e-6);
|
||||
}
|
||||
|
||||
@Test(expected=TooManyEvaluationsException.class)
|
||||
public void testMaxIterations() {
|
||||
Powell powell = new Powell();
|
||||
SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
|
||||
optimizer.optimize(new MaxEval(20),
|
||||
new ObjectiveFunction(powell),
|
||||
GoalType.MINIMIZE,
|
||||
new InitialGuess(new double[] { 3, -1, 0, 1 }),
|
||||
new NelderMeadSimplex(4));
|
||||
}
|
||||
|
||||
private static class FourExtrema implements MultivariateFunction {
|
||||
// The following function has 4 local extrema.
|
||||
final double xM = -3.841947088256863675365;
|
||||
final double yM = -1.391745200270734924416;
|
||||
final double xP = 0.2286682237349059125691;
|
||||
final double yP = -yM;
|
||||
final double valueXmYm = 0.2373295333134216789769; // Local maximum.
|
||||
final double valueXmYp = -valueXmYm; // Local minimum.
|
||||
final double valueXpYm = -0.7290400707055187115322; // Global minimum.
|
||||
final double valueXpYp = -valueXpYm; // Global maximum.
|
||||
|
||||
public double value(double[] variables) {
|
||||
final double x = variables[0];
|
||||
final double y = variables[1];
|
||||
return (x == 0 || y == 0) ? 0 :
|
||||
FastMath.atan(x) * FastMath.atan(x + 2) * FastMath.atan(y) * FastMath.atan(y) / (x * y);
|
||||
}
|
||||
}
|
||||
|
||||
private static class Rosenbrock implements MultivariateFunction {
|
||||
private int count;
|
||||
|
||||
public Rosenbrock() {
|
||||
count = 0;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
++count;
|
||||
double a = x[1] - x[0] * x[0];
|
||||
double b = 1.0 - x[0];
|
||||
return 100 * a * a + b * b;
|
||||
}
|
||||
|
||||
public int getCount() {
|
||||
return count;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Powell implements MultivariateFunction {
|
||||
private int count;
|
||||
|
||||
public Powell() {
|
||||
count = 0;
|
||||
}
|
||||
|
||||
public double value(double[] x) {
|
||||
++count;
|
||||
double a = x[0] + 10 * x[1];
|
||||
double b = x[2] - x[3];
|
||||
double c = x[1] - 2 * x[2];
|
||||
double d = x[0] - x[3];
|
||||
return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
|
||||
}
|
||||
|
||||
public int getCount() {
|
||||
return count;
|
||||
}
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue