added an implementation of a non-linear conjugate gradient optimizer
git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@758059 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
ed8fc4a03a
commit
2c84116a4c
|
@ -0,0 +1,187 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math.optimization.general;
|
||||
|
||||
import org.apache.commons.math.FunctionEvaluationException;
|
||||
import org.apache.commons.math.MaxIterationsExceededException;
|
||||
import org.apache.commons.math.analysis.DifferentiableMultivariateRealFunction;
|
||||
import org.apache.commons.math.analysis.MultivariateVectorialFunction;
|
||||
import org.apache.commons.math.optimization.GoalType;
|
||||
import org.apache.commons.math.optimization.OptimizationException;
|
||||
import org.apache.commons.math.optimization.RealConvergenceChecker;
|
||||
import org.apache.commons.math.optimization.DifferentiableMultivariateRealOptimizer;
|
||||
import org.apache.commons.math.optimization.RealPointValuePair;
|
||||
import org.apache.commons.math.optimization.SimpleScalarValueChecker;
|
||||
|
||||
/**
|
||||
* Base class for implementing optimizers for multivariate scalar functions.
|
||||
* <p>This base class handles the boilerplate methods associated to thresholds
|
||||
* settings, iterations and evaluations counting.</p>
|
||||
* @version $Revision$ $Date$
|
||||
* @since 2.0
|
||||
*/
|
||||
public abstract class AbstractScalarDifferentiableOptimizer
|
||||
implements DifferentiableMultivariateRealOptimizer{
|
||||
|
||||
/** Default maximal number of iterations allowed. */
|
||||
public static final int DEFAULT_MAX_ITERATIONS = 100;
|
||||
|
||||
/** Serializable version identifier. */
|
||||
private static final long serialVersionUID = 1357126012308766636L;
|
||||
|
||||
/** Maximal number of iterations allowed. */
|
||||
private int maxIterations;
|
||||
|
||||
/** Number of iterations already performed. */
|
||||
private int iterations;
|
||||
|
||||
/** Number of evaluations already performed. */
|
||||
private int evaluations;
|
||||
|
||||
/** Number of gradient evaluations. */
|
||||
private int gradientEvaluations;
|
||||
|
||||
/** Convergence checker. */
|
||||
protected RealConvergenceChecker checker;
|
||||
|
||||
/** Objective function. */
|
||||
private DifferentiableMultivariateRealFunction f;
|
||||
|
||||
/** Objective function gradient. */
|
||||
private MultivariateVectorialFunction gradient;
|
||||
|
||||
/** Type of optimization. */
|
||||
protected GoalType goalType;
|
||||
|
||||
/** Current point set. */
|
||||
protected double[] point;
|
||||
|
||||
/** Simple constructor with default settings.
|
||||
* <p>The convergence check is set to a {@link SimpleScalarValueChecker}
|
||||
* and the maximal number of evaluation is set to its default value.</p>
|
||||
*/
|
||||
protected AbstractScalarDifferentiableOptimizer() {
|
||||
setConvergenceChecker(new SimpleScalarValueChecker());
|
||||
setMaxIterations(DEFAULT_MAX_ITERATIONS);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public void setMaxIterations(int maxIterations) {
|
||||
this.maxIterations = maxIterations;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public int getMaxIterations() {
|
||||
return maxIterations;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public int getIterations() {
|
||||
return iterations;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public int getEvaluations() {
|
||||
return evaluations;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public int getGradientEvaluations() {
|
||||
return gradientEvaluations;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public void setConvergenceChecker(RealConvergenceChecker checker) {
|
||||
this.checker = checker;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public RealConvergenceChecker getConvergenceChecker() {
|
||||
return checker;
|
||||
}
|
||||
|
||||
/** Increment the iterations counter by 1.
|
||||
* @exception OptimizationException if the maximal number
|
||||
* of iterations is exceeded
|
||||
*/
|
||||
protected void incrementIterationsCounter()
|
||||
throws OptimizationException {
|
||||
if (++iterations > maxIterations) {
|
||||
if (++iterations > maxIterations) {
|
||||
throw new OptimizationException(new MaxIterationsExceededException(maxIterations));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the gradient vector.
|
||||
* @param point point at which the gradient must be evaluated
|
||||
* @return gradient at the specified point
|
||||
* @exception FunctionEvaluationException if the function gradient
|
||||
*/
|
||||
protected double[] computeObjectiveGradient(final double[] point)
|
||||
throws FunctionEvaluationException {
|
||||
++gradientEvaluations;
|
||||
return gradient.value(point);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the objective function value.
|
||||
* @param point point at which the objective function must be evaluated
|
||||
* @return objective function value at specified point
|
||||
* @exception FunctionEvaluationException if the function cannot be evaluated
|
||||
* or its dimension doesn't match problem dimension
|
||||
*/
|
||||
protected double computeObjectiveValue(final double[] point)
|
||||
throws FunctionEvaluationException {
|
||||
++evaluations;
|
||||
return f.value(point);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public RealPointValuePair optimize(final DifferentiableMultivariateRealFunction f,
|
||||
final GoalType goalType,
|
||||
final double[] startPoint)
|
||||
throws FunctionEvaluationException, OptimizationException, IllegalArgumentException {
|
||||
|
||||
// reset counters
|
||||
iterations = 0;
|
||||
evaluations = 0;
|
||||
gradientEvaluations = 0;
|
||||
|
||||
// store optimization problem characteristics
|
||||
this.f = f;
|
||||
gradient = f.gradient();
|
||||
this.goalType = goalType;
|
||||
point = startPoint.clone();
|
||||
|
||||
return doOptimize();
|
||||
|
||||
}
|
||||
|
||||
/** Perform the bulk of optimization algorithm.
|
||||
* @return the point/value pair giving the optimal value for objective function
|
||||
* @exception FunctionEvaluationException if the objective function throws one during
|
||||
* the search
|
||||
* @exception OptimizationException if the algorithm failed to converge
|
||||
* @exception IllegalArgumentException if the start point dimension is wrong
|
||||
*/
|
||||
abstract protected RealPointValuePair doOptimize()
|
||||
throws FunctionEvaluationException, OptimizationException, IllegalArgumentException;
|
||||
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math.optimization.general;
|
||||
|
||||
/**
|
||||
* Available choices of update formulas for the β parameter
|
||||
* in {@link NonLinearConjugateGradientOptimizer}.
|
||||
* <p>
|
||||
* The β parameter is used to compute the successive conjugate
|
||||
* search directions. For non-linear conjugate gradients, there are
|
||||
* two formulas to compute β:
|
||||
* <ul>
|
||||
* <li>Fletcher-Reeves formula</li>
|
||||
* <li>Polak-Ribière formula</li>
|
||||
* </ul>
|
||||
* On the one hand, the Fletcher-Reeves formula is guaranteed to converge
|
||||
* if the start point is close enough of the optimum whether the
|
||||
* Polak-Ribière formula may not converge in rare cases. On the
|
||||
* other hand, the Polak-Ribière formula is often faster when it
|
||||
* does converge. Polak-Ribière is often used.
|
||||
* <p>
|
||||
* @see NonLinearConjugateGradientOptimizer
|
||||
* @version $Revision$ $Date$
|
||||
* @since 2.0
|
||||
*/
|
||||
public enum ConjugateGradientFormula {
|
||||
|
||||
/** Fletcher-Reeves formula. */
|
||||
FLETCHER_REEVES,
|
||||
|
||||
/** Polak-Ribière formula. */
|
||||
POLAK_RIBIERE
|
||||
|
||||
}
|
|
@ -0,0 +1,302 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math.optimization.general;
|
||||
|
||||
import org.apache.commons.math.ConvergenceException;
|
||||
import org.apache.commons.math.FunctionEvaluationException;
|
||||
import org.apache.commons.math.analysis.UnivariateRealFunction;
|
||||
import org.apache.commons.math.analysis.solvers.BrentSolver;
|
||||
import org.apache.commons.math.analysis.solvers.UnivariateRealSolver;
|
||||
import org.apache.commons.math.optimization.GoalType;
|
||||
import org.apache.commons.math.optimization.OptimizationException;
|
||||
import org.apache.commons.math.optimization.DifferentiableMultivariateRealOptimizer;
|
||||
import org.apache.commons.math.optimization.RealPointValuePair;
|
||||
import org.apache.commons.math.optimization.SimpleVectorialValueChecker;
|
||||
|
||||
/**
|
||||
* Non-linear conjugate gradient optimizer.
|
||||
* <p>
|
||||
* This class supports both the Fletcher-Reeves and the Polak-Ribière
|
||||
* update formulas for the conjugate search directions. It also supports
|
||||
* optional preconditioning.
|
||||
* </p>
|
||||
*
|
||||
* @version $Revision$ $Date$
|
||||
* @since 2.0
|
||||
*
|
||||
*/
|
||||
|
||||
public class NonLinearConjugateGradientOptimizer
|
||||
extends AbstractScalarDifferentiableOptimizer
|
||||
implements DifferentiableMultivariateRealOptimizer {
|
||||
|
||||
/** Serializable version identifier. */
|
||||
private static final long serialVersionUID = -6545223926568155458L;
|
||||
|
||||
/** Update formula for the beta parameter. */
|
||||
private final ConjugateGradientFormula updateFormula;
|
||||
|
||||
/** Preconditioner (may be null). */
|
||||
private Preconditioner preconditioner;
|
||||
|
||||
/** solver to use in the line search (may be null). */
|
||||
private UnivariateRealSolver solver;
|
||||
|
||||
/** Initial step used to bracket the optimum in line search. */
|
||||
private double initialStep;
|
||||
|
||||
/** Simple constructor with default settings.
|
||||
* <p>The convergence check is set to a {@link SimpleVectorialValueChecker}
|
||||
* and the maximal number of evaluation is set to
|
||||
* {@link AbstractLeastSquaresOptimizer#DEFAULT_MAX_EVALUATIONS}.
|
||||
* @param updateFormula formula to use for updating the β parameter,
|
||||
* must be one of {@link UpdateFormula#FLETCHER_REEVES} or {@link
|
||||
* UpdateFormula#POLAK_RIBIERE}
|
||||
*/
|
||||
public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula) {
|
||||
this.updateFormula = updateFormula;
|
||||
preconditioner = null;
|
||||
solver = null;
|
||||
initialStep = 1.0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the preconditioner.
|
||||
* @param preconditioner preconditioner to use for next optimization,
|
||||
* may be null to remove an already registered preconditioner
|
||||
*/
|
||||
public void setPreconditioner(final Preconditioner preconditioner) {
|
||||
this.preconditioner = preconditioner;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the solver to use during line search.
|
||||
* @param solver solver to use during line search, may be null
|
||||
* to remove an already registered solver and fall back to the
|
||||
* default {@link BrentSolver Brent solver}.
|
||||
*/
|
||||
public void setLineSearchSolver(final UnivariateRealSolver solver) {
|
||||
this.solver = solver;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the initial step used to bracket the optimum in line search.
|
||||
* <p>
|
||||
* The initial step is a factor with respect to the search direction,
|
||||
* which itself is roughly related to the gradient of the function
|
||||
* </p>
|
||||
* @param initialStep initial step used to bracket the optimum in line search,
|
||||
* if a non-positive value is used, the initial step is reset to its
|
||||
* default value of 1.0
|
||||
*/
|
||||
public void setInitialStep(final double initialStep) {
|
||||
if (initialStep <= 0) {
|
||||
this.initialStep = 1.0;
|
||||
} else {
|
||||
this.initialStep = initialStep;
|
||||
}
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
protected RealPointValuePair doOptimize()
|
||||
throws FunctionEvaluationException, OptimizationException, IllegalArgumentException {
|
||||
try {
|
||||
|
||||
// initialization
|
||||
if (preconditioner == null) {
|
||||
preconditioner = new IdentityPreconditioner();
|
||||
}
|
||||
if (solver == null) {
|
||||
solver = new BrentSolver();
|
||||
}
|
||||
final int n = point.length;
|
||||
double[] r = computeObjectiveGradient(point);
|
||||
if (goalType == GoalType.MINIMIZE) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
r[i] = -r[i];
|
||||
}
|
||||
}
|
||||
|
||||
// initial search direction
|
||||
double[] steepestDescent = preconditioner.precondition(point, r);
|
||||
double[] searchDirection = steepestDescent.clone();
|
||||
|
||||
double delta = 0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
delta += r[i] * searchDirection[i];
|
||||
}
|
||||
|
||||
RealPointValuePair current = null;
|
||||
while (true) {
|
||||
|
||||
final double objective = computeObjectiveValue(point);
|
||||
RealPointValuePair previous = current;
|
||||
current = new RealPointValuePair(point, objective);
|
||||
if (previous != null) {
|
||||
if (checker.converged(getIterations(), previous, current)) {
|
||||
// we have found an optimum
|
||||
return current;
|
||||
}
|
||||
}
|
||||
|
||||
incrementIterationsCounter();
|
||||
|
||||
double dTd = 0;
|
||||
for (final double di : searchDirection) {
|
||||
dTd += di * di;
|
||||
}
|
||||
|
||||
// find the optimal step in the search direction
|
||||
final UnivariateRealFunction lsf = new LineSearchFunction(searchDirection);
|
||||
final double step = solver.solve(lsf, 0, findUpperBound(lsf, 0, initialStep));
|
||||
|
||||
// validate new point
|
||||
for (int i = 0; i < point.length; ++i) {
|
||||
point[i] += step * searchDirection[i];
|
||||
}
|
||||
r = computeObjectiveGradient(point);
|
||||
if (goalType == GoalType.MINIMIZE) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
r[i] = -r[i];
|
||||
}
|
||||
}
|
||||
|
||||
// compute beta
|
||||
final double deltaOld = delta;
|
||||
final double[] newSteepestDescent = preconditioner.precondition(point, r);
|
||||
delta = 0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
delta += r[i] * newSteepestDescent[i];
|
||||
}
|
||||
|
||||
final double beta;
|
||||
if (updateFormula == ConjugateGradientFormula.FLETCHER_REEVES) {
|
||||
beta = delta / deltaOld;
|
||||
} else {
|
||||
double deltaMid = 0;
|
||||
for (int i = 0; i < r.length; ++i) {
|
||||
deltaMid += r[i] * steepestDescent[i];
|
||||
}
|
||||
beta = (delta - deltaMid) / deltaOld;
|
||||
}
|
||||
steepestDescent = newSteepestDescent;
|
||||
|
||||
// compute conjugate search direction
|
||||
if ((getIterations() % n == 0) || (beta < 0)) {
|
||||
// break conjugation: reset search direction
|
||||
searchDirection = steepestDescent.clone();
|
||||
} else {
|
||||
// compute new conjugate search direction
|
||||
for (int i = 0; i < n; ++i) {
|
||||
searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} catch (ConvergenceException ce) {
|
||||
throw new OptimizationException(ce);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the upper bound b ensuring bracketing of a root between a and b
|
||||
* @param f function whose root must be bracketed
|
||||
* @param a lower bound of the interval
|
||||
* @param h initial step to try
|
||||
* @return b such that f(a) and f(b) have opposite signs
|
||||
* @exception FunctionEvaluationException if the function cannot be computed
|
||||
* @exception OptimizationException if no bracket can be found
|
||||
*/
|
||||
private double findUpperBound(final UnivariateRealFunction f,
|
||||
final double a, final double h)
|
||||
throws FunctionEvaluationException, OptimizationException {
|
||||
final double yA = f.value(a);
|
||||
double yB = yA;
|
||||
for (double step = h; step < Double.MAX_VALUE; step *= Math.max(2, yA / yB)) {
|
||||
final double b = a + step;
|
||||
yB = f.value(b);
|
||||
if (yA * yB <= 0) {
|
||||
return b;
|
||||
}
|
||||
}
|
||||
throw new OptimizationException("unable to bracket optimum in line search");
|
||||
}
|
||||
|
||||
/** Default identity preconditioner. */
|
||||
private static class IdentityPreconditioner implements Preconditioner {
|
||||
|
||||
/** Serializable version identifier. */
|
||||
private static final long serialVersionUID = 1868235977809734023L;
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double[] precondition(double[] variables, double[] r) {
|
||||
return r.clone();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/** Internal class for line search.
|
||||
* <p>
|
||||
* The function represented by this class is the dot product of
|
||||
* the objective function gradient and the search direction. Its
|
||||
* value is zero when the gradient is orthogonal to the search
|
||||
* direction, i.e. when the objective function value is a local
|
||||
* extremum along the search direction.
|
||||
* </p>
|
||||
*/
|
||||
private class LineSearchFunction implements UnivariateRealFunction {
|
||||
|
||||
/** Serializable version identifier. */
|
||||
private static final long serialVersionUID = 8184683950487801424L;
|
||||
|
||||
/** Search direction. */
|
||||
private final double[] searchDirection;
|
||||
|
||||
/** Simple constructor.
|
||||
* @param searchDirection search direction
|
||||
*/
|
||||
public LineSearchFunction(final double[] searchDirection) {
|
||||
this.searchDirection = searchDirection;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double value(double x) throws FunctionEvaluationException {
|
||||
|
||||
// current point in the search direction
|
||||
final double[] shiftedPoint = point.clone();
|
||||
for (int i = 0; i < shiftedPoint.length; ++i) {
|
||||
shiftedPoint[i] += x * searchDirection[i];
|
||||
}
|
||||
|
||||
// gradient of the objective function
|
||||
final double[] gradient = computeObjectiveGradient(shiftedPoint);
|
||||
|
||||
// dot product with the search direction
|
||||
double dotProduct = 0;
|
||||
for (int i = 0; i < gradient.length; ++i) {
|
||||
dotProduct += gradient[i] * searchDirection[i];
|
||||
}
|
||||
|
||||
return dotProduct;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math.optimization.general;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import org.apache.commons.math.FunctionEvaluationException;
|
||||
|
||||
/**
|
||||
* This interface represents a preconditioner for differentiable scalar
|
||||
* objective function optimizers.
|
||||
* @version $Revision$ $Date$
|
||||
* @since 2.0
|
||||
*/
|
||||
public interface Preconditioner extends Serializable {
|
||||
|
||||
/**
|
||||
* Precondition a search direction.
|
||||
* <p>
|
||||
* The returned preconditioned search direction must be computed fast or
|
||||
* the algorithm performances will drop drastically. A classical approach
|
||||
* is to compute only the diagonal elements of the hessian and to divide
|
||||
* the raw search direction by these elements if they are all positive.
|
||||
* If at least one of them is negative, it is safer to return a clone of
|
||||
* the raw search direction as if the hessian was the identity matrix. The
|
||||
* rationale for this simplified choice is that a negative diagonal element
|
||||
* means the current point is far from the optimum and preconditioning will
|
||||
* not be efficient anyway in this case.
|
||||
* </p>
|
||||
* @param point current point at which the search direction was computed
|
||||
* @param r raw search direction (i.e. opposite of the gradient)
|
||||
* @return approximation of H<sup>-1</sup>r where H is the objective function hessian
|
||||
* @exception FunctionEvaluationException if no cost can be computed for the parameters
|
||||
* @exception IllegalArgumentException if point dimension is wrong
|
||||
*/
|
||||
double[] precondition(double[] point, double[] r)
|
||||
throws FunctionEvaluationException, IllegalArgumentException;
|
||||
|
||||
}
|
|
@ -41,6 +41,7 @@ The <action> type attribute can be add,update,fix,remove.
|
|||
<release version="2.0" date="TBD" description="TBD">
|
||||
<action dev="luc" type="fix" issue="MATH-177" >
|
||||
Redesigned the optimization framework for a simpler yet more powerful API.
|
||||
Added non-linear conjugate gradient optimizer.
|
||||
</action>
|
||||
<action dev="luc" type="fix" issue="MATH-243" due-to="Christian Semrau">
|
||||
Fixed an error in computing gcd and lcm for some extreme values at integer
|
||||
|
|
|
@ -0,0 +1,505 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math.optimization.general;
|
||||
|
||||
import java.awt.geom.Point2D;
|
||||
import java.util.ArrayList;
|
||||
|
||||
import junit.framework.Test;
|
||||
import junit.framework.TestCase;
|
||||
import junit.framework.TestSuite;
|
||||
|
||||
import org.apache.commons.math.FunctionEvaluationException;
|
||||
import org.apache.commons.math.analysis.DifferentiableMultivariateRealFunction;
|
||||
import org.apache.commons.math.analysis.MultivariateRealFunction;
|
||||
import org.apache.commons.math.analysis.MultivariateVectorialFunction;
|
||||
import org.apache.commons.math.analysis.solvers.BrentSolver;
|
||||
import org.apache.commons.math.linear.DenseRealMatrix;
|
||||
import org.apache.commons.math.linear.RealMatrix;
|
||||
import org.apache.commons.math.optimization.GoalType;
|
||||
import org.apache.commons.math.optimization.OptimizationException;
|
||||
import org.apache.commons.math.optimization.RealPointValuePair;
|
||||
import org.apache.commons.math.optimization.SimpleScalarValueChecker;
|
||||
|
||||
/**
|
||||
* <p>Some of the unit tests are re-implementations of the MINPACK <a
|
||||
* href="http://www.netlib.org/minpack/ex/file17">file17</a> and <a
|
||||
* href="http://www.netlib.org/minpack/ex/file22">file22</a> test files.
|
||||
* The redistribution policy for MINPACK is available <a
|
||||
* href="http://www.netlib.org/minpack/disclaimer">here</a>, for
|
||||
* convenience, it is reproduced below.</p>
|
||||
|
||||
* <table border="0" width="80%" cellpadding="10" align="center" bgcolor="#E0E0E0">
|
||||
* <tr><td>
|
||||
* Minpack Copyright Notice (1999) University of Chicago.
|
||||
* All rights reserved
|
||||
* </td></tr>
|
||||
* <tr><td>
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions
|
||||
* are met:
|
||||
* <ol>
|
||||
* <li>Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.</li>
|
||||
* <li>Redistributions in binary form must reproduce the above
|
||||
* copyright notice, this list of conditions and the following
|
||||
* disclaimer in the documentation and/or other materials provided
|
||||
* with the distribution.</li>
|
||||
* <li>The end-user documentation included with the redistribution, if any,
|
||||
* must include the following acknowledgment:
|
||||
* <code>This product includes software developed by the University of
|
||||
* Chicago, as Operator of Argonne National Laboratory.</code>
|
||||
* Alternately, this acknowledgment may appear in the software itself,
|
||||
* if and wherever such third-party acknowledgments normally appear.</li>
|
||||
* <li><strong>WARRANTY DISCLAIMER. THE SOFTWARE IS SUPPLIED "AS IS"
|
||||
* WITHOUT WARRANTY OF ANY KIND. THE COPYRIGHT HOLDER, THE
|
||||
* UNITED STATES, THE UNITED STATES DEPARTMENT OF ENERGY, AND
|
||||
* THEIR EMPLOYEES: (1) DISCLAIM ANY WARRANTIES, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED WARRANTIES
|
||||
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE
|
||||
* OR NON-INFRINGEMENT, (2) DO NOT ASSUME ANY LEGAL LIABILITY
|
||||
* OR RESPONSIBILITY FOR THE ACCURACY, COMPLETENESS, OR
|
||||
* USEFULNESS OF THE SOFTWARE, (3) DO NOT REPRESENT THAT USE OF
|
||||
* THE SOFTWARE WOULD NOT INFRINGE PRIVATELY OWNED RIGHTS, (4)
|
||||
* DO NOT WARRANT THAT THE SOFTWARE WILL FUNCTION
|
||||
* UNINTERRUPTED, THAT IT IS ERROR-FREE OR THAT ANY ERRORS WILL
|
||||
* BE CORRECTED.</strong></li>
|
||||
* <li><strong>LIMITATION OF LIABILITY. IN NO EVENT WILL THE COPYRIGHT
|
||||
* HOLDER, THE UNITED STATES, THE UNITED STATES DEPARTMENT OF
|
||||
* ENERGY, OR THEIR EMPLOYEES: BE LIABLE FOR ANY INDIRECT,
|
||||
* INCIDENTAL, CONSEQUENTIAL, SPECIAL OR PUNITIVE DAMAGES OF
|
||||
* ANY KIND OR NATURE, INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
* PROFITS OR LOSS OF DATA, FOR ANY REASON WHATSOEVER, WHETHER
|
||||
* SUCH LIABILITY IS ASSERTED ON THE BASIS OF CONTRACT, TORT
|
||||
* (INCLUDING NEGLIGENCE OR STRICT LIABILITY), OR OTHERWISE,
|
||||
* EVEN IF ANY OF SAID PARTIES HAS BEEN WARNED OF THE
|
||||
* POSSIBILITY OF SUCH LOSS OR DAMAGES.</strong></li>
|
||||
* <ol></td></tr>
|
||||
* </table>
|
||||
|
||||
* @author Argonne National Laboratory. MINPACK project. March 1980 (original fortran minpack tests)
|
||||
* @author Burton S. Garbow (original fortran minpack tests)
|
||||
* @author Kenneth E. Hillstrom (original fortran minpack tests)
|
||||
* @author Jorge J. More (original fortran minpack tests)
|
||||
* @author Luc Maisonobe (non-minpack tests and minpack tests Java translation)
|
||||
*/
|
||||
public class NonLinearConjugateGradientOptimizerTest
|
||||
extends TestCase {
|
||||
|
||||
public NonLinearConjugateGradientOptimizerTest(String name) {
|
||||
super(name);
|
||||
}
|
||||
|
||||
public void testTrivial() throws FunctionEvaluationException, OptimizationException {
|
||||
LinearProblem problem =
|
||||
new LinearProblem(new double[][] { { 2 } }, new double[] { 3 });
|
||||
NonLinearConjugateGradientOptimizer optimizer =
|
||||
new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
|
||||
optimizer.setMaxIterations(100);
|
||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6));
|
||||
RealPointValuePair optimum =
|
||||
optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 0 });
|
||||
assertEquals(1.5, optimum.getPoint()[0], 1.0e-10);
|
||||
assertEquals(0.0, optimum.getValue(), 1.0e-10);
|
||||
}
|
||||
|
||||
public void testColumnsPermutation() throws FunctionEvaluationException, OptimizationException {
|
||||
|
||||
LinearProblem problem =
|
||||
new LinearProblem(new double[][] { { 1.0, -1.0 }, { 0.0, 2.0 }, { 1.0, -2.0 } },
|
||||
new double[] { 4.0, 6.0, 1.0 });
|
||||
|
||||
NonLinearConjugateGradientOptimizer optimizer =
|
||||
new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
|
||||
optimizer.setMaxIterations(100);
|
||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6));
|
||||
RealPointValuePair optimum =
|
||||
optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 0, 0 });
|
||||
assertEquals(7.0, optimum.getPoint()[0], 1.0e-10);
|
||||
assertEquals(3.0, optimum.getPoint()[1], 1.0e-10);
|
||||
assertEquals(0.0, optimum.getValue(), 1.0e-10);
|
||||
|
||||
}
|
||||
|
||||
public void testNoDependency() throws FunctionEvaluationException, OptimizationException {
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 2, 0, 0, 0, 0, 0 },
|
||||
{ 0, 2, 0, 0, 0, 0 },
|
||||
{ 0, 0, 2, 0, 0, 0 },
|
||||
{ 0, 0, 0, 2, 0, 0 },
|
||||
{ 0, 0, 0, 0, 2, 0 },
|
||||
{ 0, 0, 0, 0, 0, 2 }
|
||||
}, new double[] { 0.0, 1.1, 2.2, 3.3, 4.4, 5.5 });
|
||||
NonLinearConjugateGradientOptimizer optimizer =
|
||||
new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
|
||||
optimizer.setMaxIterations(100);
|
||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6));
|
||||
RealPointValuePair optimum =
|
||||
optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 0, 0, 0, 0, 0, 0 });
|
||||
for (int i = 0; i < problem.target.length; ++i) {
|
||||
assertEquals(0.55 * i, optimum.getPoint()[i], 1.0e-10);
|
||||
}
|
||||
}
|
||||
|
||||
public void testOneSet() throws FunctionEvaluationException, OptimizationException {
|
||||
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 1, 0, 0 },
|
||||
{ -1, 1, 0 },
|
||||
{ 0, -1, 1 }
|
||||
}, new double[] { 1, 1, 1});
|
||||
NonLinearConjugateGradientOptimizer optimizer =
|
||||
new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
|
||||
optimizer.setMaxIterations(100);
|
||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6));
|
||||
RealPointValuePair optimum =
|
||||
optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 0, 0, 0 });
|
||||
assertEquals(1.0, optimum.getPoint()[0], 1.0e-10);
|
||||
assertEquals(2.0, optimum.getPoint()[1], 1.0e-10);
|
||||
assertEquals(3.0, optimum.getPoint()[2], 1.0e-10);
|
||||
|
||||
}
|
||||
|
||||
public void testTwoSets() throws FunctionEvaluationException, OptimizationException {
|
||||
final double epsilon = 1.0e-7;
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 2, 1, 0, 4, 0, 0 },
|
||||
{ -4, -2, 3, -7, 0, 0 },
|
||||
{ 4, 1, -2, 8, 0, 0 },
|
||||
{ 0, -3, -12, -1, 0, 0 },
|
||||
{ 0, 0, 0, 0, epsilon, 1 },
|
||||
{ 0, 0, 0, 0, 1, 1 }
|
||||
}, new double[] { 2, -9, 2, 2, 1 + epsilon * epsilon, 2});
|
||||
|
||||
NonLinearConjugateGradientOptimizer optimizer =
|
||||
new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
|
||||
optimizer.setMaxIterations(100);
|
||||
optimizer.setPreconditioner(new Preconditioner() {
|
||||
private static final long serialVersionUID = -2935127802358453014L;
|
||||
public double[] precondition(double[] point, double[] r) {
|
||||
double[] d = r.clone();
|
||||
d[0] /= 72.0;
|
||||
d[1] /= 30.0;
|
||||
d[2] /= 314.0;
|
||||
d[3] /= 260.0;
|
||||
d[4] /= 2 * (1 + epsilon * epsilon);
|
||||
d[5] /= 4.0;
|
||||
return d;
|
||||
}
|
||||
});
|
||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-13, 1.0e-13));
|
||||
|
||||
RealPointValuePair optimum =
|
||||
optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 0, 0, 0, 0, 0, 0 });
|
||||
assertEquals( 3.0, optimum.getPoint()[0], 1.0e-10);
|
||||
assertEquals( 4.0, optimum.getPoint()[1], 1.0e-10);
|
||||
assertEquals(-1.0, optimum.getPoint()[2], 1.0e-10);
|
||||
assertEquals(-2.0, optimum.getPoint()[3], 1.0e-10);
|
||||
assertEquals( 1.0 + epsilon, optimum.getPoint()[4], 1.0e-10);
|
||||
assertEquals( 1.0 - epsilon, optimum.getPoint()[5], 1.0e-10);
|
||||
|
||||
}
|
||||
|
||||
public void testNonInversible() throws FunctionEvaluationException, OptimizationException {
|
||||
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 1, 2, -3 },
|
||||
{ 2, 1, 3 },
|
||||
{ -3, 0, -9 }
|
||||
}, new double[] { 1, 1, 1 });
|
||||
NonLinearConjugateGradientOptimizer optimizer =
|
||||
new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
|
||||
optimizer.setMaxIterations(100);
|
||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6));
|
||||
RealPointValuePair optimum =
|
||||
optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 0, 0, 0 });
|
||||
assertTrue(optimum.getValue() > 0.5);
|
||||
}
|
||||
|
||||
public void testIllConditioned() throws FunctionEvaluationException, OptimizationException {
|
||||
LinearProblem problem1 = new LinearProblem(new double[][] {
|
||||
{ 10.0, 7.0, 8.0, 7.0 },
|
||||
{ 7.0, 5.0, 6.0, 5.0 },
|
||||
{ 8.0, 6.0, 10.0, 9.0 },
|
||||
{ 7.0, 5.0, 9.0, 10.0 }
|
||||
}, new double[] { 32, 23, 33, 31 });
|
||||
NonLinearConjugateGradientOptimizer optimizer =
|
||||
new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
|
||||
optimizer.setMaxIterations(100);
|
||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-13, 1.0e-13));
|
||||
BrentSolver solver = new BrentSolver();
|
||||
solver.setAbsoluteAccuracy(1.0e-15);
|
||||
solver.setRelativeAccuracy(1.0e-15);
|
||||
optimizer.setLineSearchSolver(solver);
|
||||
RealPointValuePair optimum1 =
|
||||
optimizer.optimize(problem1, GoalType.MINIMIZE, new double[] { 0, 1, 2, 3 });
|
||||
assertEquals(1.0, optimum1.getPoint()[0], 1.0e-5);
|
||||
assertEquals(1.0, optimum1.getPoint()[1], 1.0e-5);
|
||||
assertEquals(1.0, optimum1.getPoint()[2], 1.0e-5);
|
||||
assertEquals(1.0, optimum1.getPoint()[3], 1.0e-5);
|
||||
|
||||
LinearProblem problem2 = new LinearProblem(new double[][] {
|
||||
{ 10.00, 7.00, 8.10, 7.20 },
|
||||
{ 7.08, 5.04, 6.00, 5.00 },
|
||||
{ 8.00, 5.98, 9.89, 9.00 },
|
||||
{ 6.99, 4.99, 9.00, 9.98 }
|
||||
}, new double[] { 32, 23, 33, 31 });
|
||||
RealPointValuePair optimum2 =
|
||||
optimizer.optimize(problem2, GoalType.MINIMIZE, new double[] { 0, 1, 2, 3 });
|
||||
assertEquals(-81.0, optimum2.getPoint()[0], 1.0e-1);
|
||||
assertEquals(137.0, optimum2.getPoint()[1], 1.0e-1);
|
||||
assertEquals(-34.0, optimum2.getPoint()[2], 1.0e-1);
|
||||
assertEquals( 22.0, optimum2.getPoint()[3], 1.0e-1);
|
||||
|
||||
}
|
||||
|
||||
public void testMoreEstimatedParametersSimple()
|
||||
throws FunctionEvaluationException, OptimizationException {
|
||||
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 3.0, 2.0, 0.0, 0.0 },
|
||||
{ 0.0, 1.0, -1.0, 1.0 },
|
||||
{ 2.0, 0.0, 1.0, 0.0 }
|
||||
}, new double[] { 7.0, 3.0, 5.0 });
|
||||
|
||||
NonLinearConjugateGradientOptimizer optimizer =
|
||||
new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
|
||||
optimizer.setMaxIterations(100);
|
||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6));
|
||||
RealPointValuePair optimum =
|
||||
optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 7, 6, 5, 4 });
|
||||
assertEquals(0, optimum.getValue(), 1.0e-10);
|
||||
|
||||
}
|
||||
|
||||
public void testMoreEstimatedParametersUnsorted()
|
||||
throws FunctionEvaluationException, OptimizationException {
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 1.0, 1.0, 0.0, 0.0, 0.0, 0.0 },
|
||||
{ 0.0, 0.0, 1.0, 1.0, 1.0, 0.0 },
|
||||
{ 0.0, 0.0, 0.0, 0.0, 1.0, -1.0 },
|
||||
{ 0.0, 0.0, -1.0, 1.0, 0.0, 1.0 },
|
||||
{ 0.0, 0.0, 0.0, -1.0, 1.0, 0.0 }
|
||||
}, new double[] { 3.0, 12.0, -1.0, 7.0, 1.0 });
|
||||
NonLinearConjugateGradientOptimizer optimizer =
|
||||
new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
|
||||
optimizer.setMaxIterations(100);
|
||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6));
|
||||
RealPointValuePair optimum =
|
||||
optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 2, 2, 2, 2, 2, 2 });
|
||||
assertEquals(0, optimum.getValue(), 1.0e-10);
|
||||
}
|
||||
|
||||
public void testRedundantEquations() throws FunctionEvaluationException, OptimizationException {
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 1.0, 1.0 },
|
||||
{ 1.0, -1.0 },
|
||||
{ 1.0, 3.0 }
|
||||
}, new double[] { 3.0, 1.0, 5.0 });
|
||||
|
||||
NonLinearConjugateGradientOptimizer optimizer =
|
||||
new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
|
||||
optimizer.setMaxIterations(100);
|
||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6));
|
||||
RealPointValuePair optimum =
|
||||
optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 1, 1 });
|
||||
assertEquals(2.0, optimum.getPoint()[0], 1.0e-8);
|
||||
assertEquals(1.0, optimum.getPoint()[1], 1.0e-8);
|
||||
|
||||
}
|
||||
|
||||
public void testInconsistentEquations() throws FunctionEvaluationException, OptimizationException {
|
||||
LinearProblem problem = new LinearProblem(new double[][] {
|
||||
{ 1.0, 1.0 },
|
||||
{ 1.0, -1.0 },
|
||||
{ 1.0, 3.0 }
|
||||
}, new double[] { 3.0, 1.0, 4.0 });
|
||||
|
||||
NonLinearConjugateGradientOptimizer optimizer =
|
||||
new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
|
||||
optimizer.setMaxIterations(100);
|
||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6));
|
||||
RealPointValuePair optimum =
|
||||
optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 1, 1 });
|
||||
assertTrue(optimum.getValue() > 0.1);
|
||||
|
||||
}
|
||||
|
||||
public void testCircleFitting() throws FunctionEvaluationException, OptimizationException {
|
||||
Circle circle = new Circle();
|
||||
circle.addPoint( 30.0, 68.0);
|
||||
circle.addPoint( 50.0, -6.0);
|
||||
circle.addPoint(110.0, -20.0);
|
||||
circle.addPoint( 35.0, 15.0);
|
||||
circle.addPoint( 45.0, 97.0);
|
||||
NonLinearConjugateGradientOptimizer optimizer =
|
||||
new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
|
||||
optimizer.setMaxIterations(100);
|
||||
optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-30, 1.0e-30));
|
||||
BrentSolver solver = new BrentSolver();
|
||||
solver.setAbsoluteAccuracy(1.0e-13);
|
||||
solver.setRelativeAccuracy(1.0e-15);
|
||||
optimizer.setLineSearchSolver(solver);
|
||||
RealPointValuePair optimum =
|
||||
optimizer.optimize(circle, GoalType.MINIMIZE, new double[] { 98.680, 47.345 });
|
||||
Point2D.Double center = new Point2D.Double(optimum.getPointRef()[0], optimum.getPointRef()[1]);
|
||||
assertEquals(69.960161753, circle.getRadius(center), 1.0e-8);
|
||||
assertEquals(96.075902096, center.x, 1.0e-8);
|
||||
assertEquals(48.135167894, center.y, 1.0e-8);
|
||||
}
|
||||
|
||||
private static class LinearProblem implements DifferentiableMultivariateRealFunction {
|
||||
|
||||
private static final long serialVersionUID = 703247177355019415L;
|
||||
final RealMatrix factors;
|
||||
final double[] target;
|
||||
public LinearProblem(double[][] factors, double[] target) {
|
||||
this.factors = new DenseRealMatrix(factors);
|
||||
this.target = target;
|
||||
}
|
||||
|
||||
private double[] gradient(double[] point) {
|
||||
double[] r = factors.operate(point);
|
||||
for (int i = 0; i < r.length; ++i) {
|
||||
r[i] -= target[i];
|
||||
}
|
||||
double[] p = factors.transpose().operate(r);
|
||||
for (int i = 0; i < p.length; ++i) {
|
||||
p[i] *= 2;
|
||||
}
|
||||
return p;
|
||||
}
|
||||
|
||||
public double value(double[] variables) throws FunctionEvaluationException {
|
||||
double[] y = factors.operate(variables);
|
||||
double sum = 0;
|
||||
for (int i = 0; i < y.length; ++i) {
|
||||
double ri = y[i] - target[i];
|
||||
sum += ri * ri;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
public MultivariateVectorialFunction gradient() {
|
||||
return new MultivariateVectorialFunction() {
|
||||
private static final long serialVersionUID = 2621997811350805819L;
|
||||
public double[] value(double[] point) {
|
||||
return gradient(point);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public MultivariateRealFunction partialDerivative(final int k) {
|
||||
return new MultivariateRealFunction() {
|
||||
private static final long serialVersionUID = -6186178619133562011L;
|
||||
public double value(double[] point) {
|
||||
return gradient(point)[k];
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static class Circle implements DifferentiableMultivariateRealFunction {
|
||||
|
||||
private static final long serialVersionUID = -4711170319243817874L;
|
||||
|
||||
private ArrayList<Point2D.Double> points;
|
||||
|
||||
public Circle() {
|
||||
points = new ArrayList<Point2D.Double>();
|
||||
}
|
||||
|
||||
public void addPoint(double px, double py) {
|
||||
points.add(new Point2D.Double(px, py));
|
||||
}
|
||||
|
||||
public int getN() {
|
||||
return points.size();
|
||||
}
|
||||
|
||||
public double getRadius(Point2D.Double center) {
|
||||
double r = 0;
|
||||
for (Point2D.Double point : points) {
|
||||
r += point.distance(center);
|
||||
}
|
||||
return r / points.size();
|
||||
}
|
||||
|
||||
private double[] gradient(double[] point) {
|
||||
|
||||
// optimal radius
|
||||
Point2D.Double center = new Point2D.Double(point[0], point[1]);
|
||||
double radius = getRadius(center);
|
||||
|
||||
// gradient of the sum of squared residuals
|
||||
double dJdX = 0;
|
||||
double dJdY = 0;
|
||||
for (Point2D.Double pk : points) {
|
||||
double dk = pk.distance(center);
|
||||
dJdX += (center.x - pk.x) * (dk - radius) / dk;
|
||||
dJdY += (center.y - pk.y) * (dk - radius) / dk;
|
||||
}
|
||||
dJdX *= 2;
|
||||
dJdY *= 2;
|
||||
|
||||
return new double[] { dJdX, dJdY };
|
||||
|
||||
}
|
||||
|
||||
public double value(double[] variables)
|
||||
throws IllegalArgumentException, FunctionEvaluationException {
|
||||
|
||||
Point2D.Double center = new Point2D.Double(variables[0], variables[1]);
|
||||
double radius = getRadius(center);
|
||||
|
||||
double sum = 0;
|
||||
for (Point2D.Double point : points) {
|
||||
double di = point.distance(center) - radius;
|
||||
sum += di * di;
|
||||
}
|
||||
|
||||
return sum;
|
||||
|
||||
}
|
||||
|
||||
public MultivariateVectorialFunction gradient() {
|
||||
return new MultivariateVectorialFunction() {
|
||||
private static final long serialVersionUID = 3174909643301201710L;
|
||||
public double[] value(double[] point) {
|
||||
return gradient(point);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public MultivariateRealFunction partialDerivative(final int k) {
|
||||
return new MultivariateRealFunction() {
|
||||
private static final long serialVersionUID = 3073956364104833888L;
|
||||
public double value(double[] point) {
|
||||
return gradient(point)[k];
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static Test suite() {
|
||||
return new TestSuite(NonLinearConjugateGradientOptimizerTest.class);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue