First steps to upgrade the API of the classes in "o.a.c.m.optimization.general".
Please note the introduction of a "Weight" matrix (whereas the current code
assumes that the weights are given as an array).


git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1402607 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Gilles Sadowski 2012-10-26 18:19:40 +00:00
parent 2796e77e2e
commit d0d4760c97
2 changed files with 292 additions and 27 deletions

View File

@ -23,10 +23,15 @@ import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NullArgumentException; import org.apache.commons.math3.exception.NullArgumentException;
import org.apache.commons.math3.analysis.MultivariateVectorFunction; import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.optimization.OptimizationData;
import org.apache.commons.math3.optimization.InitialGuess;
import org.apache.commons.math3.optimization.Target;
import org.apache.commons.math3.optimization.Weight;
import org.apache.commons.math3.optimization.BaseMultivariateVectorOptimizer; import org.apache.commons.math3.optimization.BaseMultivariateVectorOptimizer;
import org.apache.commons.math3.optimization.ConvergenceChecker; import org.apache.commons.math3.optimization.ConvergenceChecker;
import org.apache.commons.math3.optimization.PointVectorValuePair; import org.apache.commons.math3.optimization.PointVectorValuePair;
import org.apache.commons.math3.optimization.SimpleVectorValueChecker; import org.apache.commons.math3.optimization.SimpleVectorValueChecker;
import org.apache.commons.math3.linear.RealMatrix;
/** /**
* Base class for implementing optimizers for multivariate scalar functions. * Base class for implementing optimizers for multivariate scalar functions.
@ -46,12 +51,16 @@ public abstract class BaseAbstractMultivariateVectorOptimizer<FUNC extends Multi
private ConvergenceChecker<PointVectorValuePair> checker; private ConvergenceChecker<PointVectorValuePair> checker;
/** Target value for the objective functions at optimum. */ /** Target value for the objective functions at optimum. */
private double[] target; private double[] target;
/** Weight for the least squares cost computation. */ /** Weight matrix. */
private RealMatrix weightMatrix;
/** Weight for the least squares cost computation.
* @deprecated
*/
private double[] weight; private double[] weight;
/** Initial guess. */ /** Initial guess. */
private double[] start; private double[] start;
/** Objective function. */ /** Objective function. */
private MultivariateVectorFunction function; private FUNC function;
/** /**
* Simple constructor with default settings. * Simple constructor with default settings.
@ -101,12 +110,46 @@ public abstract class BaseAbstractMultivariateVectorOptimizer<FUNC extends Multi
return function.value(point); return function.value(point);
} }
/** {@inheritDoc} */ /** {@inheritDoc}
*
* @deprecated As of 3.1. Please use
* {@link #optimize(int,MultivariateVectorFunction,OptimizationData[])}
* instead.
*/
@Deprecated
public PointVectorValuePair optimize(int maxEval, FUNC f, double[] t, double[] w, public PointVectorValuePair optimize(int maxEval, FUNC f, double[] t, double[] w,
double[] startPoint) { double[] startPoint) {
return optimizeInternal(maxEval, f, t, w, startPoint); return optimizeInternal(maxEval, f, t, w, startPoint);
} }
/**
* Optimize an objective function.
*
* @param maxEval Allowed number of evaluations of the objective function.
* @param f Objective function.
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link Target}</li>
* <li>{@link Weight}</li>
* <li>{@link InitialGuess}</li>
* </ul>
* @return the point/value pair giving the optimal value of the objective
* function.
* @throws TooManyEvaluationsException if the maximal number of
* evaluations is exceeded.
* @throws DimensionMismatchException if the initial guess, target, and weight
* arguments have inconsistent dimensions.
*
* @since 3.1
*/
protected PointVectorValuePair optimize(int maxEval,
FUNC f,
OptimizationData... optData)
throws TooManyEvaluationsException,
DimensionMismatchException {
return optimizeInternal(maxEval, f, optData);
}
/** /**
* Optimize an objective function. * Optimize an objective function.
* Optimization is considered to be a weighted least-squares minimization. * Optimization is considered to be a weighted least-squares minimization.
@ -126,8 +169,12 @@ public abstract class BaseAbstractMultivariateVectorOptimizer<FUNC extends Multi
* if the maximal number of evaluations is exceeded. * if the maximal number of evaluations is exceeded.
* @throws org.apache.commons.math3.exception.NullArgumentException if * @throws org.apache.commons.math3.exception.NullArgumentException if
* any argument is {@code null}. * any argument is {@code null}.
* @deprecated As of 3.1. Please use
* {@link #optimizeInternal(int,MultivariateVectorFunction,OptimizationData[])}
* instead.
*/ */
protected PointVectorValuePair optimizeInternal(final int maxEval, final MultivariateVectorFunction f, @Deprecated
protected PointVectorValuePair optimizeInternal(final int maxEval, final FUNC f,
final double[] t, final double[] w, final double[] t, final double[] w,
final double[] startPoint) { final double[] startPoint) {
// Checks. // Checks.
@ -147,28 +194,122 @@ public abstract class BaseAbstractMultivariateVectorOptimizer<FUNC extends Multi
throw new DimensionMismatchException(t.length, w.length); throw new DimensionMismatchException(t.length, w.length);
} }
// Reset. return optimizeInternal(maxEval, f,
evaluations.setMaximalCount(maxEval); new Target(t),
evaluations.resetCount(); new Weight(w),
new InitialGuess(startPoint));
// Store optimization problem characteristics.
function = f;
target = t.clone();
weight = w.clone();
start = startPoint.clone();
// Perform computation.
return doOptimize();
} }
/** /**
* Optimize an objective function.
*
* @param maxEval Allowed number of evaluations of the objective function.
* @param f Objective function.
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link Target}</li>
* <li>{@link Weight}</li>
* <li>{@link InitialGuess}</li>
* </ul>
* @return the point/value pair giving the optimal value of the objective
* function.
* @throws TooManyEvaluationsException if the maximal number of
* evaluations is exceeded.
* @throws DimensionMismatchException if the initial guess, target, and weight
* arguments have inconsistent dimensions.
*
* @since 3.1
*/
protected PointVectorValuePair optimizeInternal(int maxEval,
FUNC f,
OptimizationData... optData)
throws TooManyEvaluationsException,
DimensionMismatchException {
// Set internal state.
evaluations.setMaximalCount(maxEval);
evaluations.resetCount();
function = f;
// Retrieve other settings.
parseOptimizationData(optData);
// Check input consistency.
checkParameters();
// Allow subclasses to reset their own internal state.
setUp();
// Perform computation.
return doOptimize();
}
/**
* Gets the initial values of the optimized parameters.
*
* @return the initial guess. * @return the initial guess.
*/ */
public double[] getStartPoint() { public double[] getStartPoint() {
return start.clone(); return start.clone();
} }
/**
* Gets the weight matrix of the observations.
*
* @return the weight matrix.
* @since 3.1
*/
public RealMatrix getWeight() {
return weightMatrix.copy();
}
/**
* Gets the observed values to be matched by the objective vector
* function.
*
* @return the target values.
* @since 3.1
*/
public double[] getTarget() {
return target.clone();
}
/**
* Computes the residuals.
* The residual is the difference between the observed (target)
* values and the model (objective function) value, for the given
* parameters.
* There is one residual for each element of the vector-valued
* function.
*
* @param point Parameters of the model.
* @return the residuals.
* @throws DimensionMismatchException if {@code point} has a wrong
* length.
* @since 3.1
*/
protected double[] computeResidual(double[] point) {
if (point.length != start.length) {
throw new DimensionMismatchException(point.length,
start.length);
}
final double[] objective = computeObjectiveValue(point);
final double[] residuals = new double[target.length];
for (int i = 0; i < target.length; i++) {
residuals[i] = target[i] - objective[i];
}
return residuals;
}
/**
* Gets the objective vector function.
* Note that this access bypasses the evaluation counter.
*
* @return the objective vector function.
* @since 3.1
*/
protected FUNC getObjectiveFunction() {
return function;
}
/** /**
* Perform the bulk of the optimization algorithm. * Perform the bulk of the optimization algorithm.
* *
@ -179,14 +320,80 @@ public abstract class BaseAbstractMultivariateVectorOptimizer<FUNC extends Multi
/** /**
* @return a reference to the {@link #target array}. * @return a reference to the {@link #target array}.
* @deprecated As of 3.1.
*/ */
@Deprecated
protected double[] getTargetRef() { protected double[] getTargetRef() {
return target; return target;
} }
/** /**
* @return a reference to the {@link #weight array}. * @return a reference to the {@link #weight array}.
* @deprecated As of 3.1.
*/ */
@Deprecated
protected double[] getWeightRef() { protected double[] getWeightRef() {
return weight; return weight;
} }
/**
* Method which a subclass <em>must</em> override whenever its internal
* state depend on the {@link OptimizationData input} parsed by this base
* class.
* It will be called after the parsing step performed in the
* {@link #optimize(int,MultivariateVectorFunction,OptimizationData[])
* optimize} method and just before {@link #doOptimize()}.
*
* @since 3.1
*/
protected void setUp() {
// XXX Temporary code until the new internal data is used everywhere.
final int dim = target.length;
weight = new double[dim];
for (int i = 0; i < dim; i++) {
weight[i] = weightMatrix.getEntry(i, i);
}
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
*
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link Target}</li>
* <li>{@link Weight}</li>
* <li>{@link InitialGuess}</li>
* </ul>
*/
private void parseOptimizationData(OptimizationData... optData) {
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData) {
if (data instanceof Target) {
target = ((Target) data).getTarget();
continue;
}
if (data instanceof Weight) {
weightMatrix = ((Weight) data).getWeight();
continue;
}
if (data instanceof InitialGuess) {
start = ((InitialGuess) data).getInitialGuess();
continue;
}
}
}
/**
* Check parameters consistency.
*
* @throws DimensionMismatchException if {@link #target} and
* {@link #weightMatrix} have inconsistent dimensions.
*/
private void checkParameters() {
if (target.length != weightMatrix.getColumnDimension()) {
throw new DimensionMismatchException(target.length,
weightMatrix.getColumnDimension());
}
}
} }

View File

@ -28,6 +28,10 @@ import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.linear.DecompositionSolver; import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.QRDecomposition; import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.optimization.OptimizationData;
import org.apache.commons.math3.optimization.InitialGuess;
import org.apache.commons.math3.optimization.Target;
import org.apache.commons.math3.optimization.Weight;
import org.apache.commons.math3.optimization.ConvergenceChecker; import org.apache.commons.math3.optimization.ConvergenceChecker;
import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer; import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer;
import org.apache.commons.math3.optimization.PointVectorValuePair; import org.apache.commons.math3.optimization.PointVectorValuePair;
@ -308,8 +312,10 @@ public abstract class AbstractLeastSquaresOptimizer
} }
/** {@inheritDoc} /** {@inheritDoc}
* @deprecated as of 3.1 replaced by {@link #optimize(int, * @deprecated As of 3.1. Please use
* MultivariateDifferentiableVectorFunction, double[], double[], double[])} * {@link BaseAbstractMultivariateVectorOptimizer#optimize(int,MultivariateVectorFunction,OptimizationData[])
* optimize(int,MultivariateDifferentiableVectorFunction,OptimizationData...)}
* instead.
*/ */
@Override @Override
@Deprecated @Deprecated
@ -317,8 +323,11 @@ public abstract class AbstractLeastSquaresOptimizer
final DifferentiableMultivariateVectorFunction f, final DifferentiableMultivariateVectorFunction f,
final double[] target, final double[] weights, final double[] target, final double[] weights,
final double[] startPoint) { final double[] startPoint) {
return optimize(maxEval, FunctionUtils.toMultivariateDifferentiableVectorFunction(f), return optimizeInternal(maxEval,
target, weights, startPoint); FunctionUtils.toMultivariateDifferentiableVectorFunction(f),
new Target(target),
new Weight(weights),
new InitialGuess(startPoint));
} }
/** /**
@ -340,29 +349,78 @@ public abstract class AbstractLeastSquaresOptimizer
* if the maximal number of evaluations is exceeded. * if the maximal number of evaluations is exceeded.
* @throws org.apache.commons.math3.exception.NullArgumentException if * @throws org.apache.commons.math3.exception.NullArgumentException if
* any argument is {@code null}. * any argument is {@code null}.
* @deprecated As of 3.1. Please use
* {@link BaseAbstractMultivariateVectorOptimizer#optimize(int,MultivariateVectorFunction,OptimizationData[])
* optimize(int,MultivariateDifferentiableVectorFunction,OptimizationData...)}
* instead.
*/ */
@Deprecated
public PointVectorValuePair optimize(final int maxEval, public PointVectorValuePair optimize(final int maxEval,
final MultivariateDifferentiableVectorFunction f, final MultivariateDifferentiableVectorFunction f,
final double[] target, final double[] weights, final double[] target, final double[] weights,
final double[] startPoint) { final double[] startPoint) {
return optimizeInternal(maxEval, f,
new Target(target),
new Weight(weights),
new InitialGuess(startPoint));
}
/**
* Optimize an objective function.
* Optimization is considered to be a weighted least-squares minimization.
* The cost function to be minimized is
* <code>&sum;weight<sub>i</sub>(objective<sub>i</sub> - target<sub>i</sub>)<sup>2</sup></code>
*
* @param maxEval Allowed number of evaluations of the objective function.
* @param f Objective function.
* @param optData Optimization data. The following data will be looked for:
* <ul>
* <li>{@link Target}</li>
* <li>{@link Weight}</li>
* <li>{@link InitialGuess}</li>
* </ul>
* @return the point/value pair giving the optimal value of the objective
* function.
* @throws TooManyEvaluationsException if the maximal number of
* evaluations is exceeded.
* @throws DimensionMismatchException if the target, and weight arguments
* have inconsistent dimensions.
* @see BaseAbstractMultivariateVectorOptimizer#optimizeInternal(int,MultivariateVectorFunction,OptimizationData[])
*
* @since 3.1
*/
protected PointVectorValuePair optimizeInternal(final int maxEval,
final MultivariateDifferentiableVectorFunction f,
OptimizationData... optData) {
// XXX Conversion will be removed when the generic argument of the
// base class becomes "MultivariateDifferentiableVectorFunction".
return super.optimizeInternal(maxEval, FunctionUtils.toDifferentiableMultivariateVectorFunction(f), optData);
}
/** {@inheritDoc} */
@Override
protected void setUp() {
super.setUp();
// Reset counter. // Reset counter.
jacobianEvaluations = 0; jacobianEvaluations = 0;
// Store least squares problem characteristics. // Store least squares problem characteristics.
jF = f; // XXX The conversion won't be necessary when the generic argument of
// the base class becomes "MultivariateDifferentiableVectorFunction".
// XXX "jF" is not strictly necessary anymore but is currently more
// efficient than converting the value returned from "getObjectiveFunction()"
// every time it is used.
jF = FunctionUtils.toMultivariateDifferentiableVectorFunction((DifferentiableMultivariateVectorFunction) getObjectiveFunction());
// Arrays shared with the other private methods. // Arrays shared with the other private methods.
point = startPoint.clone(); point = getStartPoint();
rows = target.length; rows = getTarget().length;
cols = point.length; cols = point.length;
weightedResidualJacobian = new double[rows][cols]; weightedResidualJacobian = new double[rows][cols];
this.weightedResiduals = new double[rows]; this.weightedResiduals = new double[rows];
cost = Double.POSITIVE_INFINITY; cost = Double.POSITIVE_INFINITY;
return optimizeInternal(maxEval, f, target, weights, startPoint);
} }
} }