MATH-874
First steps to upgrade the API of the classes in "o.a.c.m.optimization.general". Please note the introduction of a "Weight" matrix (whereas the current code assumes that the weights are given as an array). git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1402607 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
2796e77e2e
commit
d0d4760c97
|
@ -23,10 +23,15 @@ import org.apache.commons.math3.exception.TooManyEvaluationsException;
|
|||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.NullArgumentException;
|
||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
||||
import org.apache.commons.math3.optimization.OptimizationData;
|
||||
import org.apache.commons.math3.optimization.InitialGuess;
|
||||
import org.apache.commons.math3.optimization.Target;
|
||||
import org.apache.commons.math3.optimization.Weight;
|
||||
import org.apache.commons.math3.optimization.BaseMultivariateVectorOptimizer;
|
||||
import org.apache.commons.math3.optimization.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optimization.PointVectorValuePair;
|
||||
import org.apache.commons.math3.optimization.SimpleVectorValueChecker;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
|
||||
/**
|
||||
* Base class for implementing optimizers for multivariate scalar functions.
|
||||
|
@ -46,12 +51,16 @@ public abstract class BaseAbstractMultivariateVectorOptimizer<FUNC extends Multi
|
|||
private ConvergenceChecker<PointVectorValuePair> checker;
|
||||
/** Target value for the objective functions at optimum. */
|
||||
private double[] target;
|
||||
/** Weight for the least squares cost computation. */
|
||||
/** Weight matrix. */
|
||||
private RealMatrix weightMatrix;
|
||||
/** Weight for the least squares cost computation.
|
||||
* @deprecated
|
||||
*/
|
||||
private double[] weight;
|
||||
/** Initial guess. */
|
||||
private double[] start;
|
||||
/** Objective function. */
|
||||
private MultivariateVectorFunction function;
|
||||
private FUNC function;
|
||||
|
||||
/**
|
||||
* Simple constructor with default settings.
|
||||
|
@ -101,12 +110,46 @@ public abstract class BaseAbstractMultivariateVectorOptimizer<FUNC extends Multi
|
|||
return function.value(point);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
/** {@inheritDoc}
|
||||
*
|
||||
* @deprecated As of 3.1. Please use
|
||||
* {@link #optimize(int,MultivariateVectorFunction,OptimizationData[])}
|
||||
* instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public PointVectorValuePair optimize(int maxEval, FUNC f, double[] t, double[] w,
|
||||
double[] startPoint) {
|
||||
return optimizeInternal(maxEval, f, t, w, startPoint);
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimize an objective function.
|
||||
*
|
||||
* @param maxEval Allowed number of evaluations of the objective function.
|
||||
* @param f Objective function.
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link Target}</li>
|
||||
* <li>{@link Weight}</li>
|
||||
* <li>{@link InitialGuess}</li>
|
||||
* </ul>
|
||||
* @return the point/value pair giving the optimal value of the objective
|
||||
* function.
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations is exceeded.
|
||||
* @throws DimensionMismatchException if the initial guess, target, and weight
|
||||
* arguments have inconsistent dimensions.
|
||||
*
|
||||
* @since 3.1
|
||||
*/
|
||||
protected PointVectorValuePair optimize(int maxEval,
|
||||
FUNC f,
|
||||
OptimizationData... optData)
|
||||
throws TooManyEvaluationsException,
|
||||
DimensionMismatchException {
|
||||
return optimizeInternal(maxEval, f, optData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimize an objective function.
|
||||
* Optimization is considered to be a weighted least-squares minimization.
|
||||
|
@ -126,8 +169,12 @@ public abstract class BaseAbstractMultivariateVectorOptimizer<FUNC extends Multi
|
|||
* if the maximal number of evaluations is exceeded.
|
||||
* @throws org.apache.commons.math3.exception.NullArgumentException if
|
||||
* any argument is {@code null}.
|
||||
* @deprecated As of 3.1. Please use
|
||||
* {@link #optimizeInternal(int,MultivariateVectorFunction,OptimizationData[])}
|
||||
* instead.
|
||||
*/
|
||||
protected PointVectorValuePair optimizeInternal(final int maxEval, final MultivariateVectorFunction f,
|
||||
@Deprecated
|
||||
protected PointVectorValuePair optimizeInternal(final int maxEval, final FUNC f,
|
||||
final double[] t, final double[] w,
|
||||
final double[] startPoint) {
|
||||
// Checks.
|
||||
|
@ -147,28 +194,122 @@ public abstract class BaseAbstractMultivariateVectorOptimizer<FUNC extends Multi
|
|||
throw new DimensionMismatchException(t.length, w.length);
|
||||
}
|
||||
|
||||
// Reset.
|
||||
evaluations.setMaximalCount(maxEval);
|
||||
evaluations.resetCount();
|
||||
|
||||
// Store optimization problem characteristics.
|
||||
function = f;
|
||||
target = t.clone();
|
||||
weight = w.clone();
|
||||
start = startPoint.clone();
|
||||
|
||||
// Perform computation.
|
||||
return doOptimize();
|
||||
|
||||
return optimizeInternal(maxEval, f,
|
||||
new Target(t),
|
||||
new Weight(w),
|
||||
new InitialGuess(startPoint));
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimize an objective function.
|
||||
*
|
||||
* @param maxEval Allowed number of evaluations of the objective function.
|
||||
* @param f Objective function.
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link Target}</li>
|
||||
* <li>{@link Weight}</li>
|
||||
* <li>{@link InitialGuess}</li>
|
||||
* </ul>
|
||||
* @return the point/value pair giving the optimal value of the objective
|
||||
* function.
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations is exceeded.
|
||||
* @throws DimensionMismatchException if the initial guess, target, and weight
|
||||
* arguments have inconsistent dimensions.
|
||||
*
|
||||
* @since 3.1
|
||||
*/
|
||||
protected PointVectorValuePair optimizeInternal(int maxEval,
|
||||
FUNC f,
|
||||
OptimizationData... optData)
|
||||
throws TooManyEvaluationsException,
|
||||
DimensionMismatchException {
|
||||
// Set internal state.
|
||||
evaluations.setMaximalCount(maxEval);
|
||||
evaluations.resetCount();
|
||||
function = f;
|
||||
// Retrieve other settings.
|
||||
parseOptimizationData(optData);
|
||||
// Check input consistency.
|
||||
checkParameters();
|
||||
// Allow subclasses to reset their own internal state.
|
||||
setUp();
|
||||
// Perform computation.
|
||||
return doOptimize();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the initial values of the optimized parameters.
|
||||
*
|
||||
* @return the initial guess.
|
||||
*/
|
||||
public double[] getStartPoint() {
|
||||
return start.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the weight matrix of the observations.
|
||||
*
|
||||
* @return the weight matrix.
|
||||
* @since 3.1
|
||||
*/
|
||||
public RealMatrix getWeight() {
|
||||
return weightMatrix.copy();
|
||||
}
|
||||
/**
|
||||
* Gets the observed values to be matched by the objective vector
|
||||
* function.
|
||||
*
|
||||
* @return the target values.
|
||||
* @since 3.1
|
||||
*/
|
||||
public double[] getTarget() {
|
||||
return target.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the residuals.
|
||||
* The residual is the difference between the observed (target)
|
||||
* values and the model (objective function) value, for the given
|
||||
* parameters.
|
||||
* There is one residual for each element of the vector-valued
|
||||
* function.
|
||||
*
|
||||
* @param point Parameters of the model.
|
||||
* @return the residuals.
|
||||
* @throws DimensionMismatchException if {@code point} has a wrong
|
||||
* length.
|
||||
* @since 3.1
|
||||
*/
|
||||
protected double[] computeResidual(double[] point) {
|
||||
if (point.length != start.length) {
|
||||
throw new DimensionMismatchException(point.length,
|
||||
start.length);
|
||||
}
|
||||
|
||||
final double[] objective = computeObjectiveValue(point);
|
||||
|
||||
final double[] residuals = new double[target.length];
|
||||
for (int i = 0; i < target.length; i++) {
|
||||
residuals[i] = target[i] - objective[i];
|
||||
}
|
||||
|
||||
return residuals;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Gets the objective vector function.
|
||||
* Note that this access bypasses the evaluation counter.
|
||||
*
|
||||
* @return the objective vector function.
|
||||
* @since 3.1
|
||||
*/
|
||||
protected FUNC getObjectiveFunction() {
|
||||
return function;
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform the bulk of the optimization algorithm.
|
||||
*
|
||||
|
@ -179,14 +320,80 @@ public abstract class BaseAbstractMultivariateVectorOptimizer<FUNC extends Multi
|
|||
|
||||
/**
|
||||
* @return a reference to the {@link #target array}.
|
||||
* @deprecated As of 3.1.
|
||||
*/
|
||||
@Deprecated
|
||||
protected double[] getTargetRef() {
|
||||
return target;
|
||||
}
|
||||
/**
|
||||
* @return a reference to the {@link #weight array}.
|
||||
* @deprecated As of 3.1.
|
||||
*/
|
||||
@Deprecated
|
||||
protected double[] getWeightRef() {
|
||||
return weight;
|
||||
}
|
||||
|
||||
/**
|
||||
* Method which a subclass <em>must</em> override whenever its internal
|
||||
* state depend on the {@link OptimizationData input} parsed by this base
|
||||
* class.
|
||||
* It will be called after the parsing step performed in the
|
||||
* {@link #optimize(int,MultivariateVectorFunction,OptimizationData[])
|
||||
* optimize} method and just before {@link #doOptimize()}.
|
||||
*
|
||||
* @since 3.1
|
||||
*/
|
||||
protected void setUp() {
|
||||
// XXX Temporary code until the new internal data is used everywhere.
|
||||
final int dim = target.length;
|
||||
weight = new double[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
weight[i] = weightMatrix.getEntry(i, i);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Scans the list of (required and optional) optimization data that
|
||||
* characterize the problem.
|
||||
*
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link Target}</li>
|
||||
* <li>{@link Weight}</li>
|
||||
* <li>{@link InitialGuess}</li>
|
||||
* </ul>
|
||||
*/
|
||||
private void parseOptimizationData(OptimizationData... optData) {
|
||||
// The existing values (as set by the previous call) are reused if
|
||||
// not provided in the argument list.
|
||||
for (OptimizationData data : optData) {
|
||||
if (data instanceof Target) {
|
||||
target = ((Target) data).getTarget();
|
||||
continue;
|
||||
}
|
||||
if (data instanceof Weight) {
|
||||
weightMatrix = ((Weight) data).getWeight();
|
||||
continue;
|
||||
}
|
||||
if (data instanceof InitialGuess) {
|
||||
start = ((InitialGuess) data).getInitialGuess();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check parameters consistency.
|
||||
*
|
||||
* @throws DimensionMismatchException if {@link #target} and
|
||||
* {@link #weightMatrix} have inconsistent dimensions.
|
||||
*/
|
||||
private void checkParameters() {
|
||||
if (target.length != weightMatrix.getColumnDimension()) {
|
||||
throw new DimensionMismatchException(target.length,
|
||||
weightMatrix.getColumnDimension());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,6 +28,10 @@ import org.apache.commons.math3.exception.util.LocalizedFormats;
|
|||
import org.apache.commons.math3.linear.DecompositionSolver;
|
||||
import org.apache.commons.math3.linear.MatrixUtils;
|
||||
import org.apache.commons.math3.linear.QRDecomposition;
|
||||
import org.apache.commons.math3.optimization.OptimizationData;
|
||||
import org.apache.commons.math3.optimization.InitialGuess;
|
||||
import org.apache.commons.math3.optimization.Target;
|
||||
import org.apache.commons.math3.optimization.Weight;
|
||||
import org.apache.commons.math3.optimization.ConvergenceChecker;
|
||||
import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer;
|
||||
import org.apache.commons.math3.optimization.PointVectorValuePair;
|
||||
|
@ -308,8 +312,10 @@ public abstract class AbstractLeastSquaresOptimizer
|
|||
}
|
||||
|
||||
/** {@inheritDoc}
|
||||
* @deprecated as of 3.1 replaced by {@link #optimize(int,
|
||||
* MultivariateDifferentiableVectorFunction, double[], double[], double[])}
|
||||
* @deprecated As of 3.1. Please use
|
||||
* {@link BaseAbstractMultivariateVectorOptimizer#optimize(int,MultivariateVectorFunction,OptimizationData[])
|
||||
* optimize(int,MultivariateDifferentiableVectorFunction,OptimizationData...)}
|
||||
* instead.
|
||||
*/
|
||||
@Override
|
||||
@Deprecated
|
||||
|
@ -317,8 +323,11 @@ public abstract class AbstractLeastSquaresOptimizer
|
|||
final DifferentiableMultivariateVectorFunction f,
|
||||
final double[] target, final double[] weights,
|
||||
final double[] startPoint) {
|
||||
return optimize(maxEval, FunctionUtils.toMultivariateDifferentiableVectorFunction(f),
|
||||
target, weights, startPoint);
|
||||
return optimizeInternal(maxEval,
|
||||
FunctionUtils.toMultivariateDifferentiableVectorFunction(f),
|
||||
new Target(target),
|
||||
new Weight(weights),
|
||||
new InitialGuess(startPoint));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -340,29 +349,78 @@ public abstract class AbstractLeastSquaresOptimizer
|
|||
* if the maximal number of evaluations is exceeded.
|
||||
* @throws org.apache.commons.math3.exception.NullArgumentException if
|
||||
* any argument is {@code null}.
|
||||
* @deprecated As of 3.1. Please use
|
||||
* {@link BaseAbstractMultivariateVectorOptimizer#optimize(int,MultivariateVectorFunction,OptimizationData[])
|
||||
* optimize(int,MultivariateDifferentiableVectorFunction,OptimizationData...)}
|
||||
* instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public PointVectorValuePair optimize(final int maxEval,
|
||||
final MultivariateDifferentiableVectorFunction f,
|
||||
final double[] target, final double[] weights,
|
||||
final double[] startPoint) {
|
||||
return optimizeInternal(maxEval, f,
|
||||
new Target(target),
|
||||
new Weight(weights),
|
||||
new InitialGuess(startPoint));
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimize an objective function.
|
||||
* Optimization is considered to be a weighted least-squares minimization.
|
||||
* The cost function to be minimized is
|
||||
* <code>∑weight<sub>i</sub>(objective<sub>i</sub> - target<sub>i</sub>)<sup>2</sup></code>
|
||||
*
|
||||
* @param maxEval Allowed number of evaluations of the objective function.
|
||||
* @param f Objective function.
|
||||
* @param optData Optimization data. The following data will be looked for:
|
||||
* <ul>
|
||||
* <li>{@link Target}</li>
|
||||
* <li>{@link Weight}</li>
|
||||
* <li>{@link InitialGuess}</li>
|
||||
* </ul>
|
||||
* @return the point/value pair giving the optimal value of the objective
|
||||
* function.
|
||||
* @throws TooManyEvaluationsException if the maximal number of
|
||||
* evaluations is exceeded.
|
||||
* @throws DimensionMismatchException if the target, and weight arguments
|
||||
* have inconsistent dimensions.
|
||||
* @see BaseAbstractMultivariateVectorOptimizer#optimizeInternal(int,MultivariateVectorFunction,OptimizationData[])
|
||||
*
|
||||
* @since 3.1
|
||||
*/
|
||||
protected PointVectorValuePair optimizeInternal(final int maxEval,
|
||||
final MultivariateDifferentiableVectorFunction f,
|
||||
OptimizationData... optData) {
|
||||
// XXX Conversion will be removed when the generic argument of the
|
||||
// base class becomes "MultivariateDifferentiableVectorFunction".
|
||||
return super.optimizeInternal(maxEval, FunctionUtils.toDifferentiableMultivariateVectorFunction(f), optData);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
protected void setUp() {
|
||||
super.setUp();
|
||||
|
||||
// Reset counter.
|
||||
jacobianEvaluations = 0;
|
||||
|
||||
// Store least squares problem characteristics.
|
||||
jF = f;
|
||||
// XXX The conversion won't be necessary when the generic argument of
|
||||
// the base class becomes "MultivariateDifferentiableVectorFunction".
|
||||
// XXX "jF" is not strictly necessary anymore but is currently more
|
||||
// efficient than converting the value returned from "getObjectiveFunction()"
|
||||
// every time it is used.
|
||||
jF = FunctionUtils.toMultivariateDifferentiableVectorFunction((DifferentiableMultivariateVectorFunction) getObjectiveFunction());
|
||||
|
||||
// Arrays shared with the other private methods.
|
||||
point = startPoint.clone();
|
||||
rows = target.length;
|
||||
point = getStartPoint();
|
||||
rows = getTarget().length;
|
||||
cols = point.length;
|
||||
|
||||
weightedResidualJacobian = new double[rows][cols];
|
||||
this.weightedResiduals = new double[rows];
|
||||
|
||||
cost = Double.POSITIVE_INFINITY;
|
||||
|
||||
return optimizeInternal(maxEval, f, target, weights, startPoint);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue