started refactoring of optimization framework:

- created subpackages optimization.direct, optimization.general,
   optimization.linear (currently empty) and optimization.univariate
 - removed packages analysis.minimization and estimation
 - renamed all Cost-related interfaces/classes into Objective
   (this allows both minimization and maximization)
 - added a few new general interfaces

This work is not complete yet. The direct and general packages classes
are very close to the former design, they have almost not been changed
structurally.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@748274 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2009-02-26 19:17:39 +00:00
parent 1cf41555f1
commit 722fc97a7a
47 changed files with 1676 additions and 1128 deletions

View File

@ -26,7 +26,7 @@
<!-- the following equality tests are part of the reference algorithms -->
<!-- which already know about limited precision of the double numbers -->
<Match>
<Class name="org.apache.commons.math.analysis.minimization.BrentMinimizer" />
<Class name="org.apache.commons.math.optimization.univariate.BrentMinimizer" />
<Method name="localMin" params="double,double,double,double,org.apache.commons.math.analysis.UnivariateRealFunction" returns="double" />
<Bug pattern="FE_FLOATING_POINT_EQUALITY" />
</Match>

View File

@ -17,7 +17,7 @@
-->
<!-- $Revision$ $Date$ -->
<body>
Parent package for common numerical analysis procedures, including root finding, minimization,
Parent package for common numerical analysis procedures, including root finding,
function interpolation and integration.
</body>
</html>

View File

@ -1,55 +0,0 @@
<html>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<!-- $Revision$ -->
<body>
This package provides classes to solve estimation problems.
<p>The estimation problems considered here are parametric problems where a user model
depends on initially unknown scalar parameters and several measurements made on
values that depend on the model are available. As an example, one can consider the
flow rate of a river given rain data on its vicinity, or the center and radius of a
circle given points on a ring.</p>
<p>One important class of estimation problems is weighted least squares problems.
They basically consist in finding the values for some parameters p<sub>k</sub> such
that a cost function J = sum(w<sub>i</sub> r<sub>i</sub><sup>2</sup>) is minimized.
The various r<sub>i</sub> terms represent the deviation r<sub>i</sub> =
mes<sub>i</sub> - mod<sub>i</sub> between the measurements and the parameterized
models. The w<sub>i</sub> factors are the measurements weights, they are often chosen
either all equal to 1.0 or proportional to the inverse of the variance of the
measurement type. The solver adjusts the values of the estimated parameters
p<sub>k</sub> which are not bound. It does not touch the parameters which have been
put in a bound state by the user.</p>
<p>This package provides the {@link
org.apache.commons.math.estimation.EstimatedParameter EstimatedParameter} class to
represent each estimated parameter, and the {@link
org.apache.commons.math.estimation.WeightedMeasurement WeightedMeasurement} abstract
class the user can extend to define its measurements. All parameters and measurements
are then provided to some {@link org.apache.commons.math.estimation.Estimator
Estimator} packed together in an {@link
org.apache.commons.math.estimation.EstimationProblem EstimationProblem} instance
which acts only as a container. The package provides two common estimators for
weighted least squares problems, one based on the {@link
org.apache.commons.math.estimation.GaussNewtonEstimator Gauss-Newton} method and the
other one based on the {@link
org.apache.commons.math.estimation.LevenbergMarquardtEstimator Levenberg-Marquardt}
method.</p>
</body>
</html>

View File

@ -17,6 +17,9 @@
package org.apache.commons.math.optimization;
import org.apache.commons.math.optimization.direct.DirectSearchOptimizer;
/** This interface specifies how to check if a {@link
* DirectSearchOptimizer direct search method} has converged.
*
@ -32,10 +35,16 @@ package org.apache.commons.math.optimization;
public interface ConvergenceChecker {
/** Check if the optimization algorithm has converged on the simplex.
* @param simplex ordered simplex (all points in the simplex have
* been eavluated and are sorted from lowest to largest cost)
* <p>
* When this method is called, all points in the simplex have been evaluated
* and are sorted from lowest to largest value. The values are either the
* original objective function values if the optimizer was configured for
* minimization, or the opposites of the original objective function values
* if the optimizer was configured for maximization.
* </p>
* @param simplex ordered simplex
* @return true if the algorithm is considered to have converged
*/
public boolean converged (PointCostPair[] simplex);
boolean converged(PointValuePair[] simplex);
}

View File

@ -17,21 +17,19 @@
package org.apache.commons.math.optimization;
import java.io.Serializable;
/**
* This interface represents a cost function to be minimized.
* Goal type for an optimization problem.
* @version $Revision$ $Date$
* @since 1.2
* @since 2.0
*/
public interface CostFunction {
public enum GoalType implements Serializable {
/** Maximization goal. */
MAXIMIZE,
/**
* Compute the cost associated to the given parameters array.
* @param x parameters array
* @return cost associated to the parameters array
* @exception CostException if no cost can be computed for the parameters
* @see PointCostPair
*/
public double cost(double[] x) throws CostException;
/** Minimization goal. */
MINIMIZE
}

View File

@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization;
import org.apache.commons.math.linear.RealMatrix;
/** This class converts {@link MultiObjectiveFunction vectorial
* objective functions} to {@link ObjectiveFunction scalar objective functions}
* when the goal is to minimize them.
* <p>
* This class is mostly used when the vectorial objective function represents
* residuals, i.e. differences between a theoretical result computed from a
* variables set applied to a model and a reference. Residuals are intended to be
* minimized in order to get the variables set that best fit the model to the
* reference. The reference may be obtained for example from physical measurements
* whether the model is built from theoretical considerations.
* </p>
* <p>
* This class computes a possibly weighted squared sum of the residuals, which is
* a scalar value. It implements the {@link ObjectiveFunction} interface and can
* therefore be minimized by any optimizer supporting scalar objectives functions.
* This correspond to a least square estimation.
* </p>
* <p>
* This class support combination of residuals with or without weights and correlations.
* </p>
*
* @see ObjectiveFunction
* @see MultiObjectiveFunction
* @version $Revision$ $Date$
* @since 2.0
*/
public class LeastSquaresConverter implements ObjectiveFunction {
/** Serializable version identifier. */
private static final long serialVersionUID = -5174886571116126798L;
/** Underlying vectorial function. */
private final MultiObjectiveFunction function;
/** Optional weights for the residuals. */
private final double[] weights;
/** Optional scaling matrix (weight and correlations) for the residuals. */
private final RealMatrix scale;
/** Build a simple converter for uncorrelated residuals with the same weight.
* @param function vectorial residuals function to wrap
*/
public LeastSquaresConverter (final MultiObjectiveFunction function) {
this.function = function;
this.weights = null;
this.scale = null;
}
/** Build a simple converter for uncorrelated residuals with the specific weights.
* <p>
* The scalar objective function value is computed as:
* <pre>
* objective = &sum;(weight<sub>i</sub>residual<sub>i</sub>)<sup>2</sup>
* </pre>
* </p>
* <p>
* Weights can be used for example to combine residuals with different standard
* deviations. As an example, consider a 2000 elements residuals array in which
* even elements are angular measurements in degrees with a 0.01&deg; standard
* deviation and off elements are distance measurements in meters with a 15m
* standard deviation. In this case, the weights array should be initialized with
* value 1.0/0.01 in the even elements and 1.0/15.0 in the odd elements.
* </p>
* <p>
* The residuals array computed by the function and the weights array must
* have consistent sizes or a {@link ObjectiveException} will be triggered while
* computing the scalar objective.
* </p>
* @param function vectorial residuals function to wrap
* @param weights weights to apply to the residuals
*/
public LeastSquaresConverter (final MultiObjectiveFunction function,
final double[] weights) {
this.function = function;
this.weights = weights.clone();
this.scale = null;
}
/** Build a simple convertor for correlated residuals with the specific weights.
* <p>
* The scalar objective function value is computed as:
* <pre>
* objective = &sum;(y<sub>i</sub>)<sup>2</sup> with y = scale&times;residual
* </pre>
* </p>
* <p>
* The residuals array computed by the function and the scaling matrix must
* have consistent sizes or a {@link ObjectiveException} will be triggered while
* computing the scalar objective.
* </p>
* @param function vectorial residuals function to wrap
* @param scale scaling matrix (
*/
public LeastSquaresConverter (final MultiObjectiveFunction function,
final RealMatrix scale) {
this.function = function;
this.weights = null;
this.scale = scale.copy();
}
/** {@inheritDoc} */
public double objective(final double[] variables) throws ObjectiveException {
final double[] residuals = function.objective(variables);
double sumSquares = 0;
if (weights != null) {
if (weights.length != residuals.length) {
throw new ObjectiveException("dimension mismatch {0} != {1}",
weights.length, residuals.length);
}
for (int i = 0; i < weights.length; ++i) {
final double ai = residuals[i] * weights[i];
sumSquares += ai * ai;
}
} else if (scale != null) {
if (scale.getColumnDimension() != residuals.length) {
throw new ObjectiveException("dimension mismatch {0} != {1}",
scale.getColumnDimension(), residuals.length);
}
for (final double yi : scale.operate(residuals)) {
sumSquares += yi * yi;
}
} else {
for (final double ri : residuals) {
sumSquares += ri * ri;
}
}
return sumSquares;
}
}

View File

@ -1,126 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization;
/**
* This class implements the multi-directional direct search method.
*
* @version $Revision$ $Date$
* @see NelderMead
* @since 1.2
*/
public class MultiDirectional
extends DirectSearchOptimizer {
/** Build a multi-directional optimizer with default coefficients.
* <p>The default values are 2.0 for khi and 0.5 for gamma.</p>
*/
public MultiDirectional() {
super();
this.khi = 2.0;
this.gamma = 0.5;
}
/** Build a multi-directional optimizer with specified coefficients.
* @param khi expansion coefficient
* @param gamma contraction coefficient
*/
public MultiDirectional(double khi, double gamma) {
super();
this.khi = khi;
this.gamma = gamma;
}
/** Compute the next simplex of the algorithm.
* @exception CostException if the function cannot be evaluated at
* some point
*/
protected void iterateSimplex()
throws CostException {
while (true) {
// save the original vertex
PointCostPair[] original = simplex;
double originalCost = original[0].getCost();
// perform a reflection step
double reflectedCost = evaluateNewSimplex(original, 1.0);
if (reflectedCost < originalCost) {
// compute the expanded simplex
PointCostPair[] reflected = simplex;
double expandedCost = evaluateNewSimplex(original, khi);
if (reflectedCost <= expandedCost) {
// accept the reflected simplex
simplex = reflected;
}
return;
}
// compute the contracted simplex
double contractedCost = evaluateNewSimplex(original, gamma);
if (contractedCost < originalCost) {
// accept the contracted simplex
return;
}
}
}
/** Compute and evaluate a new simplex.
* @param original original simplex (to be preserved)
* @param coeff linear coefficient
* @return smallest cost in the transformed simplex
* @exception CostException if the function cannot be evaluated at
* some point
*/
private double evaluateNewSimplex(PointCostPair[] original, double coeff)
throws CostException {
double[] xSmallest = original[0].getPoint();
int n = xSmallest.length;
// create the linearly transformed simplex
simplex = new PointCostPair[n + 1];
simplex[0] = original[0];
for (int i = 1; i <= n; ++i) {
double[] xOriginal = original[i].getPoint();
double[] xTransformed = new double[n];
for (int j = 0; j < n; ++j) {
xTransformed[j] = xSmallest[j] + coeff * (xSmallest[j] - xOriginal[j]);
}
simplex[i] = new PointCostPair(xTransformed, Double.NaN);
}
// evaluate it
evaluateSimplex();
return simplex[0].getCost();
}
/** Expansion coefficient. */
private double khi;
/** Contraction coefficient. */
private double gamma;
}

View File

@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization;
import java.io.Serializable;
/**
* This interface represents a vectorial objective function to be either minimized or maximized.
* @see LeastSquaresConverter
* @version $Revision$ $Date$
* @since 2.0
*/
public interface MultiObjectiveFunction extends Serializable {
/**
* Compute the function value for the given variables set.
* @param variables variables set
* @return function value for the given variables set
* @exception ObjectiveException if no cost can be computed for the parameters
*/
double[] objective(double[] variables) throws ObjectiveException;
}

View File

@ -1,176 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization;
/**
* This class implements the Nelder-Mead direct search method.
*
* @version $Revision$ $Date$
* @see MultiDirectional
* @since 1.2
*/
public class NelderMead
extends DirectSearchOptimizer {
/** Build a Nelder-Mead optimizer with default coefficients.
* <p>The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
* for both gamma and sigma.</p>
*/
public NelderMead() {
super();
this.rho = 1.0;
this.khi = 2.0;
this.gamma = 0.5;
this.sigma = 0.5;
}
/** Build a Nelder-Mead optimizer with specified coefficients.
* @param rho reflection coefficient
* @param khi expansion coefficient
* @param gamma contraction coefficient
* @param sigma shrinkage coefficient
*/
public NelderMead(double rho, double khi, double gamma, double sigma) {
super();
this.rho = rho;
this.khi = khi;
this.gamma = gamma;
this.sigma = sigma;
}
/** Compute the next simplex of the algorithm.
* @exception CostException if the function cannot be evaluated at
* some point
*/
protected void iterateSimplex()
throws CostException {
// the simplex has n+1 point if dimension is n
int n = simplex.length - 1;
// interesting costs
double smallest = simplex[0].getCost();
double secondLargest = simplex[n-1].getCost();
double largest = simplex[n].getCost();
double[] xLargest = simplex[n].getPoint();
// compute the centroid of the best vertices
// (dismissing the worst point at index n)
double[] centroid = new double[n];
for (int i = 0; i < n; ++i) {
double[] x = simplex[i].getPoint();
for (int j = 0; j < n; ++j) {
centroid[j] += x[j];
}
}
double scaling = 1.0 / n;
for (int j = 0; j < n; ++j) {
centroid[j] *= scaling;
}
// compute the reflection point
double[] xR = new double[n];
for (int j = 0; j < n; ++j) {
xR[j] = centroid[j] + rho * (centroid[j] - xLargest[j]);
}
double costR = evaluateCost(xR);
if ((smallest <= costR) && (costR < secondLargest)) {
// accept the reflected point
replaceWorstPoint(new PointCostPair(xR, costR));
} else if (costR < smallest) {
// compute the expansion point
double[] xE = new double[n];
for (int j = 0; j < n; ++j) {
xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
}
double costE = evaluateCost(xE);
if (costE < costR) {
// accept the expansion point
replaceWorstPoint(new PointCostPair(xE, costE));
} else {
// accept the reflected point
replaceWorstPoint(new PointCostPair(xR, costR));
}
} else {
if (costR < largest) {
// perform an outside contraction
double[] xC = new double[n];
for (int j = 0; j < n; ++j) {
xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
}
double costC = evaluateCost(xC);
if (costC <= costR) {
// accept the contraction point
replaceWorstPoint(new PointCostPair(xC, costC));
return;
}
} else {
// perform an inside contraction
double[] xC = new double[n];
for (int j = 0; j < n; ++j) {
xC[j] = centroid[j] - gamma * (centroid[j] - xLargest[j]);
}
double costC = evaluateCost(xC);
if (costC < largest) {
// accept the contraction point
replaceWorstPoint(new PointCostPair(xC, costC));
return;
}
}
// perform a shrink
double[] xSmallest = simplex[0].getPoint();
for (int i = 1; i < simplex.length; ++i) {
double[] x = simplex[i].getPoint();
for (int j = 0; j < n; ++j) {
x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
}
simplex[i] = new PointCostPair(x, Double.NaN);
}
evaluateSimplex();
}
}
/** Reflection coefficient. */
private double rho;
/** Expansion coefficient. */
private double khi;
/** Contraction coefficient. */
private double gamma;
/** Shrinkage coefficient. */
private double sigma;
}

View File

@ -20,37 +20,37 @@ package org.apache.commons.math.optimization;
import org.apache.commons.math.MathException;
/**
* This class represents exceptions thrown by cost functions.
* This class represents exceptions thrown by obective functions.
*
* @version $Revision$ $Date$
* @since 1.2
*/
public class CostException
public class ObjectiveException
extends MathException {
/** Serializable version identifier. */
private static final long serialVersionUID = 467695563268795689L;
private static final long serialVersionUID = 8738657724051397417L;
/**
* Constructs a new <code>MathException</code> with specified
* Constructs a new <code>ObjectiveException</code> with specified
* formatted detail message.
* Message formatting is delegated to {@link java.text.MessageFormat}.
* @param pattern format specifier
* @param arguments format arguments
*/
public CostException(String pattern, Object ... arguments) {
public ObjectiveException(String pattern, Object ... arguments) {
super(pattern, arguments);
}
/**
* Constructs a new <code>MathException</code> with specified
* Constructs a new <code>ObjectiveException</code> with specified
* nested <code>Throwable</code> root cause.
*
* @param rootCause the exception or error that caused this exception
* to be thrown.
*/
public CostException(Throwable rootCause) {
public ObjectiveException(Throwable rootCause) {
super(rootCause);
}

View File

@ -17,43 +17,21 @@
package org.apache.commons.math.optimization;
import java.io.Serializable;
/**
* This class holds a point and its associated cost.
* <p>This is a simple immutable container.</p>
* This interface represents a scalar objective function to be either minimized or maximized.
* @version $Revision$ $Date$
* @see CostFunction
* @since 1.2
*/
public class PointCostPair {
public interface ObjectiveFunction extends Serializable {
/** Build a point/cost pair.
* @param point point coordinates (the built instance will store
* a copy of the array, not the array passed as argument)
* @param cost point cost
*/
public PointCostPair(double[] point, double cost) {
this.point = (double[]) point.clone();
this.cost = cost;
}
/** Get the point.
* @return a copy of the stored point
*/
public double[] getPoint() {
return (double[]) point.clone();
}
/** Get the cost.
* @return the stored cost
*/
public double getCost() {
return cost;
}
/** Point coordinates. */
private final double[] point;
/** Cost associated to the point. */
private final double cost;
/**
* Compute the function value for the given variables set.
* @param variables variables set
* @return function value for the given variables set
* @exception ObjectiveException if no value can be computed for the parameters
*/
double objective(double[] variables) throws ObjectiveException;
}

View File

@ -15,9 +15,9 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization;
import org.apache.commons.math.MathException;
import org.apache.commons.math.ConvergenceException;
/**
* This class represents exceptions thrown by the estimation solvers.
@ -27,11 +27,10 @@ import org.apache.commons.math.MathException;
*
*/
public class EstimationException
extends MathException {
public class OptimizationException extends ConvergenceException {
/** Serializable version identifier. */
private static final long serialVersionUID = -573038581493881337L;
private static final long serialVersionUID = -781139167958631145L;
/**
* Simple constructor.
@ -39,7 +38,7 @@ extends MathException {
* @param specifier format specifier (to be translated)
* @param parts to insert in the format (no translation)
*/
public EstimationException(String specifier, Object ... parts) {
public OptimizationException(String specifier, Object ... parts) {
super(specifier, parts);
}

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization;
import java.io.Serializable;
/**
* This interface represents an optimization algorithm.
* @version $Revision$ $Date$
* @since 2.0
*/
public interface Optimizer extends Serializable {
/** Set the maximal number of objective function calls.
* @param maxEvaluations maximal number of function calls for each
* start (note that the number may be checked <em>after</em>
* a few related calls have been made, this means that in some
* cases this number will be exceeded by a few units, depending on
* the dimension of the problem and kind of optimizer).
*/
void setMaxEvaluations(int maxEvaluations);
/** Set the convergence checker.
* @param checker object to use to check for convergence
*/
void setConvergenceChecker(ConvergenceChecker checker);
/** Optimizes an objective function.
* @param f objective function
* @param goalType type of optimization goal: either {@link GoalType#MAXIMIZE}
* or {@link GoalType#MINIMIZE}
* @return the point/value pair giving the optimal value for objective function
* @exception ObjectiveException if the objective function throws one during
* the search
* @exception OptimizationException if the algorithm failed to converge
*/
PointValuePair optimize(final ObjectiveFunction f, final GoalType goalType)
throws ObjectiveException, OptimizationException;
}

View File

@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization;
import java.io.Serializable;
/**
* This class holds a point and the value of an objective function at this point.
* <p>This is a simple immutable container.</p>
* @version $Revision$ $Date$
* @see ObjectiveFunction
* @since 2.0
*/
public class PointValuePair implements Serializable {
/** Serializable version identifier. */
private static final long serialVersionUID = 2254035971797977063L;
/** Point coordinates. */
private final double[] point;
/** Value of the objective function at the point. */
private final double value;
/** Build a point/objective function value pair.
* @param point point coordinates (the built instance will store
* a copy of the array, not the array passed as argument)
* @param value value of an objective function at the point
*/
public PointValuePair(final double[] point, final double value) {
this.point = point.clone();
this.value = value;
}
/** Get the point.
* @return a copy of the stored point
*/
public double[] getPoint() {
return point.clone();
}
/** Get the value of the objective function.
* @return the stored value of the objective function
*/
public double getValue() {
return value;
}
}

View File

@ -15,8 +15,9 @@
* limitations under the License.
*/
package org.apache.commons.math.optimization;
package org.apache.commons.math.optimization.direct;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Comparator;
@ -25,6 +26,10 @@ import org.apache.commons.math.DimensionMismatchException;
import org.apache.commons.math.MathRuntimeException;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.decomposition.NotPositiveDefiniteMatrixException;
import org.apache.commons.math.optimization.ConvergenceChecker;
import org.apache.commons.math.optimization.ObjectiveException;
import org.apache.commons.math.optimization.ObjectiveFunction;
import org.apache.commons.math.optimization.PointValuePair;
import org.apache.commons.math.random.CorrelatedRandomVectorGenerator;
import org.apache.commons.math.random.JDKRandomGenerator;
import org.apache.commons.math.random.RandomGenerator;
@ -38,7 +43,7 @@ import org.apache.commons.math.stat.descriptive.moment.VectorialMean;
* This class implements simplex-based direct search optimization
* algorithms.
*
* <p>Direct search methods only use cost function values, they don't
* <p>Direct search methods only use objective function values, they don't
* need derivatives and don't either try to compute approximation of
* the derivatives. According to a 1996 paper by Margaret H. Wright
* (<a href="http://cm.bell-labs.com/cm/cs/doc/96/4-02.ps.gz">Direct
@ -51,40 +56,80 @@ import org.apache.commons.math.stat.descriptive.moment.VectorialMean;
* direct search methods can be useful.</p>
*
* <p>Simplex-based direct search methods are based on comparison of
* the cost function values at the vertices of a simplex (which is a
* the objective function values at the vertices of a simplex (which is a
* set of n+1 points in dimension n) that is updated by the algorithms
* steps.</p>
*
* <p>Minimization can be attempted either in single-start or in
* <p>Optimization can be attempted either in single-start or in
* multi-start mode. Multi-start is a traditional way to try to avoid
* being trapped in a local minimum and miss the global minimum of a
* being trapped in a local optimum and miss the global optimum of a
* function. It can also be used to verify the convergence of an
* algorithm. The various multi-start-enabled <code>minimize</code>
* methods return the best minimum found after all starts, and the
* {@link #getMinima getMinima} method can be used to retrieve all
* minima from all starts (including the one already provided by the
* {@link #minimize(CostFunction, int, ConvergenceChecker, double[],
* double[]) minimize} method).</p>
* algorithm. The various multi-start-enabled <code>optimize</code>
* methods return the best optimum found after all starts, and the
* {@link #getOptimum getOptimum} method can be used to retrieve all
* optima from all starts (including the one already provided by the
* {@link #optimize(ObjectiveFunction, int, ConvergenceChecker, double[],
* double[]) optimize} method).</p>
*
* <p>This class is the base class performing the boilerplate simplex
* initialization and handling. The simplex update by itself is
* performed by the derived classes according to the implemented
* algorithms.</p>
*
* @see CostFunction
* @see ObjectiveFunction
* @see NelderMead
* @see MultiDirectional
* @version $Revision$ $Date$
* @since 1.2
*/
public abstract class DirectSearchOptimizer {
public abstract class DirectSearchOptimizer implements Serializable {
/** Serializable version identifier. */
private static final long serialVersionUID = -3913013760494455466L;
/** Comparator for {@link PointValuePair} objects. */
private static final Comparator<PointValuePair> PAIR_COMPARATOR =
new Comparator<PointValuePair>() {
public int compare(PointValuePair o1, PointValuePair o2) {
if (o1 == null) {
return (o2 == null) ? 0 : +1;
} else if (o2 == null) {
return -1;
}
return (o1.getValue() < o2.getValue()) ? -1 : ((o1 == o2) ? 0 : +1);
}
};
/** Simplex. */
protected PointValuePair[] simplex;
/** Objective function. */
private ObjectiveFunction f;
/** Indicator for minimization. */
private boolean minimizing;
/** Number of evaluations already performed for the current start. */
private int evaluations;
/** Number of evaluations already performed for all starts. */
private int totalEvaluations;
/** Number of starts to go. */
private int starts;
/** Random generator for multi-start. */
private RandomVectorGenerator generator;
/** Found optima. */
private PointValuePair[] optima;
/** Simple constructor.
*/
protected DirectSearchOptimizer() {
}
/** Minimizes a cost function.
/** Optimizes an objective function.
* <p>The initial simplex is built from two vertices that are
* considered to represent two opposite vertices of a box parallel
* to the canonical axes of the space. The simplex is the subset of
@ -93,36 +138,37 @@ public abstract class DirectSearchOptimizer {
* regular simplex using the projected separation between the given
* points as the scaling factor along each coordinate axis.</p>
* <p>The optimization is performed in single-start mode.</p>
* @param f cost function
* @param f objective function
* @param maxEvaluations maximal number of function calls for each
* start (note that the number will be checked <em>after</em>
* complete simplices have been evaluated, this means that in some
* cases this number will be exceeded by a few units, depending on
* the dimension of the problem)
* @param checker object to use to check for convergence
* @param minimizing if true, function must be minimize otherwise it must be maximized
* @param vertexA first vertex
* @param vertexB last vertex
* @return the point/cost pairs giving the minimal cost
* @exception CostException if the cost function throws one during
* @return the point/value pairs giving the optimal value for objective function
* @exception ObjectiveException if the objective function throws one during
* the search
* @exception ConvergenceException if none of the starts did
* converge (it is not thrown if at least one start did converge)
*/
public PointCostPair minimize(CostFunction f, int maxEvaluations,
ConvergenceChecker checker,
double[] vertexA, double[] vertexB)
throws CostException, ConvergenceException {
public PointValuePair optimize(final ObjectiveFunction f, final int maxEvaluations,
final ConvergenceChecker checker, final boolean minimizing,
final double[] vertexA, final double[] vertexB)
throws ObjectiveException, ConvergenceException {
// set up optimizer
buildSimplex(vertexA, vertexB);
setSingleStart();
// compute minimum
return minimize(f, maxEvaluations, checker);
// compute optimum
return optimize(f, maxEvaluations, checker, minimizing);
}
/** Minimizes a cost function.
/** Optimizes an objective function.
* <p>The initial simplex is built from two vertices that are
* considered to represent two opposite vertices of a box parallel
* to the canonical axes of the space. The simplex is the subset of
@ -131,30 +177,31 @@ public abstract class DirectSearchOptimizer {
* regular simplex using the projected separation between the given
* points as the scaling factor along each coordinate axis.</p>
* <p>The optimization is performed in multi-start mode.</p>
* @param f cost function
* @param f objective function
* @param maxEvaluations maximal number of function calls for each
* start (note that the number will be checked <em>after</em>
* complete simplices have been evaluated, this means that in some
* cases this number will be exceeded by a few units, depending on
* the dimension of the problem)
* @param checker object to use to check for convergence
* @param minimizing if true, function must be minimize otherwise it must be maximized
* @param vertexA first vertex
* @param vertexB last vertex
* @param starts number of starts to perform (including the
* first one), multi-start is disabled if value is less than or
* equal to 1
* @param seed seed for the random vector generator
* @return the point/cost pairs giving the minimal cost
* @exception CostException if the cost function throws one during
* @return the point/value pairs giving the optimal value for objective function
* @exception ObjectiveException if the obective function throws one during
* the search
* @exception ConvergenceException if none of the starts did
* converge (it is not thrown if at least one start did converge)
*/
public PointCostPair minimize(CostFunction f, int maxEvaluations,
ConvergenceChecker checker,
double[] vertexA, double[] vertexB,
int starts, long seed)
throws CostException, ConvergenceException {
public PointValuePair optimize(final ObjectiveFunction f, final int maxEvaluations,
final ConvergenceChecker checker, final boolean minimizing,
final double[] vertexA, final double[] vertexB,
final int starts, final long seed)
throws ObjectiveException, ConvergenceException {
// set up the simplex traveling around the box
buildSimplex(vertexA, vertexB);
@ -162,111 +209,112 @@ public abstract class DirectSearchOptimizer {
// we consider the simplex could have been produced by a generator
// having its mean value at the center of the box, the standard
// deviation along each axe being the corresponding half size
double[] mean = new double[vertexA.length];
double[] standardDeviation = new double[vertexA.length];
final double[] mean = new double[vertexA.length];
final double[] standardDeviation = new double[vertexA.length];
for (int i = 0; i < vertexA.length; ++i) {
mean[i] = 0.5 * (vertexA[i] + vertexB[i]);
standardDeviation[i] = 0.5 * Math.abs(vertexA[i] - vertexB[i]);
}
RandomGenerator rg = new JDKRandomGenerator();
final RandomGenerator rg = new JDKRandomGenerator();
rg.setSeed(seed);
UniformRandomGenerator urg = new UniformRandomGenerator(rg);
RandomVectorGenerator rvg =
final UniformRandomGenerator urg = new UniformRandomGenerator(rg);
final RandomVectorGenerator rvg =
new UncorrelatedRandomVectorGenerator(mean, standardDeviation, urg);
setMultiStart(starts, rvg);
// compute minimum
return minimize(f, maxEvaluations, checker);
// compute optimum
return optimize(f, maxEvaluations, checker, minimizing);
}
/** Minimizes a cost function.
/** Optimizes an objective function.
* <p>The simplex is built from all its vertices.</p>
* <p>The optimization is performed in single-start mode.</p>
* @param f cost function
* @param f objective function
* @param maxEvaluations maximal number of function calls for each
* start (note that the number will be checked <em>after</em>
* complete simplices have been evaluated, this means that in some
* cases this number will be exceeded by a few units, depending on
* the dimension of the problem)
* @param checker object to use to check for convergence
* @param minimizing if true, function must be minimize otherwise it must be maximized
* @param vertices array containing all vertices of the simplex
* @return the point/cost pairs giving the minimal cost
* @exception CostException if the cost function throws one during
* @return the point/value pairs giving the optimal value for objective function
* @exception ObjectiveException if the objective function throws one during
* the search
* @exception ConvergenceException if none of the starts did
* converge (it is not thrown if at least one start did converge)
*/
public PointCostPair minimize(CostFunction f, int maxEvaluations,
ConvergenceChecker checker,
double[][] vertices)
throws CostException, ConvergenceException {
public PointValuePair optimize(final ObjectiveFunction f, final int maxEvaluations,
final ConvergenceChecker checker, final boolean minimizing,
final double[][] vertices)
throws ObjectiveException, ConvergenceException {
// set up optimizer
buildSimplex(vertices);
setSingleStart();
// compute minimum
return minimize(f, maxEvaluations, checker);
// compute optimum
return optimize(f, maxEvaluations, checker, minimizing);
}
/** Minimizes a cost function.
/** Optimizes an objective function.
* <p>The simplex is built from all its vertices.</p>
* <p>The optimization is performed in multi-start mode.</p>
* @param f cost function
* @param f objective function
* @param maxEvaluations maximal number of function calls for each
* start (note that the number will be checked <em>after</em>
* complete simplices have been evaluated, this means that in some
* cases this number will be exceeded by a few units, depending on
* the dimension of the problem)
* @param checker object to use to check for convergence
* @param minimizing if true, function must be minimize otherwise it must be maximized
* @param vertices array containing all vertices of the simplex
* @param starts number of starts to perform (including the
* first one), multi-start is disabled if value is less than or
* equal to 1
* @param seed seed for the random vector generator
* @return the point/cost pairs giving the minimal cost
* @return the point/value pairs giving the optimal value for objective function
* @exception NotPositiveDefiniteMatrixException if the vertices
* array is degenerated
* @exception CostException if the cost function throws one during
* @exception ObjectiveException if the objective function throws one during
* the search
* @exception ConvergenceException if none of the starts did
* converge (it is not thrown if at least one start did converge)
*/
public PointCostPair minimize(CostFunction f, int maxEvaluations,
ConvergenceChecker checker,
double[][] vertices,
int starts, long seed)
throws NotPositiveDefiniteMatrixException,
CostException, ConvergenceException {
public PointValuePair optimize(final ObjectiveFunction f, final int maxEvaluations,
final ConvergenceChecker checker, final boolean minimizing,
final double[][] vertices,
final int starts, final long seed)
throws NotPositiveDefiniteMatrixException, ObjectiveException, ConvergenceException {
try {
// store the points into the simplex
buildSimplex(vertices);
// compute the statistical properties of the simplex points
VectorialMean meanStat = new VectorialMean(vertices[0].length);
VectorialCovariance covStat = new VectorialCovariance(vertices[0].length, true);
final VectorialMean meanStat = new VectorialMean(vertices[0].length);
final VectorialCovariance covStat = new VectorialCovariance(vertices[0].length, true);
for (int i = 0; i < vertices.length; ++i) {
meanStat.increment(vertices[i]);
covStat.increment(vertices[i]);
}
double[] mean = meanStat.getResult();
RealMatrix covariance = covStat.getResult();
final double[] mean = meanStat.getResult();
final RealMatrix covariance = covStat.getResult();
RandomGenerator rg = new JDKRandomGenerator();
final RandomGenerator rg = new JDKRandomGenerator();
rg.setSeed(seed);
RandomVectorGenerator rvg =
final RandomVectorGenerator rvg =
new CorrelatedRandomVectorGenerator(mean,
covariance, 1.0e-12 * covariance.getNorm(),
new UniformRandomGenerator(rg));
setMultiStart(starts, rvg);
// compute minimum
return minimize(f, maxEvaluations, checker);
// compute optimum
return optimize(f, maxEvaluations, checker, minimizing);
} catch (DimensionMismatchException dme) {
// this should not happen
@ -275,69 +323,71 @@ public abstract class DirectSearchOptimizer {
}
/** Minimizes a cost function.
/** Optimizes an objective function.
* <p>The simplex is built randomly.</p>
* <p>The optimization is performed in single-start mode.</p>
* @param f cost function
* @param f objective function
* @param maxEvaluations maximal number of function calls for each
* start (note that the number will be checked <em>after</em>
* complete simplices have been evaluated, this means that in some
* cases this number will be exceeded by a few units, depending on
* the dimension of the problem)
* @param checker object to use to check for convergence
* @param minimizing if true, function must be minimize otherwise it must be maximized
* @param generator random vector generator
* @return the point/cost pairs giving the minimal cost
* @exception CostException if the cost function throws one during
* @return the point/value pairs giving the optimal value for objective function
* @exception ObjectiveException if the objective function throws one during
* the search
* @exception ConvergenceException if none of the starts did
* converge (it is not thrown if at least one start did converge)
*/
public PointCostPair minimize(CostFunction f, int maxEvaluations,
ConvergenceChecker checker,
RandomVectorGenerator generator)
throws CostException, ConvergenceException {
public PointValuePair optimize(final ObjectiveFunction f, final int maxEvaluations,
final ConvergenceChecker checker, final boolean minimizing,
final RandomVectorGenerator generator)
throws ObjectiveException, ConvergenceException {
// set up optimizer
buildSimplex(generator);
setSingleStart();
// compute minimum
return minimize(f, maxEvaluations, checker);
// compute optimum
return optimize(f, maxEvaluations, checker, minimizing);
}
/** Minimizes a cost function.
/** Optimizes an objective function.
* <p>The simplex is built randomly.</p>
* <p>The optimization is performed in multi-start mode.</p>
* @param f cost function
* @param f objective function
* @param maxEvaluations maximal number of function calls for each
* start (note that the number will be checked <em>after</em>
* complete simplices have been evaluated, this means that in some
* cases this number will be exceeded by a few units, depending on
* the dimension of the problem)
* @param checker object to use to check for convergence
* @param minimizing if true, function must be minimize otherwise it must be maximized
* @param generator random vector generator
* @param starts number of starts to perform (including the
* first one), multi-start is disabled if value is less than or
* equal to 1
* @return the point/cost pairs giving the minimal cost
* @exception CostException if the cost function throws one during
* @return the point/value pairs giving the optimal value for objective function
* @exception ObjectiveException if the objective function throws one during
* the search
* @exception ConvergenceException if none of the starts did
* converge (it is not thrown if at least one start did converge)
*/
public PointCostPair minimize(CostFunction f, int maxEvaluations,
ConvergenceChecker checker,
RandomVectorGenerator generator,
int starts)
throws CostException, ConvergenceException {
public PointValuePair optimize(final ObjectiveFunction f, final int maxEvaluations,
final ConvergenceChecker checker, final boolean minimizing,
final RandomVectorGenerator generator,
final int starts)
throws ObjectiveException, ConvergenceException {
// set up optimizer
buildSimplex(generator);
setMultiStart(starts, generator);
// compute minimum
return minimize(f, maxEvaluations, checker);
// compute optimum
return optimize(f, maxEvaluations, checker, minimizing);
}
@ -352,21 +402,21 @@ public abstract class DirectSearchOptimizer {
* @param vertexA first vertex
* @param vertexB last vertex
*/
private void buildSimplex(double[] vertexA, double[] vertexB) {
private void buildSimplex(final double[] vertexA, final double[] vertexB) {
int n = vertexA.length;
simplex = new PointCostPair[n + 1];
final int n = vertexA.length;
simplex = new PointValuePair[n + 1];
// set up the simplex traveling around the box
for (int i = 0; i <= n; ++i) {
double[] vertex = new double[n];
final double[] vertex = new double[n];
if (i > 0) {
System.arraycopy(vertexB, 0, vertex, 0, i);
}
if (i < n) {
System.arraycopy(vertexA, i, vertex, i, n - i);
}
simplex[i] = new PointCostPair(vertex, Double.NaN);
simplex[i] = new PointValuePair(vertex, Double.NaN);
}
}
@ -374,28 +424,28 @@ public abstract class DirectSearchOptimizer {
/** Build a simplex from all its points.
* @param vertices array containing all vertices of the simplex
*/
private void buildSimplex(double[][] vertices) {
int n = vertices.length - 1;
simplex = new PointCostPair[n + 1];
private void buildSimplex(final double[][] vertices) {
final int n = vertices.length - 1;
simplex = new PointValuePair[n + 1];
for (int i = 0; i <= n; ++i) {
simplex[i] = new PointCostPair(vertices[i], Double.NaN);
simplex[i] = new PointValuePair(vertices[i], Double.NaN);
}
}
/** Build a simplex randomly.
* @param generator random vector generator
*/
private void buildSimplex(RandomVectorGenerator generator) {
private void buildSimplex(final RandomVectorGenerator generator) {
// use first vector size to compute the number of points
double[] vertex = generator.nextVector();
int n = vertex.length;
simplex = new PointCostPair[n + 1];
simplex[0] = new PointCostPair(vertex, Double.NaN);
final double[] vertex = generator.nextVector();
final int n = vertex.length;
simplex = new PointValuePair[n + 1];
simplex[0] = new PointValuePair(vertex, Double.NaN);
// fill up the vertex
for (int i = 1; i <= n; ++i) {
simplex[i] = new PointCostPair(generator.nextVector(), Double.NaN);
simplex[i] = new PointValuePair(generator.nextVector(), Double.NaN);
}
}
@ -405,7 +455,7 @@ public abstract class DirectSearchOptimizer {
private void setSingleStart() {
starts = 1;
generator = null;
minima = null;
optima = null;
}
/** Set up multi-start mode.
@ -414,65 +464,69 @@ public abstract class DirectSearchOptimizer {
* equal to 1
* @param generator random vector generator to use for restarts
*/
private void setMultiStart(int starts, RandomVectorGenerator generator) {
private void setMultiStart(final int starts, final RandomVectorGenerator generator) {
if (starts < 2) {
this.starts = 1;
this.generator = null;
minima = null;
optima = null;
} else {
this.starts = starts;
this.generator = generator;
minima = null;
optima = null;
}
}
/** Get all the minima found during the last call to {@link
* #minimize(CostFunction, int, ConvergenceChecker, double[], double[])
/** Get all the optima found during the last call to {@link
* #optimize(ObjectiveFunction, int, ConvergenceChecker, double[], double[])
* minimize}.
* <p>The optimizer stores all the minima found during a set of
* <p>The optimizer stores all the optima found during a set of
* restarts when multi-start mode is enabled. The {@link
* #minimize(CostFunction, int, ConvergenceChecker, double[], double[])
* minimize} method returns the best point only. This method
* #optimize(ObjectiveFunction, int, ConvergenceChecker, double[], double[])
* optimize} method returns the best point only. This method
* returns all the points found at the end of each starts, including
* the best one already returned by the {@link #minimize(CostFunction,
* int, ConvergenceChecker, double[], double[]) minimize} method.
* the best one already returned by the {@link #optimize(ObjectiveFunction,
* int, ConvergenceChecker, double[], double[]) optimize} method.
* The array as one element for each start as specified in the constructor
* (it has one element only if optimizer has been set up for single-start).</p>
* <p>The array containing the minima is ordered with the results
* <p>The array containing the optimum is ordered with the results
* from the runs that did converge first, sorted from lowest to
* highest minimum cost, and null elements corresponding to the runs
* that did not converge (all elements will be null if the {@link
* #minimize(CostFunction, int, ConvergenceChecker, double[], double[])
* minimize} method did throw a {@link ConvergenceException
* ConvergenceException}).</p>
* @return array containing the minima, or null if {@link
* #minimize(CostFunction, int, ConvergenceChecker, double[], double[])
* minimize} has not been called
* highest objective value if minimizing (from highest to lowest if maximizing),
* and null elements corresponding to the runs that did not converge. This means
* all elements will be null if the {@link #optimize(ObjectiveFunction, int,
* ConvergenceChecker, double[], double[]) optimize} method did throw a {@link
* ConvergenceException ConvergenceException}). This also means that if the first
* element is non null, it is the best point found accross all starts.</p>
* @return array containing the optima, or null if {@link
* #optimize(ObjectiveFunction, int, ConvergenceChecker, double[], double[])
* optimize} has not been called
*/
public PointCostPair[] getMinima() {
return (PointCostPair[]) minima.clone();
public PointValuePair[] getOptima() {
return (PointValuePair[]) optima.clone();
}
/** Minimizes a cost function.
* @param f cost function
/** Optimizes an objective function.
* @param f objective function
* @param maxEvaluations maximal number of function calls for each
* start (note that the number will be checked <em>after</em>
* complete simplices have been evaluated, this means that in some
* cases this number will be exceeded by a few units, depending on
* the dimension of the problem)
* @param checker object to use to check for convergence
* @return the point/cost pairs giving the minimal cost
* @exception CostException if the cost function throws one during
* @param minimizing if true, function must be minimize otherwise it must be maximized
* @return the point/value pairs giving the optimal value for objective function
* @exception ObjectiveException if the objective function throws one during
* the search
* @exception ConvergenceException if none of the starts did
* converge (it is not thrown if at least one start did converge)
*/
private PointCostPair minimize(CostFunction f, int maxEvaluations,
ConvergenceChecker checker)
throws CostException, ConvergenceException {
private PointValuePair optimize(final ObjectiveFunction f, final int maxEvaluations,
final ConvergenceChecker checker, final boolean minimizing)
throws ObjectiveException, ConvergenceException {
this.f = f;
minima = new PointCostPair[starts];
this.f = f;
this.minimizing = minimizing;
optima = new PointValuePair[starts];
totalEvaluations = 0;
// multi-start loop
for (int i = 0; i < starts; ++i) {
@ -482,18 +536,20 @@ public abstract class DirectSearchOptimizer {
for (boolean loop = true; loop;) {
if (checker.converged(simplex)) {
// we have found a minimum
minima[i] = simplex[0];
// we have found an optimum
optima[i] = simplex[0];
loop = false;
} else if (evaluations >= maxEvaluations) {
// this start did not converge, try a new one
minima[i] = null;
optima[i] = null;
loop = false;
} else {
iterateSimplex();
}
}
totalEvaluations += evaluations;
if (i < (starts - 1)) {
// restart
buildSimplex(generator);
@ -501,103 +557,90 @@ public abstract class DirectSearchOptimizer {
}
// sort the minima from lowest cost to highest cost, followed by
// sort the optima from best to poorest, followed by
// null elements
Arrays.sort(minima, pointCostPairComparator);
Arrays.sort(optima, PAIR_COMPARATOR);
// return the found point given the lowest cost
if (minima[0] == null) {
if (!minimizing) {
// revert objective function sign to match user original definition
for (int i = 0; i < optima.length; ++i) {
final PointValuePair current = optima[i];
if (current != null) {
optima[i] = new PointValuePair(current.getPoint(), -current.getValue());
}
}
}
// return the found point given the best objective function value
if (optima[0] == null) {
throw new ConvergenceException(
"none of the {0} start points lead to convergence",
starts);
}
return minima[0];
return optima[0];
}
/** Get the total number of evaluations of the objective function.
* <p>
* The total number of evaluations includes all evaluations for all
* starts if in optimization was done in multi-start mode.
* </p>
* @return total number of evaluations of the objective function
*/
public int getTotalEvaluations() {
return totalEvaluations;
}
/** Compute the next simplex of the algorithm.
* @exception CostException if the function cannot be evaluated at
* @exception ObjectiveException if the function cannot be evaluated at
* some point
*/
protected abstract void iterateSimplex()
throws CostException;
protected abstract void iterateSimplex() throws ObjectiveException;
/** Evaluate the cost on one point.
/** Evaluate the objective function on one point.
* <p>A side effect of this method is to count the number of
* function evaluations</p>
* @param x point on which the cost function should be evaluated
* @return cost at the given point
* @exception CostException if no cost can be computed for the parameters
* @param x point on which the objective function should be evaluated
* @return objective function value at the given point
* @exception ObjectiveException if no value can be computed for the parameters
*/
protected double evaluateCost(double[] x)
throws CostException {
protected double evaluate(final double[] x) throws ObjectiveException {
evaluations++;
return f.cost(x);
return minimizing ? f.objective(x) : -f.objective(x);
}
/** Evaluate all the non-evaluated points of the simplex.
* @exception CostException if no cost can be computed for the parameters
* @exception ObjectiveException if no value can be computed for the parameters
*/
protected void evaluateSimplex()
throws CostException {
protected void evaluateSimplex() throws ObjectiveException {
// evaluate the cost at all non-evaluated simplex points
// evaluate the objective function at all non-evaluated simplex points
for (int i = 0; i < simplex.length; ++i) {
PointCostPair pair = simplex[i];
if (Double.isNaN(pair.getCost())) {
simplex[i] = new PointCostPair(pair.getPoint(), evaluateCost(pair.getPoint()));
PointValuePair pair = simplex[i];
if (Double.isNaN(pair.getValue())) {
simplex[i] = new PointValuePair(pair.getPoint(), evaluate(pair.getPoint()));
}
}
// sort the simplex from lowest cost to highest cost
Arrays.sort(simplex, pointCostPairComparator);
// sort the simplex from best to poorest
Arrays.sort(simplex, PAIR_COMPARATOR);
}
/** Replace the worst point of the simplex by a new point.
* @param pointCostPair point to insert
* @param pointValuePair point to insert
*/
protected void replaceWorstPoint(PointCostPair pointCostPair) {
protected void replaceWorstPoint(PointValuePair pointValuePair) {
int n = simplex.length - 1;
for (int i = 0; i < n; ++i) {
if (simplex[i].getCost() > pointCostPair.getCost()) {
PointCostPair tmp = simplex[i];
simplex[i] = pointCostPair;
pointCostPair = tmp;
if (simplex[i].getValue() > pointValuePair.getValue()) {
PointValuePair tmp = simplex[i];
simplex[i] = pointValuePair;
pointValuePair = tmp;
}
}
simplex[n] = pointCostPair;
simplex[n] = pointValuePair;
}
/** Comparator for {@link PointCostPair PointCostPair} objects. */
private static Comparator<PointCostPair> pointCostPairComparator =
new Comparator<PointCostPair>() {
public int compare(PointCostPair o1, PointCostPair o2) {
if (o1 == null) {
return (o2 == null) ? 0 : +1;
} else if (o2 == null) {
return -1;
}
return (o1.getCost() < o2.getCost()) ? -1 : ((o1 == o2) ? 0 : +1);
}
};
/** Simplex. */
protected PointCostPair[] simplex;
/** Cost function. */
private CostFunction f;
/** Number of evaluations already performed. */
private int evaluations;
/** Number of starts to go. */
private int starts;
/** Random generator for multi-start. */
private RandomVectorGenerator generator;
/** Found minima. */
private PointCostPair[] minima;
}

View File

@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization.direct;
import org.apache.commons.math.optimization.ObjectiveException;
import org.apache.commons.math.optimization.PointValuePair;
/**
* This class implements the multi-directional direct search method.
*
* @version $Revision$ $Date$
* @see NelderMead
* @since 1.2
*/
public class MultiDirectional extends DirectSearchOptimizer {
/** Serializable version identifier. */
private static final long serialVersionUID = -5347711305645019145L;
/** Expansion coefficient. */
private final double khi;
/** Contraction coefficient. */
private final double gamma;
/** Build a multi-directional optimizer with default coefficients.
* <p>The default values are 2.0 for khi and 0.5 for gamma.</p>
*/
public MultiDirectional() {
this.khi = 2.0;
this.gamma = 0.5;
}
/** Build a multi-directional optimizer with specified coefficients.
* @param khi expansion coefficient
* @param gamma contraction coefficient
*/
public MultiDirectional(final double khi, final double gamma) {
this.khi = khi;
this.gamma = gamma;
}
/** Compute the next simplex of the algorithm.
* @exception ObjectiveException if the function cannot be evaluated at
* some point
*/
protected void iterateSimplex() throws ObjectiveException {
while (true) {
// save the original vertex
final PointValuePair[] original = simplex;
final double originalValue = original[0].getValue();
// perform a reflection step
final double reflectedValue = evaluateNewSimplex(original, 1.0);
if (reflectedValue < originalValue) {
// compute the expanded simplex
final PointValuePair[] reflected = simplex;
final double expandedValue = evaluateNewSimplex(original, khi);
if (reflectedValue <= expandedValue) {
// accept the reflected simplex
simplex = reflected;
}
return;
}
// compute the contracted simplex
final double contractedValue = evaluateNewSimplex(original, gamma);
if (contractedValue < originalValue) {
// accept the contracted simplex
return;
}
}
}
/** Compute and evaluate a new simplex.
* @param original original simplex (to be preserved)
* @param coeff linear coefficient
* @return smallest value in the transformed simplex
* @exception ObjectiveException if the function cannot be evaluated at
* some point
*/
private double evaluateNewSimplex(final PointValuePair[] original,
final double coeff)
throws ObjectiveException {
final double[] xSmallest = original[0].getPoint();
final int n = xSmallest.length;
// create the linearly transformed simplex
simplex = new PointValuePair[n + 1];
simplex[0] = original[0];
for (int i = 1; i <= n; ++i) {
final double[] xOriginal = original[i].getPoint();
final double[] xTransformed = new double[n];
for (int j = 0; j < n; ++j) {
xTransformed[j] = xSmallest[j] + coeff * (xSmallest[j] - xOriginal[j]);
}
simplex[i] = new PointValuePair(xTransformed, Double.NaN);
}
// evaluate it
evaluateSimplex();
return simplex[0].getValue();
}
}

View File

@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization.direct;
import org.apache.commons.math.optimization.ObjectiveException;
import org.apache.commons.math.optimization.PointValuePair;
/**
* This class implements the Nelder-Mead direct search method.
*
* @version $Revision$ $Date$
* @see MultiDirectional
* @since 1.2
*/
public class NelderMead extends DirectSearchOptimizer {
/** Serializable version identifier. */
private static final long serialVersionUID = -5810365844886183056L;
/** Reflection coefficient. */
private final double rho;
/** Expansion coefficient. */
private final double khi;
/** Contraction coefficient. */
private final double gamma;
/** Shrinkage coefficient. */
private final double sigma;
/** Build a Nelder-Mead optimizer with default coefficients.
* <p>The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
* for both gamma and sigma.</p>
*/
public NelderMead() {
this.rho = 1.0;
this.khi = 2.0;
this.gamma = 0.5;
this.sigma = 0.5;
}
/** Build a Nelder-Mead optimizer with specified coefficients.
* @param rho reflection coefficient
* @param khi expansion coefficient
* @param gamma contraction coefficient
* @param sigma shrinkage coefficient
*/
public NelderMead(final double rho, final double khi,
final double gamma, final double sigma) {
this.rho = rho;
this.khi = khi;
this.gamma = gamma;
this.sigma = sigma;
}
/** {@inheritDoc} */
protected void iterateSimplex() throws ObjectiveException {
// the simplex has n+1 point if dimension is n
final int n = simplex.length - 1;
// interesting values
final double smallest = simplex[0].getValue();
final double secondLargest = simplex[n-1].getValue();
final double largest = simplex[n].getValue();
final double[] xLargest = simplex[n].getPoint();
// compute the centroid of the best vertices
// (dismissing the worst point at index n)
final double[] centroid = new double[n];
for (int i = 0; i < n; ++i) {
final double[] x = simplex[i].getPoint();
for (int j = 0; j < n; ++j) {
centroid[j] += x[j];
}
}
final double scaling = 1.0 / n;
for (int j = 0; j < n; ++j) {
centroid[j] *= scaling;
}
// compute the reflection point
final double[] xR = new double[n];
for (int j = 0; j < n; ++j) {
xR[j] = centroid[j] + rho * (centroid[j] - xLargest[j]);
}
final double valueR = evaluate(xR);
if ((smallest <= valueR) && (valueR < secondLargest)) {
// accept the reflected point
replaceWorstPoint(new PointValuePair(xR, valueR));
} else if (valueR < smallest) {
// compute the expansion point
final double[] xE = new double[n];
for (int j = 0; j < n; ++j) {
xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
}
final double valueE = evaluate(xE);
if (valueE < valueR) {
// accept the expansion point
replaceWorstPoint(new PointValuePair(xE, valueE));
} else {
// accept the reflected point
replaceWorstPoint(new PointValuePair(xR, valueR));
}
} else {
if (valueR < largest) {
// perform an outside contraction
final double[] xC = new double[n];
for (int j = 0; j < n; ++j) {
xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
}
final double valueC = evaluate(xC);
if (valueC <= valueR) {
// accept the contraction point
replaceWorstPoint(new PointValuePair(xC, valueC));
return;
}
} else {
// perform an inside contraction
final double[] xC = new double[n];
for (int j = 0; j < n; ++j) {
xC[j] = centroid[j] - gamma * (centroid[j] - xLargest[j]);
}
final double valueC = evaluate(xC);
if (valueC < largest) {
// accept the contraction point
replaceWorstPoint(new PointValuePair(xC, valueC));
return;
}
}
// perform a shrink
final double[] xSmallest = simplex[0].getPoint();
for (int i = 1; i < simplex.length; ++i) {
final double[] x = simplex[i].getPoint();
for (int j = 0; j < n; ++j) {
x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
}
simplex[i] = new PointValuePair(x, Double.NaN);
}
evaluateSimplex();
}
}
}

View File

@ -0,0 +1,24 @@
<html>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<!-- $Revision$ -->
<body>
<p>
This package provides optimization algorithms that don't require derivatives.
</p>
</body>
</html>

View File

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
import java.util.Arrays;
@ -23,6 +23,7 @@ import org.apache.commons.math.linear.InvalidMatrixException;
import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.decomposition.LUDecompositionImpl;
import org.apache.commons.math.optimization.OptimizationException;
/**
* Base class for implementing estimators.
@ -98,14 +99,14 @@ public abstract class AbstractEstimator implements Estimator {
/**
* Update the residuals array and cost function value.
* @exception EstimationException if the number of cost evaluations
* @exception OptimizationException if the number of cost evaluations
* exceeds the maximum allowed
*/
protected void updateResidualsAndCost()
throws EstimationException {
throws OptimizationException {
if (++costEvaluations > maxCostEval) {
throw new EstimationException("maximal number of evaluations exceeded ({0})",
throw new OptimizationException("maximal number of evaluations exceeded ({0})",
maxCostEval);
}
@ -160,11 +161,11 @@ public abstract class AbstractEstimator implements Estimator {
* Get the covariance matrix of unbound estimated parameters.
* @param problem estimation problem
* @return covariance matrix
* @exception EstimationException if the covariance matrix
* @exception OptimizationException if the covariance matrix
* cannot be computed (singular problem)
*/
public double[][] getCovariances(EstimationProblem problem)
throws EstimationException {
throws OptimizationException {
// set up the jacobian
updateJacobian();
@ -191,7 +192,7 @@ public abstract class AbstractEstimator implements Estimator {
new LUDecompositionImpl(MatrixUtils.createRealMatrix(jTj)).getSolver().getInverse();
return inverse.getData();
} catch (InvalidMatrixException ime) {
throw new EstimationException("unable to compute covariances: singular problem");
throw new OptimizationException("unable to compute covariances: singular problem");
}
}
@ -201,16 +202,16 @@ public abstract class AbstractEstimator implements Estimator {
* <p>Guessing is covariance-based, it only gives rough order of magnitude.</p>
* @param problem estimation problem
* @return errors in estimated parameters
* @exception EstimationException if the covariances matrix cannot be computed
* @exception OptimizationException if the covariances matrix cannot be computed
* or the number of degrees of freedom is not positive (number of measurements
* lesser or equal to number of parameters)
*/
public double[] guessParametersErrors(EstimationProblem problem)
throws EstimationException {
throws OptimizationException {
int m = problem.getMeasurements().length;
int p = problem.getUnboundParameters().length;
if (m <= p) {
throw new EstimationException(
throw new OptimizationException(
"no degrees of freedom ({0} measurements, {1} parameters)",
m, p);
}
@ -261,11 +262,11 @@ public abstract class AbstractEstimator implements Estimator {
* EstimationProblem.getAllParameters} method.</p>
*
* @param problem estimation problem to solve
* @exception EstimationException if the problem cannot be solved
* @exception OptimizationException if the problem cannot be solved
*
*/
public abstract void estimate(EstimationProblem problem)
throws EstimationException;
throws OptimizationException;
/** Array of measurements. */
protected WeightedMeasurement[] measurements;

View File

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
import java.io.Serializable;

View File

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
/**
* This interface represents an estimation problem.
@ -48,18 +48,18 @@ public interface EstimationProblem {
* Get the measurements of an estimation problem.
* @return measurements
*/
public WeightedMeasurement[] getMeasurements();
WeightedMeasurement[] getMeasurements();
/**
* Get the unbound parameters of the problem.
* @return unbound parameters
*/
public EstimatedParameter[] getUnboundParameters();
EstimatedParameter[] getUnboundParameters();
/**
* Get all the parameters of the problem.
* @return parameters
*/
public EstimatedParameter[] getAllParameters();
EstimatedParameter[] getAllParameters();
}

View File

@ -15,7 +15,9 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
import org.apache.commons.math.optimization.OptimizationException;
/**
* This interface represents solvers for estimation problems.
@ -48,11 +50,11 @@ public interface Estimator {
* EstimationProblem.getAllParameters} method.</p>
*
* @param problem estimation problem to solve
* @exception EstimationException if the problem cannot be solved
* @exception OptimizationException if the problem cannot be solved
*
*/
public void estimate(EstimationProblem problem)
throws EstimationException;
void estimate(EstimationProblem problem)
throws OptimizationException;
/**
* Get the Root Mean Square value.
@ -66,26 +68,26 @@ public interface Estimator {
* @param problem estimation problem
* @return RMS value
*/
public double getRMS(EstimationProblem problem);
double getRMS(EstimationProblem problem);
/**
* Get the covariance matrix of estimated parameters.
* @param problem estimation problem
* @return covariance matrix
* @exception EstimationException if the covariance matrix
* @exception OptimizationException if the covariance matrix
* cannot be computed (singular problem)
*/
public double[][] getCovariances(EstimationProblem problem)
throws EstimationException;
double[][] getCovariances(EstimationProblem problem)
throws OptimizationException;
/**
* Guess the errors in estimated parameters.
* @see #getRMS(EstimationProblem)
* @param problem estimation problem
* @return errors in estimated parameters
* @exception EstimationException if the error cannot be guessed
* @exception OptimizationException if the error cannot be guessed
*/
public double[] guessParametersErrors(EstimationProblem problem)
throws EstimationException;
double[] guessParametersErrors(EstimationProblem problem)
throws OptimizationException;
}

View File

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
import java.io.Serializable;
@ -25,6 +25,7 @@ import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealVector;
import org.apache.commons.math.linear.RealVectorImpl;
import org.apache.commons.math.linear.decomposition.LUDecompositionImpl;
import org.apache.commons.math.optimization.OptimizationException;
/**
* This class implements a solver for estimation problems.
@ -87,7 +88,7 @@ public class GaussNewtonEstimator extends AbstractEstimator implements Serializa
*
* <p>If neither conditions are fulfilled before a given number of
* iterations, the algorithm is considered to have failed and an
* {@link EstimationException} is thrown.</p>
* {@link OptimizationException} is thrown.</p>
*
* @param maxCostEval maximal number of cost evaluations allowed
* @param convergence criterion threshold below which we do not need
@ -144,17 +145,17 @@ public class GaussNewtonEstimator extends AbstractEstimator implements Serializa
* useless or when the algorithm is unable to improve it (even if it
* is still high). The first condition that is met stops the
* iterations. If the convergence it not reached before the maximum
* number of iterations, an {@link EstimationException} is
* number of iterations, an {@link OptimizationException} is
* thrown.</p>
*
* @param problem estimation problem to solve
* @exception EstimationException if the problem cannot be solved
* @exception OptimizationException if the problem cannot be solved
*
* @see EstimationProblem
*
*/
public void estimate(EstimationProblem problem)
throws EstimationException {
throws OptimizationException {
initializeEstimate(problem);
@ -210,7 +211,7 @@ public class GaussNewtonEstimator extends AbstractEstimator implements Serializa
}
} catch(InvalidMatrixException e) {
throw new EstimationException("unable to solve: singular problem");
throw new OptimizationException("unable to solve: singular problem");
}

View File

@ -14,11 +14,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
import java.io.Serializable;
import java.util.Arrays;
import org.apache.commons.math.optimization.OptimizationException;
/**
* This class solves a least squares problem.
@ -188,7 +190,7 @@ public class LevenbergMarquardtEstimator extends AbstractEstimator implements Se
* <p>Luc Maisonobe did the Java translation.</p>
*
* @param problem estimation problem to solve
* @exception EstimationException if convergence cannot be
* @exception OptimizationException if convergence cannot be
* reached with the specified algorithm settings or if there are more variables
* than equations
* @see #setInitialStepBoundFactor
@ -197,7 +199,7 @@ public class LevenbergMarquardtEstimator extends AbstractEstimator implements Se
* @see #setOrthoTolerance
*/
public void estimate(EstimationProblem problem)
throws EstimationException {
throws OptimizationException {
initializeEstimate(problem);
@ -397,17 +399,17 @@ public class LevenbergMarquardtEstimator extends AbstractEstimator implements Se
// tests for termination and stringent tolerances
// (2.2204e-16 is the machine epsilon for IEEE754)
if ((Math.abs(actRed) <= 2.2204e-16) && (preRed <= 2.2204e-16) && (ratio <= 2.0)) {
throw new EstimationException("cost relative tolerance is too small ({0})," +
throw new OptimizationException("cost relative tolerance is too small ({0})," +
" no further reduction in the" +
" sum of squares is possible",
costRelativeTolerance);
} else if (delta <= 2.2204e-16 * xNorm) {
throw new EstimationException("parameters relative tolerance is too small" +
throw new OptimizationException("parameters relative tolerance is too small" +
" ({0}), no further improvement in" +
" the approximate solution is possible",
parRelativeTolerance);
} else if (maxCosine <= 2.2204e-16) {
throw new EstimationException("orthogonality tolerance is too small ({0})," +
throw new OptimizationException("orthogonality tolerance is too small ({0})," +
" solution is orthogonal to the jacobian",
orthoTolerance);
}
@ -732,9 +734,9 @@ public class LevenbergMarquardtEstimator extends AbstractEstimator implements Se
* are performed in non-increasing columns norms order thanks to columns
* pivoting. The diagonal elements of the R matrix are therefore also in
* non-increasing absolute values order.</p>
* @exception EstimationException if the decomposition cannot be performed
* @exception OptimizationException if the decomposition cannot be performed
*/
private void qrDecomposition() throws EstimationException {
private void qrDecomposition() throws OptimizationException {
// initializations
for (int k = 0; k < cols; ++k) {
@ -761,7 +763,7 @@ public class LevenbergMarquardtEstimator extends AbstractEstimator implements Se
norm2 += aki * aki;
}
if (Double.isInfinite(norm2) || Double.isNaN(norm2)) {
throw new EstimationException(
throw new OptimizationException(
"unable to perform Q.R decomposition on the {0}x{1} jacobian matrix",
rows, cols);
}

View File

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
import java.util.ArrayList;
import java.util.List;

View File

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
import java.io.Serializable;

View File

@ -0,0 +1,22 @@
<html>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<!-- $Revision$ -->
<body>
This package provides optimization algorithms that require derivatives.
</body>
</html>

View File

@ -18,7 +18,36 @@
<!-- $Revision$ -->
<body>
<p>
This package provides parametric optimization algorithms.
This package provides common interfaces for the optimization algorithms
provided in sub-packages. The main interfaces are {@link
org.apache.commons.math.optimization.ObjectiveFunction ObjectiveFunction}
and {@link org.apache.commons.math.optimization.MultiObjectiveFunction
MultiObjectiveFunction}. Both interfaces are intended to be implemented by
user code to represent the problem they want to optimize.
</p>
<p>
The {@link org.apache.commons.math.optimization.ObjectiveFunction ObjectiveFunction}
interface represent a scalar function that should be either minimized or maximized,
by changing its input variables set until an optimal set is found. This function is
often called a cost function when the goal is to minimize it.
</p>
<p>
The {@link org.apache.commons.math.optimization.MultiObjectiveFunction
MultiObjectiveFunction} interface represent a vectorial function that should be either
minimized or maximized, by changing its input variables set until an optimal set is
found.
</p>
<p>
The {@link org.apache.commons.math.optimization.LeastSquaresConverter
LeastSquaresConverter} class can be used to convert vectorial objective functions to
scalar objective functions when the goal is to minimize them. This class is mostly used
when the vectorial objective function represents residuals, i.e. differences between a
theoretical result computed from a variables set applied to a model and a reference.
Residuals are intended to be minimized in order to get the variables set that best fit
the model to the reference. The reference may be obtained for example from physical
measurements whether the model is built from theoretical considerations.
</p>
</body>
</html>

View File

@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.analysis.minimization;
package org.apache.commons.math.optimization.univariate;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.MaxIterationsExceededException;

View File

@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.analysis.minimization;
package org.apache.commons.math.optimization.univariate;
import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.ConvergingAlgorithm;
@ -47,8 +47,7 @@ public interface UnivariateRealMinimizer extends ConvergingAlgorithm {
* satisfy the requirements specified by the minimizer
*/
double minimize(UnivariateRealFunction f, double min, double max)
throws ConvergenceException,
FunctionEvaluationException;
throws ConvergenceException, FunctionEvaluationException;
/**
* Find a minimum in the given interval, start at startValue.

View File

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.commons.math.analysis.minimization;
package org.apache.commons.math.optimization.univariate;
import org.apache.commons.math.ConvergingAlgorithmImpl;
import org.apache.commons.math.MathRuntimeException;

View File

@ -15,12 +15,12 @@
* limitations under the License.
*/
import org.apache.commons.math.estimation.EstimationException;
import org.apache.commons.math.estimation.EstimatedParameter;
import org.apache.commons.math.estimation.EstimationProblem;
import org.apache.commons.math.estimation.LevenbergMarquardtEstimator;
import org.apache.commons.math.estimation.SimpleEstimationProblem;
import org.apache.commons.math.estimation.WeightedMeasurement;
import org.apache.commons.math.optimization.general.EstimationException;
import org.apache.commons.math.optimization.general.EstimatedParameter;
import org.apache.commons.math.optimization.general.EstimationProblem;
import org.apache.commons.math.optimization.general.LevenbergMarquardtEstimator;
import org.apache.commons.math.optimization.general.SimpleEstimationProblem;
import org.apache.commons.math.optimization.general.WeightedMeasurement;
public class TrajectoryDeterminationProblem extends SimpleEstimationProblem {

View File

@ -237,19 +237,7 @@ double c = solver.solve(function, 1.0, 5.0);</source>
</table>
</p>
</subsection>
<subsection name="4.3 Minimization" href="minimization">
<p>
A <a href="../apidocs/org/apache/commons/math/analysis/minimization/UnivariateRealMinimizer.html">
org.apache.commons.math.analysis.minimization.UnivariateRealMinimizer</a>
is used to find the minimal values of a univariate real-valued function <code>f</code>.
</p>
<p>
Minimization algorithms usage is very similar to root-finding algorithms usage explained
above. The main difference is that the <code>solve</code> methods in root finding algorithms
is replaced by <code>minimize</code> methods.
</p>
</subsection>
<subsection name="4.4 Interpolation" href="interpolation">
<subsection name="4.3 Interpolation" href="interpolation">
<p>
A <a href="../apidocs/org/apache/commons/math/analysis/interpolation/UnivariateRealInterpolator.html">
org.apache.commons.math.analysis.interpolation.UnivariateRealInterpolator</a>
@ -300,7 +288,7 @@ System.out println("f(" + interpolationX + ") = " + interpolatedY);</source>
adding more points does not always lead to a better interpolation.
</p>
</subsection>
<subsection name="4.5 Integration" href="integration">
<subsection name="4.4 Integration" href="integration">
<p>
A <a href="../apidocs/org/apache/commons/math/analysis/integration/UnivariateRealIntegrator.html">
org.apache.commons.math.analysis.integration.UnivariateRealIntegrator.</a>
@ -317,7 +305,7 @@ System.out println("f(" + interpolationX + ") = " + interpolatedY);</source>
</ul>
</p>
</subsection>
<subsection name="4.6 Polynomials" href="polynomials">
<subsection name="4.5 Polynomials" href="polynomials">
<p>
The <a href="../apidocs/org/apache/commons/math/analysis/polynomials/package-summary.html">
org.apache.commons.math.analysis.polynomials</a> package provides real coefficients

View File

@ -67,10 +67,9 @@
<ul>
<li><a href="analysis.html#a4.1_Overview">4.1 Overview</a></li>
<li><a href="analysis.html#a4.2_Root-finding">4.2 Root-finding</a></li>
<li><a href="analysis.html#a4.3_Minimization">4.3 Minimization</a></li>
<li><a href="analysis.html#a4.4_Interpolation">4.4 Interpolation</a></li>
<li><a href="analysis.html#a4.5_Integration">4.5 Integration</a></li>
<li><a href="analysis.html#a4.6_Polynomials">4.6 Polynomials</a></li>
<li><a href="analysis.html#a4.3_Interpolation">4.3 Interpolation</a></li>
<li><a href="analysis.html#a4.4_Integration">4.4 Integration</a></li>
<li><a href="analysis.html#a4.5_Polynomials">4.5 Polynomials</a></li>
</ul></li>
<li><a href="special.html">5. Special Functions</a>
<ul>
@ -124,7 +123,10 @@
<li><a href="optimization.html">13. Optimization</a>
<ul>
<li><a href="optimization.html#a13.1_Overview">13.1 Overview</a></li>
<li><a href="optimization.html#a13.2_Direct_Methods">13.2 Direct Methods</a></li>
<li><a href="analysis.html#a13.2_Univariate_Functions">13.2 Univariate Functions</a></li>
<li><a href="analysis.html#a13.3_Linear_Programming">13.3 Linear Programming</a></li>
<li><a href="optimization.html#a13.4_Direct_Methods">13.4 Direct Methods</a></li>
<li><a href="analysis.html#a13.5_General_Case">13.5 General Case</a></li>
</ul></li>
<li><a href="ode.html">14. Ordinary Differential Equations Integration</a>
<ul>

View File

@ -29,24 +29,37 @@
<section name="13 Optimization">
<subsection name="13.1 Overview" href="overview">
<p>
The optimization package provides simplex-based direct search optimization algorithms.
</p>
<p>
The aim of this package is similar to the aim of the estimation package, but the
algorithms are entirely differents as:
The optimization package provides algorithms to optimize (i.e. minimize) some
objective or cost function. The package is split in several sub-packages
dedicated to different kind of functions or algorithms.
<ul>
<li>
they do not need the partial derivatives of the measurements
with respect to the free parameters
</li>
<li>
they do not rely on residuals-based quadratic cost functions but
handle any cost functions, including non-continuous ones!
</li>
<li>the univariate package handles univariate real functions,</li>
<li>the linear package handles multivariate vector linear functions
with linear constraints,</li>
<li>the direct package handles multivariate real functions using direct
search methods (i.e. not using derivatives),</li>
<li>the general package handles multivariate real or vector functions
using derivatives.</li>
</ul>
</p>
</subsection>
<subsection name="13.2 Direct Methods" href="direct">
<subsection name="13.2 Univariate Functions" href="univariate">
<p>
A <a href="../apidocs/org/apache/commons/math/analysis/minimization/UnivariateRealMinimizer.html">
org.apache.commons.math.optimization.univariate.UnivariateRealMinimizer</a>
is used to find the minimal values of a univariate real-valued function <code>f</code>.
</p>
<p>
Minimization algorithms usage is very similar to root-finding algorithms usage explained
in the analysis package. The main difference is that the <code>solve</code> methods in root
finding algorithms is replaced by <code>minimize</code> methods.
</p>
</subsection>
<subsection name="13.3 Linear Programming" href="linear">
<p>
</p>
</subsection>
<subsection name="13.4 Direct Methods" href="direct">
<p>
Direct search methods only use cost function values, they don't
need derivatives and don't either try to compute approximation of
@ -77,13 +90,17 @@
already provided by the <code>minimizes</code> method).
</p>
<p>
The package provides two solvers. The first one is the classical
<a href="../apidocs/org/apache/commons/math/optimization/NelderMead.html">
The <code>direct</code> package provides two solvers. The first one is the classical
<a href="../apidocs/org/apache/commons/math/optimization/direct/NelderMead.html">
Nelder-Mead</a> method. The second one is Virginia Torczon's
<a href="../apidocs/org/apache/commons/math/optimization/MultiDirectional.html">
<a href="../apidocs/org/apache/commons/math/optimization/direct/MultiDirectional.html">
multi-directional</a> method.
</p>
</subsection>
<subsection name="13.5 General Case" href="general">
<p>
</p>
</subsection>
</section>
</body>
</document>

View File

@ -85,7 +85,7 @@
<li><a href="transform.html">org.apache.commons.math.transform</a> - transform methods (Fast Fourier)</li>
<li><a href="geometry.html">org.apache.commons.math.geometry</a> - 3D geometry (vectors and rotations)</li>
<li><a href="estimation.html">org.apache.commons.math.estimation</a> - parametric estimation problems</li>
<li><a href="optimization.html">org.apache.commons.math.optimization</a> - multi-dimensional functions minimization</li>
<li><a href="optimization.html">org.apache.commons.math.optimization</a> - functions minimization</li>
<li><a href="ode.html">org.apache.commons.math.ode</a> - Ordinary Differential Equations integration</li>
</ol>
Package javadocs are <a href="../apidocs/index.html">here</a>

View File

@ -1,143 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization;
import org.apache.commons.math.optimization.ConvergenceChecker;
import org.apache.commons.math.optimization.CostException;
import org.apache.commons.math.optimization.CostFunction;
import org.apache.commons.math.optimization.MultiDirectional;
import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.optimization.PointCostPair;
import junit.framework.*;
public class MultiDirectionalTest
extends TestCase {
public MultiDirectionalTest(String name) {
super(name);
}
public void testCostExceptions() throws ConvergenceException {
CostFunction wrong =
new CostFunction() {
public double cost(double[] x) throws CostException {
if (x[0] < 0) {
throw new CostException("{0}", "oops");
} else if (x[0] > 1) {
throw new CostException(new RuntimeException("oops"));
} else {
return x[0] * (1 - x[0]);
}
}
};
try {
new MultiDirectional(1.9, 0.4).minimize(wrong, 10, new ValueChecker(1.0e-3),
new double[] { -0.5 }, new double[] { 0.5 });
fail("an exception should have been thrown");
} catch (CostException ce) {
// expected behavior
assertNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
try {
new MultiDirectional(1.9, 0.4).minimize(wrong, 10, new ValueChecker(1.0e-3),
new double[] { 0.5 }, new double[] { 1.5 });
fail("an exception should have been thrown");
} catch (CostException ce) {
// expected behavior
assertNotNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
}
public void testRosenbrock()
throws CostException, ConvergenceException {
CostFunction rosenbrock =
new CostFunction() {
public double cost(double[] x) {
++count;
double a = x[1] - x[0] * x[0];
double b = 1.0 - x[0];
return 100 * a * a + b * b;
}
};
count = 0;
PointCostPair optimum =
new MultiDirectional().minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
new double[][] {
{ -1.2, 1.0 }, { 0.9, 1.2 } , { 3.5, -2.3 }
});
assertTrue(count > 60);
assertTrue(optimum.getCost() > 0.01);
}
public void testPowell()
throws CostException, ConvergenceException {
CostFunction powell =
new CostFunction() {
public double cost(double[] x) {
++count;
double a = x[0] + 10 * x[1];
double b = x[2] - x[3];
double c = x[1] - 2 * x[2];
double d = x[0] - x[3];
return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
}
};
count = 0;
PointCostPair optimum =
new MultiDirectional().minimize(powell, 1000, new ValueChecker(1.0e-3),
new double[] { 3.0, -1.0, 0.0, 1.0 },
new double[] { 4.0, 0.0, 1.0, 2.0 });
assertTrue(count > 850);
assertTrue(optimum.getCost() > 0.015);
}
private static class ValueChecker implements ConvergenceChecker {
public ValueChecker(double threshold) {
this.threshold = threshold;
}
public boolean converged(PointCostPair[] simplex) {
PointCostPair smallest = simplex[0];
PointCostPair largest = simplex[simplex.length - 1];
return (largest.getCost() - smallest.getCost()) < threshold;
}
private double threshold;
};
public static Test suite() {
return new TestSuite(MultiDirectionalTest.class);
}
private int count;
}

View File

@ -1,202 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization;
import org.apache.commons.math.linear.decomposition.NotPositiveDefiniteMatrixException;
import org.apache.commons.math.optimization.ConvergenceChecker;
import org.apache.commons.math.optimization.CostException;
import org.apache.commons.math.optimization.CostFunction;
import org.apache.commons.math.optimization.NelderMead;
import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.optimization.PointCostPair;
import org.apache.commons.math.random.JDKRandomGenerator;
import org.apache.commons.math.random.RandomGenerator;
import org.apache.commons.math.random.RandomVectorGenerator;
import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
import org.apache.commons.math.random.UniformRandomGenerator;
import junit.framework.*;
public class NelderMeadTest
extends TestCase {
public NelderMeadTest(String name) {
super(name);
}
public void testCostExceptions() throws ConvergenceException {
CostFunction wrong =
new CostFunction() {
public double cost(double[] x) throws CostException {
if (x[0] < 0) {
throw new CostException("{0}", "oops");
} else if (x[0] > 1) {
throw new CostException(new RuntimeException("oops"));
} else {
return x[0] * (1 - x[0]);
}
}
};
try {
new NelderMead(0.9, 1.9, 0.4, 0.6).minimize(wrong, 10, new ValueChecker(1.0e-3),
new double[] { -0.5 }, new double[] { 0.5 });
fail("an exception should have been thrown");
} catch (CostException ce) {
// expected behavior
assertNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
try {
new NelderMead(0.9, 1.9, 0.4, 0.6).minimize(wrong, 10, new ValueChecker(1.0e-3),
new double[] { 0.5 }, new double[] { 1.5 });
fail("an exception should have been thrown");
} catch (CostException ce) {
// expected behavior
assertNotNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
}
public void testRosenbrock()
throws CostException, ConvergenceException, NotPositiveDefiniteMatrixException {
CostFunction rosenbrock =
new CostFunction() {
public double cost(double[] x) {
++count;
double a = x[1] - x[0] * x[0];
double b = 1.0 - x[0];
return 100 * a * a + b * b;
}
};
count = 0;
NelderMead nm = new NelderMead();
try {
nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
new double[][] {
{ -1.2, 1.0 }, { 3.5, -2.3 }, { 0.4, 1.5 }
}, 1, 5384353l);
fail("an exception should have been thrown");
} catch (ConvergenceException ce) {
// expected behavior
} catch (Exception e) {
e.printStackTrace(System.err);
fail("wrong exception caught: " + e.getMessage());
}
count = 0;
PointCostPair optimum =
nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
new double[][] {
{ -1.2, 1.0 }, { 0.9, 1.2 }, { 3.5, -2.3 }
}, 10, 1642738l);
assertTrue(count > 700);
assertTrue(count < 800);
assertEquals(0.0, optimum.getCost(), 5.0e-5);
assertEquals(1.0, optimum.getPoint()[0], 0.01);
assertEquals(1.0, optimum.getPoint()[1], 0.01);
PointCostPair[] minima = nm.getMinima();
assertEquals(10, minima.length);
assertNotNull(minima[0]);
assertNull(minima[minima.length - 1]);
for (int i = 0; i < minima.length; ++i) {
if (minima[i] == null) {
if ((i + 1) < minima.length) {
assertTrue(minima[i+1] == null);
}
} else {
if (i > 0) {
assertTrue(minima[i-1].getCost() <= minima[i].getCost());
}
}
}
RandomGenerator rg = new JDKRandomGenerator();
rg.setSeed(64453353l);
RandomVectorGenerator rvg =
new UncorrelatedRandomVectorGenerator(new double[] { 0.9, 1.1 },
new double[] { 0.2, 0.2 },
new UniformRandomGenerator(rg));
optimum =
nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3), rvg);
assertEquals(0.0, optimum.getCost(), 2.0e-4);
optimum =
nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3), rvg, 3);
assertEquals(0.0, optimum.getCost(), 3.0e-5);
}
public void testPowell()
throws CostException, ConvergenceException {
CostFunction powell =
new CostFunction() {
public double cost(double[] x) {
++count;
double a = x[0] + 10 * x[1];
double b = x[2] - x[3];
double c = x[1] - 2 * x[2];
double d = x[0] - x[3];
return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
}
};
count = 0;
NelderMead nm = new NelderMead();
PointCostPair optimum =
nm.minimize(powell, 200, new ValueChecker(1.0e-3),
new double[] { 3.0, -1.0, 0.0, 1.0 },
new double[] { 4.0, 0.0, 1.0, 2.0 },
1, 1642738l);
assertTrue(count < 150);
assertEquals(0.0, optimum.getCost(), 6.0e-4);
assertEquals(0.0, optimum.getPoint()[0], 0.07);
assertEquals(0.0, optimum.getPoint()[1], 0.07);
assertEquals(0.0, optimum.getPoint()[2], 0.07);
assertEquals(0.0, optimum.getPoint()[3], 0.07);
}
private static class ValueChecker implements ConvergenceChecker {
public ValueChecker(double threshold) {
this.threshold = threshold;
}
public boolean converged(PointCostPair[] simplex) {
PointCostPair smallest = simplex[0];
PointCostPair largest = simplex[simplex.length - 1];
return (largest.getCost() - smallest.getCost()) < threshold;
}
private double threshold;
};
public static Test suite() {
return new TestSuite(NelderMeadTest.class);
}
private int count;
}

View File

@ -0,0 +1,229 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization.direct;
import org.apache.commons.math.linear.decomposition.NotPositiveDefiniteMatrixException;
import org.apache.commons.math.optimization.ConvergenceChecker;
import org.apache.commons.math.optimization.ObjectiveException;
import org.apache.commons.math.optimization.ObjectiveFunction;
import org.apache.commons.math.optimization.PointValuePair;
import org.apache.commons.math.ConvergenceException;
import junit.framework.*;
public class MultiDirectionalTest
extends TestCase {
public MultiDirectionalTest(String name) {
super(name);
}
public void testObjectiveExceptions() throws ConvergenceException {
ObjectiveFunction wrong =
new ObjectiveFunction() {
private static final long serialVersionUID = 4751314470965489371L;
public double objective(double[] x) throws ObjectiveException {
if (x[0] < 0) {
throw new ObjectiveException("{0}", "oops");
} else if (x[0] > 1) {
throw new ObjectiveException(new RuntimeException("oops"));
} else {
return x[0] * (1 - x[0]);
}
}
};
try {
new MultiDirectional(1.9, 0.4).optimize(wrong, 10, new ValueChecker(1.0e-3), true,
new double[] { -0.5 }, new double[] { 0.5 });
fail("an exception should have been thrown");
} catch (ObjectiveException ce) {
// expected behavior
assertNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
try {
new MultiDirectional(1.9, 0.4).optimize(wrong, 10, new ValueChecker(1.0e-3), true,
new double[] { 0.5 }, new double[] { 1.5 });
fail("an exception should have been thrown");
} catch (ObjectiveException ce) {
// expected behavior
assertNotNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
}
public void testMinimizeMaximize()
throws ObjectiveException, ConvergenceException, NotPositiveDefiniteMatrixException {
// the following function has 4 local extrema:
final double xM = -3.841947088256863675365;
final double yM = -1.391745200270734924416;
final double xP = 0.2286682237349059125691;
final double yP = -yM;
final double valueXmYm = 0.2373295333134216789769; // local maximum
final double valueXmYp = -valueXmYm; // local minimum
final double valueXpYm = -0.7290400707055187115322; // global minimum
final double valueXpYp = -valueXpYm; // global maximum
ObjectiveFunction fourExtrema = new ObjectiveFunction() {
private static final long serialVersionUID = -7039124064449091152L;
public double objective(double[] variables) {
final double x = variables[0];
final double y = variables[1];
return Math.atan(x) * Math.atan(x + 2) * Math.atan(y) * Math.atan(y) / (x * y);
}
};
MultiDirectional md = new MultiDirectional();
// minimization
md.optimize(fourExtrema, 200, new ValueChecker(1.0e-8), true,
new double[] { -4, -2 }, new double[] { 1, 2 }, 10, 38821113105892l);
PointValuePair[] optima = md.getOptima();
assertEquals(10, optima.length);
int localCount = 0;
int globalCount = 0;
for (PointValuePair optimum : optima) {
if (optimum != null) {
if (optimum.getPoint()[0] < 0) {
// this should be the local minimum
++localCount;
assertEquals(xM, optimum.getPoint()[0], 1.0e-3);
assertEquals(yP, optimum.getPoint()[1], 1.0e-3);
assertEquals(valueXmYp, optimum.getValue(), 3.0e-8);
} else {
// this should be the global minimum
++globalCount;
assertEquals(xP, optimum.getPoint()[0], 1.0e-3);
assertEquals(yM, optimum.getPoint()[1], 1.0e-3);
assertEquals(valueXpYm, optimum.getValue(), 3.0e-8);
}
}
}
assertTrue(localCount > 0);
assertTrue(globalCount > 0);
assertTrue(md.getTotalEvaluations() > 1400);
assertTrue(md.getTotalEvaluations() < 1700);
// minimization
md.optimize(fourExtrema, 200, new ValueChecker(1.0e-8), false,
new double[] { -3.5, -1 }, new double[] { 0.5, 1.5 }, 10, 38821113105892l);
optima = md.getOptima();
assertEquals(10, optima.length);
localCount = 0;
globalCount = 0;
for (PointValuePair optimum : optima) {
if (optimum != null) {
if (optimum.getPoint()[0] < 0) {
// this should be the local maximum
++localCount;
assertEquals(xM, optimum.getPoint()[0], 1.0e-3);
assertEquals(yM, optimum.getPoint()[1], 1.0e-3);
assertEquals(valueXmYm, optimum.getValue(), 4.0e-8);
} else {
// this should be the global maximum
++globalCount;
assertEquals(xP, optimum.getPoint()[0], 1.0e-3);
assertEquals(yP, optimum.getPoint()[1], 1.0e-3);
assertEquals(valueXpYp, optimum.getValue(), 4.0e-8);
}
}
}
assertTrue(localCount > 0);
assertTrue(globalCount > 0);
assertTrue(md.getTotalEvaluations() > 1400);
assertTrue(md.getTotalEvaluations() < 1700);
}
public void testRosenbrock()
throws ObjectiveException, ConvergenceException {
ObjectiveFunction rosenbrock =
new ObjectiveFunction() {
private static final long serialVersionUID = -9044950469615237490L;
public double objective(double[] x) {
++count;
double a = x[1] - x[0] * x[0];
double b = 1.0 - x[0];
return 100 * a * a + b * b;
}
};
count = 0;
PointValuePair optimum =
new MultiDirectional().optimize(rosenbrock, 100, new ValueChecker(1.0e-3), true,
new double[][] {
{ -1.2, 1.0 }, { 0.9, 1.2 } , { 3.5, -2.3 }
});
assertTrue(count > 60);
assertTrue(optimum.getValue() > 0.01);
}
public void testPowell()
throws ObjectiveException, ConvergenceException {
ObjectiveFunction powell =
new ObjectiveFunction() {
private static final long serialVersionUID = -832162886102041840L;
public double objective(double[] x) {
++count;
double a = x[0] + 10 * x[1];
double b = x[2] - x[3];
double c = x[1] - 2 * x[2];
double d = x[0] - x[3];
return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
}
};
count = 0;
PointValuePair optimum =
new MultiDirectional().optimize(powell, 1000, new ValueChecker(1.0e-3), true,
new double[] { 3.0, -1.0, 0.0, 1.0 },
new double[] { 4.0, 0.0, 1.0, 2.0 });
assertTrue(count > 850);
assertTrue(optimum.getValue() > 0.015);
}
private static class ValueChecker implements ConvergenceChecker {
public ValueChecker(double threshold) {
this.threshold = threshold;
}
public boolean converged(PointValuePair[] simplex) {
PointValuePair smallest = simplex[0];
PointValuePair largest = simplex[simplex.length - 1];
return (largest.getValue() - smallest.getValue()) < threshold;
}
private double threshold;
};
public static Test suite() {
return new TestSuite(MultiDirectionalTest.class);
}
private int count;
}

View File

@ -0,0 +1,287 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization.direct;
import org.apache.commons.math.linear.decomposition.NotPositiveDefiniteMatrixException;
import org.apache.commons.math.optimization.ConvergenceChecker;
import org.apache.commons.math.optimization.ObjectiveException;
import org.apache.commons.math.optimization.ObjectiveFunction;
import org.apache.commons.math.optimization.PointValuePair;
import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.random.JDKRandomGenerator;
import org.apache.commons.math.random.RandomGenerator;
import org.apache.commons.math.random.RandomVectorGenerator;
import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
import org.apache.commons.math.random.UniformRandomGenerator;
import junit.framework.*;
public class NelderMeadTest
extends TestCase {
public NelderMeadTest(String name) {
super(name);
}
public void testObjectiveExceptions() throws ConvergenceException {
ObjectiveFunction wrong =
new ObjectiveFunction() {
private static final long serialVersionUID = 2624035220997628868L;
public double objective(double[] x) throws ObjectiveException {
if (x[0] < 0) {
throw new ObjectiveException("{0}", "oops");
} else if (x[0] > 1) {
throw new ObjectiveException(new RuntimeException("oops"));
} else {
return x[0] * (1 - x[0]);
}
}
};
try {
new NelderMead(0.9, 1.9, 0.4, 0.6).optimize(wrong, 10, new ValueChecker(1.0e-3), true,
new double[] { -0.5 }, new double[] { 0.5 });
fail("an exception should have been thrown");
} catch (ObjectiveException ce) {
// expected behavior
assertNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
try {
new NelderMead(0.9, 1.9, 0.4, 0.6).optimize(wrong, 10, new ValueChecker(1.0e-3), true,
new double[] { 0.5 }, new double[] { 1.5 });
fail("an exception should have been thrown");
} catch (ObjectiveException ce) {
// expected behavior
assertNotNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
}
public void testMinimizeMaximize()
throws ObjectiveException, ConvergenceException, NotPositiveDefiniteMatrixException {
// the following function has 4 local extrema:
final double xM = -3.841947088256863675365;
final double yM = -1.391745200270734924416;
final double xP = 0.2286682237349059125691;
final double yP = -yM;
final double valueXmYm = 0.2373295333134216789769; // local maximum
final double valueXmYp = -valueXmYm; // local minimum
final double valueXpYm = -0.7290400707055187115322; // global minimum
final double valueXpYp = -valueXpYm; // global maximum
ObjectiveFunction fourExtrema = new ObjectiveFunction() {
private static final long serialVersionUID = -7039124064449091152L;
public double objective(double[] variables) {
final double x = variables[0];
final double y = variables[1];
return Math.atan(x) * Math.atan(x + 2) * Math.atan(y) * Math.atan(y) / (x * y);
}
};
NelderMead nm = new NelderMead();
// minimization
nm.optimize(fourExtrema, 100, new ValueChecker(1.0e-8), true,
new double[] { -5, -5 }, new double[] { 5, 5 }, 10, 38821113105892l);
PointValuePair[] optima = nm.getOptima();
assertEquals(10, optima.length);
int localCount = 0;
int globalCount = 0;
for (PointValuePair optimum : optima) {
if (optimum != null) {
if (optimum.getPoint()[0] < 0) {
// this should be the local minimum
++localCount;
assertEquals(xM, optimum.getPoint()[0], 1.0e-3);
assertEquals(yP, optimum.getPoint()[1], 1.0e-3);
assertEquals(valueXmYp, optimum.getValue(), 2.0e-8);
} else {
// this should be the global minimum
++globalCount;
assertEquals(xP, optimum.getPoint()[0], 1.0e-3);
assertEquals(yM, optimum.getPoint()[1], 1.0e-3);
assertEquals(valueXpYm, optimum.getValue(), 2.0e-8);
}
}
}
assertTrue(localCount > 0);
assertTrue(globalCount > 0);
assertTrue(nm.getTotalEvaluations() > 600);
assertTrue(nm.getTotalEvaluations() < 800);
// minimization
nm.optimize(fourExtrema, 100, new ValueChecker(1.0e-8), false,
new double[] { -5, -5 }, new double[] { 5, 5 }, 10, 38821113105892l);
optima = nm.getOptima();
assertEquals(10, optima.length);
localCount = 0;
globalCount = 0;
for (PointValuePair optimum : optima) {
if (optimum != null) {
if (optimum.getPoint()[0] < 0) {
// this should be the local maximum
++localCount;
assertEquals(xM, optimum.getPoint()[0], 1.0e-3);
assertEquals(yM, optimum.getPoint()[1], 1.0e-3);
assertEquals(valueXmYm, optimum.getValue(), 2.0e-8);
} else {
// this should be the global maximum
++globalCount;
assertEquals(xP, optimum.getPoint()[0], 1.0e-3);
assertEquals(yP, optimum.getPoint()[1], 1.0e-3);
assertEquals(valueXpYp, optimum.getValue(), 2.0e-8);
}
}
}
assertTrue(localCount > 0);
assertTrue(globalCount > 0);
assertTrue(nm.getTotalEvaluations() > 600);
assertTrue(nm.getTotalEvaluations() < 800);
}
public void testRosenbrock()
throws ObjectiveException, ConvergenceException, NotPositiveDefiniteMatrixException {
ObjectiveFunction rosenbrock =
new ObjectiveFunction() {
private static final long serialVersionUID = -7039124064449091152L;
public double objective(double[] x) {
++count;
double a = x[1] - x[0] * x[0];
double b = 1.0 - x[0];
return 100 * a * a + b * b;
}
};
count = 0;
NelderMead nm = new NelderMead();
try {
nm.optimize(rosenbrock, 100, new ValueChecker(1.0e-3), true,
new double[][] {
{ -1.2, 1.0 }, { 3.5, -2.3 }, { 0.4, 1.5 }
}, 1, 5384353l);
fail("an exception should have been thrown");
} catch (ConvergenceException ce) {
// expected behavior
} catch (Exception e) {
e.printStackTrace(System.err);
fail("wrong exception caught: " + e.getMessage());
}
count = 0;
PointValuePair optimum =
nm.optimize(rosenbrock, 100, new ValueChecker(1.0e-3), true,
new double[][] {
{ -1.2, 1.0 }, { 0.9, 1.2 }, { 3.5, -2.3 }
}, 10, 1642738l);
assertTrue(count > 700);
assertTrue(count < 800);
assertEquals(0.0, optimum.getValue(), 5.0e-5);
assertEquals(1.0, optimum.getPoint()[0], 0.01);
assertEquals(1.0, optimum.getPoint()[1], 0.01);
PointValuePair[] minima = nm.getOptima();
assertEquals(10, minima.length);
assertNotNull(minima[0]);
assertNull(minima[minima.length - 1]);
for (int i = 0; i < minima.length; ++i) {
if (minima[i] == null) {
if ((i + 1) < minima.length) {
assertTrue(minima[i+1] == null);
}
} else {
if (i > 0) {
assertTrue(minima[i-1].getValue() <= minima[i].getValue());
}
}
}
RandomGenerator rg = new JDKRandomGenerator();
rg.setSeed(64453353l);
RandomVectorGenerator rvg =
new UncorrelatedRandomVectorGenerator(new double[] { 0.9, 1.1 },
new double[] { 0.2, 0.2 },
new UniformRandomGenerator(rg));
optimum =
nm.optimize(rosenbrock, 100, new ValueChecker(1.0e-3), true, rvg);
assertEquals(0.0, optimum.getValue(), 2.0e-4);
optimum =
nm.optimize(rosenbrock, 100, new ValueChecker(1.0e-3), true, rvg, 3);
assertEquals(0.0, optimum.getValue(), 3.0e-5);
}
public void testPowell()
throws ObjectiveException, ConvergenceException {
ObjectiveFunction powell =
new ObjectiveFunction() {
private static final long serialVersionUID = -7681075710859391520L;
public double objective(double[] x) {
++count;
double a = x[0] + 10 * x[1];
double b = x[2] - x[3];
double c = x[1] - 2 * x[2];
double d = x[0] - x[3];
return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
}
};
count = 0;
NelderMead nm = new NelderMead();
PointValuePair optimum =
nm.optimize(powell, 200, new ValueChecker(1.0e-3), true,
new double[] { 3.0, -1.0, 0.0, 1.0 },
new double[] { 4.0, 0.0, 1.0, 2.0 },
1, 1642738l);
assertTrue(count < 150);
assertEquals(0.0, optimum.getValue(), 6.0e-4);
assertEquals(0.0, optimum.getPoint()[0], 0.07);
assertEquals(0.0, optimum.getPoint()[1], 0.07);
assertEquals(0.0, optimum.getPoint()[2], 0.07);
assertEquals(0.0, optimum.getPoint()[3], 0.07);
}
private static class ValueChecker implements ConvergenceChecker {
public ValueChecker(double threshold) {
this.threshold = threshold;
}
public boolean converged(PointValuePair[] simplex) {
PointValuePair smallest = simplex[0];
PointValuePair largest = simplex[simplex.length - 1];
return (largest.getValue() - smallest.getValue()) < threshold;
}
private double threshold;
};
public static Test suite() {
return new TestSuite(NelderMeadTest.class);
}
private int count;
}

View File

@ -15,9 +15,8 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
import org.apache.commons.math.estimation.EstimatedParameter;
import junit.framework.*;

View File

@ -15,11 +15,14 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
import java.util.ArrayList;
import java.util.HashSet;
import org.apache.commons.math.optimization.OptimizationException;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
@ -93,7 +96,7 @@ public class GaussNewtonEstimatorTest
super(name);
}
public void testTrivial() throws EstimationException {
public void testTrivial() throws OptimizationException {
LinearProblem problem =
new LinearProblem(new LinearMeasurement[] {
new LinearMeasurement(new double[] {2},
@ -109,7 +112,7 @@ public class GaussNewtonEstimatorTest
1.0e-10);
}
public void testQRColumnsPermutation() throws EstimationException {
public void testQRColumnsPermutation() throws OptimizationException {
EstimatedParameter[] x = {
new EstimatedParameter("p0", 0), new EstimatedParameter("p1", 0)
@ -134,7 +137,7 @@ public class GaussNewtonEstimatorTest
}
public void testNoDependency() throws EstimationException {
public void testNoDependency() throws OptimizationException {
EstimatedParameter[] p = new EstimatedParameter[] {
new EstimatedParameter("p0", 0),
new EstimatedParameter("p1", 0),
@ -159,7 +162,7 @@ public class GaussNewtonEstimatorTest
}
}
public void testOneSet() throws EstimationException {
public void testOneSet() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 0),
@ -187,7 +190,7 @@ public class GaussNewtonEstimatorTest
}
public void testTwoSets() throws EstimationException {
public void testTwoSets() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 0),
new EstimatedParameter("p1", 1),
@ -236,7 +239,7 @@ public class GaussNewtonEstimatorTest
}
public void testNonInversible() throws EstimationException {
public void testNonInversible() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 0),
@ -260,14 +263,14 @@ public class GaussNewtonEstimatorTest
try {
estimator.estimate(problem);
fail("an exception should have been caught");
} catch (EstimationException ee) {
} catch (OptimizationException ee) {
// expected behavior
} catch (Exception e) {
fail("wrong exception type caught");
}
}
public void testIllConditioned() throws EstimationException {
public void testIllConditioned() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 0),
new EstimatedParameter("p1", 1),
@ -321,7 +324,7 @@ public class GaussNewtonEstimatorTest
}
public void testMoreEstimatedParametersSimple() throws EstimationException {
public void testMoreEstimatedParametersSimple() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 7),
@ -345,7 +348,7 @@ public class GaussNewtonEstimatorTest
try {
estimator.estimate(problem);
fail("an exception should have been caught");
} catch (EstimationException ee) {
} catch (OptimizationException ee) {
// expected behavior
} catch (Exception e) {
fail("wrong exception type caught");
@ -353,7 +356,7 @@ public class GaussNewtonEstimatorTest
}
public void testMoreEstimatedParametersUnsorted() throws EstimationException {
public void testMoreEstimatedParametersUnsorted() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 2),
new EstimatedParameter("p1", 2),
@ -384,7 +387,7 @@ public class GaussNewtonEstimatorTest
try {
estimator.estimate(problem);
fail("an exception should have been caught");
} catch (EstimationException ee) {
} catch (OptimizationException ee) {
// expected behavior
} catch (Exception e) {
fail("wrong exception type caught");
@ -392,7 +395,7 @@ public class GaussNewtonEstimatorTest
}
public void testRedundantEquations() throws EstimationException {
public void testRedundantEquations() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 1),
new EstimatedParameter("p1", 1)
@ -420,7 +423,7 @@ public class GaussNewtonEstimatorTest
}
public void testInconsistentEquations() throws EstimationException {
public void testInconsistentEquations() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 1),
new EstimatedParameter("p1", 1)
@ -443,7 +446,7 @@ public class GaussNewtonEstimatorTest
}
public void testBoundParameters() throws EstimationException {
public void testBoundParameters() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("unbound0", 2, false),
new EstimatedParameter("unbound1", 2, false),
@ -492,14 +495,14 @@ public class GaussNewtonEstimatorTest
GaussNewtonEstimator estimator = new GaussNewtonEstimator(4, 1.0e-14, 1.0e-14);
estimator.estimate(circle);
fail("an exception should have been caught");
} catch (EstimationException ee) {
} catch (OptimizationException ee) {
// expected behavior
} catch (Exception e) {
fail("wrong exception type caught");
}
}
public void testCircleFitting() throws EstimationException {
public void testCircleFitting() throws OptimizationException {
Circle circle = new Circle(98.680, 47.345);
circle.addPoint( 30.0, 68.0);
circle.addPoint( 50.0, -6.0);
@ -515,7 +518,7 @@ public class GaussNewtonEstimatorTest
assertEquals(48.13516790438953, circle.getY(), 1.0e-10);
}
public void testCircleFittingBadInit() throws EstimationException {
public void testCircleFittingBadInit() throws OptimizationException {
Circle circle = new Circle(-12, -12);
double[][] points = new double[][] {
{-0.312967, 0.072366}, {-0.339248, 0.132965}, {-0.379780, 0.202724},
@ -555,7 +558,7 @@ public class GaussNewtonEstimatorTest
try {
estimator.estimate(circle);
fail("an exception should have been caught");
} catch (EstimationException ee) {
} catch (OptimizationException ee) {
// expected behavior
} catch (Exception e) {
fail("wrong exception type caught");

View File

@ -15,11 +15,14 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
import java.util.ArrayList;
import java.util.HashSet;
import org.apache.commons.math.optimization.OptimizationException;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
@ -93,7 +96,7 @@ public class LevenbergMarquardtEstimatorTest
super(name);
}
public void testTrivial() throws EstimationException {
public void testTrivial() throws OptimizationException {
LinearProblem problem =
new LinearProblem(new LinearMeasurement[] {
new LinearMeasurement(new double[] {2},
@ -107,7 +110,7 @@ public class LevenbergMarquardtEstimatorTest
try {
estimator.guessParametersErrors(problem);
fail("an exception should have been thrown");
} catch (EstimationException ee) {
} catch (OptimizationException ee) {
// expected behavior
} catch (Exception e) {
fail("wrong exception caught");
@ -117,7 +120,7 @@ public class LevenbergMarquardtEstimatorTest
1.0e-10);
}
public void testQRColumnsPermutation() throws EstimationException {
public void testQRColumnsPermutation() throws OptimizationException {
EstimatedParameter[] x = {
new EstimatedParameter("p0", 0), new EstimatedParameter("p1", 0)
@ -142,7 +145,7 @@ public class LevenbergMarquardtEstimatorTest
}
public void testNoDependency() throws EstimationException {
public void testNoDependency() throws OptimizationException {
EstimatedParameter[] p = new EstimatedParameter[] {
new EstimatedParameter("p0", 0),
new EstimatedParameter("p1", 0),
@ -167,7 +170,7 @@ public class LevenbergMarquardtEstimatorTest
}
}
public void testOneSet() throws EstimationException {
public void testOneSet() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 0),
@ -195,7 +198,7 @@ public class LevenbergMarquardtEstimatorTest
}
public void testTwoSets() throws EstimationException {
public void testTwoSets() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 0),
new EstimatedParameter("p1", 1),
@ -244,7 +247,7 @@ public class LevenbergMarquardtEstimatorTest
}
public void testNonInversible() throws EstimationException {
public void testNonInversible() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 0),
@ -272,7 +275,7 @@ public class LevenbergMarquardtEstimatorTest
try {
estimator.getCovariances(problem);
fail("an exception should have been thrown");
} catch (EstimationException ee) {
} catch (OptimizationException ee) {
// expected behavior
} catch (Exception e) {
fail("wrong exception caught");
@ -291,7 +294,7 @@ public class LevenbergMarquardtEstimatorTest
}
public void testIllConditioned() throws EstimationException {
public void testIllConditioned() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 0),
new EstimatedParameter("p1", 1),
@ -345,7 +348,7 @@ public class LevenbergMarquardtEstimatorTest
}
public void testMoreEstimatedParametersSimple() throws EstimationException {
public void testMoreEstimatedParametersSimple() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 7),
@ -371,7 +374,7 @@ public class LevenbergMarquardtEstimatorTest
}
public void testMoreEstimatedParametersUnsorted() throws EstimationException {
public void testMoreEstimatedParametersUnsorted() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 2),
new EstimatedParameter("p1", 2),
@ -408,7 +411,7 @@ public class LevenbergMarquardtEstimatorTest
}
public void testRedundantEquations() throws EstimationException {
public void testRedundantEquations() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 1),
new EstimatedParameter("p1", 1)
@ -433,7 +436,7 @@ public class LevenbergMarquardtEstimatorTest
}
public void testInconsistentEquations() throws EstimationException {
public void testInconsistentEquations() throws OptimizationException {
EstimatedParameter[] p = {
new EstimatedParameter("p0", 1),
new EstimatedParameter("p1", 1)
@ -456,7 +459,7 @@ public class LevenbergMarquardtEstimatorTest
}
public void testControlParameters() throws EstimationException {
public void testControlParameters() throws OptimizationException {
Circle circle = new Circle(98.680, 47.345);
circle.addPoint( 30.0, 68.0);
circle.addPoint( 50.0, -6.0);
@ -483,14 +486,14 @@ public class LevenbergMarquardtEstimatorTest
estimator.setOrthoTolerance(orthoTolerance);
estimator.estimate(problem);
assertTrue(! shouldFail);
} catch (EstimationException ee) {
} catch (OptimizationException ee) {
assertTrue(shouldFail);
} catch (Exception e) {
fail("wrong exception type caught");
}
}
public void testCircleFitting() throws EstimationException {
public void testCircleFitting() throws OptimizationException {
Circle circle = new Circle(98.680, 47.345);
circle.addPoint( 30.0, 68.0);
circle.addPoint( 50.0, -6.0);
@ -535,7 +538,7 @@ public class LevenbergMarquardtEstimatorTest
}
public void testCircleFittingBadInit() throws EstimationException {
public void testCircleFittingBadInit() throws OptimizationException {
Circle circle = new Circle(-12, -12);
double[][] points = new double[][] {
{-0.312967, 0.072366}, {-0.339248, 0.132965}, {-0.379780, 0.202724},
@ -591,7 +594,7 @@ public class LevenbergMarquardtEstimatorTest
problem.addPoint (4, 1.7785661310051026, 0.0);
new LevenbergMarquardtEstimator().estimate(problem);
fail("an exception should have been thrown");
} catch (EstimationException ee) {
} catch (OptimizationException ee) {
// expected behavior
}

View File

@ -15,15 +15,12 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
import java.util.Arrays;
import org.apache.commons.math.estimation.EstimatedParameter;
import org.apache.commons.math.estimation.EstimationException;
import org.apache.commons.math.estimation.EstimationProblem;
import org.apache.commons.math.estimation.LevenbergMarquardtEstimator;
import org.apache.commons.math.estimation.WeightedMeasurement;
import org.apache.commons.math.optimization.OptimizationException;
import junit.framework.*;
@ -97,7 +94,7 @@ public class MinpackTest
}
public void testMinpackLinearFullRank()
throws EstimationException {
throws OptimizationException {
minpackTest(new LinearFullRankFunction(10, 5, 1.0,
5.0, 2.23606797749979), false);
minpackTest(new LinearFullRankFunction(50, 5, 1.0,
@ -105,7 +102,7 @@ public class MinpackTest
}
public void testMinpackLinearRank1()
throws EstimationException {
throws OptimizationException {
minpackTest(new LinearRank1Function(10, 5, 1.0,
291.521868819476, 1.4638501094228), false);
minpackTest(new LinearRank1Function(50, 5, 1.0,
@ -113,13 +110,13 @@ public class MinpackTest
}
public void testMinpackLinearRank1ZeroColsAndRows()
throws EstimationException {
throws OptimizationException {
minpackTest(new LinearRank1ZeroColsAndRowsFunction(10, 5, 1.0), false);
minpackTest(new LinearRank1ZeroColsAndRowsFunction(50, 5, 1.0), false);
}
public void testMinpackRosenbrok()
throws EstimationException {
throws OptimizationException {
minpackTest(new RosenbrockFunction(new double[] { -1.2, 1.0 },
Math.sqrt(24.2)), false);
minpackTest(new RosenbrockFunction(new double[] { -12.0, 10.0 },
@ -129,7 +126,7 @@ public class MinpackTest
}
public void testMinpackHelicalValley()
throws EstimationException {
throws OptimizationException {
minpackTest(new HelicalValleyFunction(new double[] { -1.0, 0.0, 0.0 },
50.0), false);
minpackTest(new HelicalValleyFunction(new double[] { -10.0, 0.0, 0.0 },
@ -139,7 +136,7 @@ public class MinpackTest
}
public void testMinpackPowellSingular()
throws EstimationException {
throws OptimizationException {
minpackTest(new PowellSingularFunction(new double[] { 3.0, -1.0, 0.0, 1.0 },
14.6628782986152), false);
minpackTest(new PowellSingularFunction(new double[] { 30.0, -10.0, 0.0, 10.0 },
@ -149,7 +146,7 @@ public class MinpackTest
}
public void testMinpackFreudensteinRoth()
throws EstimationException {
throws OptimizationException {
minpackTest(new FreudensteinRothFunction(new double[] { 0.5, -2.0 },
20.0124960961895, 6.99887517584575,
new double[] {
@ -171,7 +168,7 @@ public class MinpackTest
}
public void testMinpackBard()
throws EstimationException {
throws OptimizationException {
minpackTest(new BardFunction(1.0, 6.45613629515967, 0.0906359603390466,
new double[] {
0.0824105765758334,
@ -193,7 +190,7 @@ public class MinpackTest
}
public void testMinpackKowalikOsborne()
throws EstimationException {
throws OptimizationException {
minpackTest(new KowalikOsborneFunction(new double[] { 0.25, 0.39, 0.415, 0.39 },
0.0728915102882945,
0.017535837721129,
@ -224,7 +221,7 @@ public class MinpackTest
}
public void testMinpackMeyer()
throws EstimationException {
throws OptimizationException {
minpackTest(new MeyerFunction(new double[] { 0.02, 4000.0, 250.0 },
41153.4665543031, 9.37794514651874,
new double[] {
@ -242,7 +239,7 @@ public class MinpackTest
}
public void testMinpackWatson()
throws EstimationException {
throws OptimizationException {
minpackTest(new WatsonFunction(6, 0.0,
5.47722557505166, 0.0478295939097601,
@ -328,13 +325,13 @@ public class MinpackTest
}
public void testMinpackBox3Dimensional()
throws EstimationException {
throws OptimizationException {
minpackTest(new Box3DimensionalFunction(10, new double[] { 0.0, 10.0, 20.0 },
32.1115837449572), false);
}
public void testMinpackJennrichSampson()
throws EstimationException {
throws OptimizationException {
minpackTest(new JennrichSampsonFunction(10, new double[] { 0.3, 0.4 },
64.5856498144943, 11.1517793413499,
new double[] {
@ -343,7 +340,7 @@ public class MinpackTest
}
public void testMinpackBrownDennis()
throws EstimationException {
throws OptimizationException {
minpackTest(new BrownDennisFunction(20,
new double[] { 25.0, 5.0, -5.0, -1.0 },
2815.43839161816, 292.954288244866,
@ -368,7 +365,7 @@ public class MinpackTest
}
public void testMinpackChebyquad()
throws EstimationException {
throws OptimizationException {
minpackTest(new ChebyquadFunction(1, 8, 1.0,
1.88623796907732, 1.88623796907732,
new double[] { 0.5 }), false);
@ -407,7 +404,7 @@ public class MinpackTest
}
public void testMinpackBrownAlmostLinear()
throws EstimationException {
throws OptimizationException {
minpackTest(new BrownAlmostLinearFunction(10, 0.5,
16.5302162063499, 0.0,
new double[] {
@ -476,7 +473,7 @@ public class MinpackTest
}
public void testMinpackOsborne1()
throws EstimationException {
throws OptimizationException {
minpackTest(new Osborne1Function(new double[] { 0.5, 1.5, -1.0, 0.01, 0.02, },
0.937564021037838, 0.00739249260904843,
new double[] {
@ -487,7 +484,7 @@ public class MinpackTest
}
public void testMinpackOsborne2()
throws EstimationException {
throws OptimizationException {
minpackTest(new Osborne2Function(new double[] {
1.3, 0.65, 0.65, 0.7, 0.6,
@ -514,7 +511,7 @@ public class MinpackTest
try {
estimator.estimate(function);
assertFalse(exceptionExpected);
} catch (EstimationException lsse) {
} catch (OptimizationException lsse) {
assertTrue(exceptionExpected);
}
assertTrue(function.checkTheoreticalMinCost(estimator.getRMS(function)));

View File

@ -15,10 +15,8 @@
* limitations under the License.
*/
package org.apache.commons.math.estimation;
package org.apache.commons.math.optimization.general;
import org.apache.commons.math.estimation.EstimatedParameter;
import org.apache.commons.math.estimation.WeightedMeasurement;
import junit.framework.*;

View File

@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.analysis.minimization;
package org.apache.commons.math.optimization.univariate;
import junit.framework.Test;
import junit.framework.TestCase;