diff --git a/src/main/java/org/apache/commons/math3/optimization/direct/BaseAbstractMultivariateVectorOptimizer.java b/src/main/java/org/apache/commons/math3/optimization/direct/BaseAbstractMultivariateVectorOptimizer.java index 12958f9cb..c807f0cba 100644 --- a/src/main/java/org/apache/commons/math3/optimization/direct/BaseAbstractMultivariateVectorOptimizer.java +++ b/src/main/java/org/apache/commons/math3/optimization/direct/BaseAbstractMultivariateVectorOptimizer.java @@ -23,10 +23,15 @@ import org.apache.commons.math3.exception.TooManyEvaluationsException; import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.NullArgumentException; import org.apache.commons.math3.analysis.MultivariateVectorFunction; +import org.apache.commons.math3.optimization.OptimizationData; +import org.apache.commons.math3.optimization.InitialGuess; +import org.apache.commons.math3.optimization.Target; +import org.apache.commons.math3.optimization.Weight; import org.apache.commons.math3.optimization.BaseMultivariateVectorOptimizer; import org.apache.commons.math3.optimization.ConvergenceChecker; import org.apache.commons.math3.optimization.PointVectorValuePair; import org.apache.commons.math3.optimization.SimpleVectorValueChecker; +import org.apache.commons.math3.linear.RealMatrix; /** * Base class for implementing optimizers for multivariate scalar functions. @@ -46,12 +51,16 @@ public abstract class BaseAbstractMultivariateVectorOptimizer checker; /** Target value for the objective functions at optimum. */ private double[] target; - /** Weight for the least squares cost computation. */ + /** Weight matrix. */ + private RealMatrix weightMatrix; + /** Weight for the least squares cost computation. + * @deprecated + */ private double[] weight; /** Initial guess. */ private double[] start; /** Objective function. */ - private MultivariateVectorFunction function; + private FUNC function; /** * Simple constructor with default settings. @@ -101,12 +110,46 @@ public abstract class BaseAbstractMultivariateVectorOptimizer + *
  • {@link Target}
  • + *
  • {@link Weight}
  • + *
  • {@link InitialGuess}
  • + * + * @return the point/value pair giving the optimal value of the objective + * function. + * @throws TooManyEvaluationsException if the maximal number of + * evaluations is exceeded. + * @throws DimensionMismatchException if the initial guess, target, and weight + * arguments have inconsistent dimensions. + * + * @since 3.1 + */ + protected PointVectorValuePair optimize(int maxEval, + FUNC f, + OptimizationData... optData) + throws TooManyEvaluationsException, + DimensionMismatchException { + return optimizeInternal(maxEval, f, optData); + } + /** * Optimize an objective function. * Optimization is considered to be a weighted least-squares minimization. @@ -126,8 +169,12 @@ public abstract class BaseAbstractMultivariateVectorOptimizer + *
  • {@link Target}
  • + *
  • {@link Weight}
  • + *
  • {@link InitialGuess}
  • + * + * @return the point/value pair giving the optimal value of the objective + * function. + * @throws TooManyEvaluationsException if the maximal number of + * evaluations is exceeded. + * @throws DimensionMismatchException if the initial guess, target, and weight + * arguments have inconsistent dimensions. + * + * @since 3.1 + */ + protected PointVectorValuePair optimizeInternal(int maxEval, + FUNC f, + OptimizationData... optData) + throws TooManyEvaluationsException, + DimensionMismatchException { + // Set internal state. + evaluations.setMaximalCount(maxEval); + evaluations.resetCount(); + function = f; + // Retrieve other settings. + parseOptimizationData(optData); + // Check input consistency. + checkParameters(); + // Allow subclasses to reset their own internal state. + setUp(); + // Perform computation. + return doOptimize(); + } + + /** + * Gets the initial values of the optimized parameters. + * * @return the initial guess. */ public double[] getStartPoint() { return start.clone(); } + /** + * Gets the weight matrix of the observations. + * + * @return the weight matrix. + * @since 3.1 + */ + public RealMatrix getWeight() { + return weightMatrix.copy(); + } + /** + * Gets the observed values to be matched by the objective vector + * function. + * + * @return the target values. + * @since 3.1 + */ + public double[] getTarget() { + return target.clone(); + } + + /** + * Computes the residuals. + * The residual is the difference between the observed (target) + * values and the model (objective function) value, for the given + * parameters. + * There is one residual for each element of the vector-valued + * function. + * + * @param point Parameters of the model. + * @return the residuals. + * @throws DimensionMismatchException if {@code point} has a wrong + * length. + * @since 3.1 + */ + protected double[] computeResidual(double[] point) { + if (point.length != start.length) { + throw new DimensionMismatchException(point.length, + start.length); + } + + final double[] objective = computeObjectiveValue(point); + + final double[] residuals = new double[target.length]; + for (int i = 0; i < target.length; i++) { + residuals[i] = target[i] - objective[i]; + } + + return residuals; + } + + + /** + * Gets the objective vector function. + * Note that this access bypasses the evaluation counter. + * + * @return the objective vector function. + * @since 3.1 + */ + protected FUNC getObjectiveFunction() { + return function; + } + /** * Perform the bulk of the optimization algorithm. * @@ -179,14 +320,80 @@ public abstract class BaseAbstractMultivariateVectorOptimizermust override whenever its internal + * state depend on the {@link OptimizationData input} parsed by this base + * class. + * It will be called after the parsing step performed in the + * {@link #optimize(int,MultivariateVectorFunction,OptimizationData[]) + * optimize} method and just before {@link #doOptimize()}. + * + * @since 3.1 + */ + protected void setUp() { + // XXX Temporary code until the new internal data is used everywhere. + final int dim = target.length; + weight = new double[dim]; + for (int i = 0; i < dim; i++) { + weight[i] = weightMatrix.getEntry(i, i); + } + } + + /** + * Scans the list of (required and optional) optimization data that + * characterize the problem. + * + * @param optData Optimization data. The following data will be looked for: + *
      + *
    • {@link Target}
    • + *
    • {@link Weight}
    • + *
    • {@link InitialGuess}
    • + *
    + */ + private void parseOptimizationData(OptimizationData... optData) { + // The existing values (as set by the previous call) are reused if + // not provided in the argument list. + for (OptimizationData data : optData) { + if (data instanceof Target) { + target = ((Target) data).getTarget(); + continue; + } + if (data instanceof Weight) { + weightMatrix = ((Weight) data).getWeight(); + continue; + } + if (data instanceof InitialGuess) { + start = ((InitialGuess) data).getInitialGuess(); + continue; + } + } + } + + /** + * Check parameters consistency. + * + * @throws DimensionMismatchException if {@link #target} and + * {@link #weightMatrix} have inconsistent dimensions. + */ + private void checkParameters() { + if (target.length != weightMatrix.getColumnDimension()) { + throw new DimensionMismatchException(target.length, + weightMatrix.getColumnDimension()); + } + } } diff --git a/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java b/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java index 3af449ce2..1288af11a 100644 --- a/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java +++ b/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java @@ -28,6 +28,10 @@ import org.apache.commons.math3.exception.util.LocalizedFormats; import org.apache.commons.math3.linear.DecompositionSolver; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.QRDecomposition; +import org.apache.commons.math3.optimization.OptimizationData; +import org.apache.commons.math3.optimization.InitialGuess; +import org.apache.commons.math3.optimization.Target; +import org.apache.commons.math3.optimization.Weight; import org.apache.commons.math3.optimization.ConvergenceChecker; import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer; import org.apache.commons.math3.optimization.PointVectorValuePair; @@ -308,8 +312,10 @@ public abstract class AbstractLeastSquaresOptimizer } /** {@inheritDoc} - * @deprecated as of 3.1 replaced by {@link #optimize(int, - * MultivariateDifferentiableVectorFunction, double[], double[], double[])} + * @deprecated As of 3.1. Please use + * {@link BaseAbstractMultivariateVectorOptimizer#optimize(int,MultivariateVectorFunction,OptimizationData[]) + * optimize(int,MultivariateDifferentiableVectorFunction,OptimizationData...)} + * instead. */ @Override @Deprecated @@ -317,8 +323,11 @@ public abstract class AbstractLeastSquaresOptimizer final DifferentiableMultivariateVectorFunction f, final double[] target, final double[] weights, final double[] startPoint) { - return optimize(maxEval, FunctionUtils.toMultivariateDifferentiableVectorFunction(f), - target, weights, startPoint); + return optimizeInternal(maxEval, + FunctionUtils.toMultivariateDifferentiableVectorFunction(f), + new Target(target), + new Weight(weights), + new InitialGuess(startPoint)); } /** @@ -340,29 +349,78 @@ public abstract class AbstractLeastSquaresOptimizer * if the maximal number of evaluations is exceeded. * @throws org.apache.commons.math3.exception.NullArgumentException if * any argument is {@code null}. + * @deprecated As of 3.1. Please use + * {@link BaseAbstractMultivariateVectorOptimizer#optimize(int,MultivariateVectorFunction,OptimizationData[]) + * optimize(int,MultivariateDifferentiableVectorFunction,OptimizationData...)} + * instead. */ + @Deprecated public PointVectorValuePair optimize(final int maxEval, final MultivariateDifferentiableVectorFunction f, final double[] target, final double[] weights, final double[] startPoint) { + return optimizeInternal(maxEval, f, + new Target(target), + new Weight(weights), + new InitialGuess(startPoint)); + } + + /** + * Optimize an objective function. + * Optimization is considered to be a weighted least-squares minimization. + * The cost function to be minimized is + * ∑weighti(objectivei - targeti)2 + * + * @param maxEval Allowed number of evaluations of the objective function. + * @param f Objective function. + * @param optData Optimization data. The following data will be looked for: + *
      + *
    • {@link Target}
    • + *
    • {@link Weight}
    • + *
    • {@link InitialGuess}
    • + *
    + * @return the point/value pair giving the optimal value of the objective + * function. + * @throws TooManyEvaluationsException if the maximal number of + * evaluations is exceeded. + * @throws DimensionMismatchException if the target, and weight arguments + * have inconsistent dimensions. + * @see BaseAbstractMultivariateVectorOptimizer#optimizeInternal(int,MultivariateVectorFunction,OptimizationData[]) + * + * @since 3.1 + */ + protected PointVectorValuePair optimizeInternal(final int maxEval, + final MultivariateDifferentiableVectorFunction f, + OptimizationData... optData) { + // XXX Conversion will be removed when the generic argument of the + // base class becomes "MultivariateDifferentiableVectorFunction". + return super.optimizeInternal(maxEval, FunctionUtils.toDifferentiableMultivariateVectorFunction(f), optData); + } + + /** {@inheritDoc} */ + @Override + protected void setUp() { + super.setUp(); // Reset counter. jacobianEvaluations = 0; // Store least squares problem characteristics. - jF = f; + // XXX The conversion won't be necessary when the generic argument of + // the base class becomes "MultivariateDifferentiableVectorFunction". + // XXX "jF" is not strictly necessary anymore but is currently more + // efficient than converting the value returned from "getObjectiveFunction()" + // every time it is used. + jF = FunctionUtils.toMultivariateDifferentiableVectorFunction((DifferentiableMultivariateVectorFunction) getObjectiveFunction()); // Arrays shared with the other private methods. - point = startPoint.clone(); - rows = target.length; + point = getStartPoint(); + rows = getTarget().length; cols = point.length; weightedResidualJacobian = new double[rows][cols]; this.weightedResiduals = new double[rows]; cost = Double.POSITIVE_INFINITY; - - return optimizeInternal(maxEval, f, target, weights, startPoint); } - }