diff --git a/src/java/org/apache/commons/math/optimization/general/AbstractLeastSquaresOptimizer.java b/src/java/org/apache/commons/math/optimization/general/AbstractLeastSquaresOptimizer.java index 70af06dbc..381af6b08 100644 --- a/src/java/org/apache/commons/math/optimization/general/AbstractLeastSquaresOptimizer.java +++ b/src/java/org/apache/commons/math/optimization/general/AbstractLeastSquaresOptimizer.java @@ -17,17 +17,18 @@ package org.apache.commons.math.optimization.general; +import org.apache.commons.math.FunctionEvaluationException; import org.apache.commons.math.MaxIterationsExceededException; +import org.apache.commons.math.analysis.DifferentiableMultivariateVectorialFunction; +import org.apache.commons.math.analysis.MultivariateMatrixFunction; import org.apache.commons.math.linear.InvalidMatrixException; import org.apache.commons.math.linear.MatrixUtils; import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.linear.decomposition.LUDecompositionImpl; -import org.apache.commons.math.optimization.ObjectiveException; import org.apache.commons.math.optimization.OptimizationException; import org.apache.commons.math.optimization.SimpleVectorialValueChecker; import org.apache.commons.math.optimization.VectorialConvergenceChecker; -import org.apache.commons.math.optimization.VectorialDifferentiableObjectiveFunction; -import org.apache.commons.math.optimization.VectorialDifferentiableOptimizer; +import org.apache.commons.math.optimization.DifferentiableMultivariateVectorialOptimizer; import org.apache.commons.math.optimization.VectorialPointValuePair; /** @@ -38,7 +39,7 @@ import org.apache.commons.math.optimization.VectorialPointValuePair; * @since 1.2 * */ -public abstract class AbstractLeastSquaresOptimizer implements VectorialDifferentiableOptimizer { +public abstract class AbstractLeastSquaresOptimizer implements DifferentiableMultivariateVectorialOptimizer { /** Serializable version identifier */ private static final long serialVersionUID = 5413193243329026789L; @@ -77,7 +78,10 @@ public abstract class AbstractLeastSquaresOptimizer implements VectorialDifferen protected int rows; /** Objective function. */ - private VectorialDifferentiableObjectiveFunction f; + private DifferentiableMultivariateVectorialFunction f; + + /** Objective function derivatives. */ + private MultivariateMatrixFunction jF; /** Target value for the objective functions at optimum. */ protected double[] target; @@ -85,8 +89,8 @@ public abstract class AbstractLeastSquaresOptimizer implements VectorialDifferen /** Weight for the least squares cost computation. */ protected double[] weights; - /** Current variables set. */ - protected double[] variables; + /** Current point. */ + protected double[] point; /** Current objective function value. */ protected double[] objective; @@ -156,15 +160,15 @@ public abstract class AbstractLeastSquaresOptimizer implements VectorialDifferen /** * Update the jacobian matrix. - * @exception ObjectiveException if the function jacobian + * @exception FunctionEvaluationException if the function jacobian * cannot be evaluated or its dimension doesn't match problem dimension */ - protected void updateJacobian() throws ObjectiveException { + protected void updateJacobian() throws FunctionEvaluationException { ++jacobianEvaluations; - jacobian = f.jacobian(variables, objective); + jacobian = jF.value(point); if (jacobian.length != rows) { - throw new ObjectiveException("dimension mismatch {0} != {1}", - jacobian.length, rows); + throw new FunctionEvaluationException(point, "dimension mismatch {0} != {1}", + jacobian.length, rows); } for (int i = 0; i < rows; i++) { final double[] ji = jacobian[i]; @@ -177,17 +181,17 @@ public abstract class AbstractLeastSquaresOptimizer implements VectorialDifferen /** * Update the residuals array and cost function value. - * @exception ObjectiveException if the function cannot be evaluated + * @exception FunctionEvaluationException if the function cannot be evaluated * or its dimension doesn't match problem dimension */ protected void updateResidualsAndCost() - throws ObjectiveException { + throws FunctionEvaluationException { ++objectiveEvaluations; - objective = f.objective(variables); + objective = f.value(point); if (objective.length != rows) { - throw new ObjectiveException("dimension mismatch {0} != {1}", - objective.length, rows); + throw new FunctionEvaluationException(point, "dimension mismatch {0} != {1}", + objective.length, rows); } cost = 0; for (int i = 0, index = 0; i < rows; i++, index += cols) { @@ -234,13 +238,13 @@ public abstract class AbstractLeastSquaresOptimizer implements VectorialDifferen /** * Get the covariance matrix of optimized parameters. * @return covariance matrix - * @exception ObjectiveException if the function jacobian cannot + * @exception FunctionEvaluationException if the function jacobian cannot * be evaluated * @exception OptimizationException if the covariance matrix * cannot be computed (singular problem) */ public double[][] getCovariances() - throws ObjectiveException, OptimizationException { + throws FunctionEvaluationException, OptimizationException { // set up the jacobian updateJacobian(); @@ -273,13 +277,13 @@ public abstract class AbstractLeastSquaresOptimizer implements VectorialDifferen * Guess the errors in optimized parameters. *
Guessing is covariance-based, it only gives rough order of magnitude.
* @return errors in optimized parameters - * @exception ObjectiveException if the function jacobian cannot b evaluated + * @exception FunctionEvaluationException if the function jacobian cannot b evaluated * @exception OptimizationException if the covariances matrix cannot be computed * or the number of degrees of freedom is not positive (number of measurements * lesser or equal to number of parameters) */ public double[] guessParametersErrors() - throws ObjectiveException, OptimizationException { + throws FunctionEvaluationException, OptimizationException { if (rows <= cols) { throw new OptimizationException( "no degrees of freedom ({0} measurements, {1} parameters)", @@ -295,10 +299,10 @@ public abstract class AbstractLeastSquaresOptimizer implements VectorialDifferen } /** {@inheritDoc} */ - public VectorialPointValuePair optimize(final VectorialDifferentiableObjectiveFunction f, + public VectorialPointValuePair optimize(final DifferentiableMultivariateVectorialFunction f, final double[] target, final double[] weights, final double[] startPoint) - throws ObjectiveException, OptimizationException, IllegalArgumentException { + throws FunctionEvaluationException, OptimizationException, IllegalArgumentException { if (target.length != weights.length) { throw new OptimizationException("dimension mismatch {0} != {1}", @@ -312,14 +316,15 @@ public abstract class AbstractLeastSquaresOptimizer implements VectorialDifferen // store least squares problem characteristics this.f = f; + jF = f.jacobian(); this.target = target.clone(); this.weights = weights.clone(); - this.variables = startPoint.clone(); + this.point = startPoint.clone(); this.residuals = new double[target.length]; // arrays shared with the other private methods rows = target.length; - cols = variables.length; + cols = point.length; jacobian = new double[rows][cols]; cost = Double.POSITIVE_INFINITY; @@ -330,12 +335,12 @@ public abstract class AbstractLeastSquaresOptimizer implements VectorialDifferen /** Perform the bulk of optimization algorithm. * @return the point/value pair giving the optimal value for objective function - * @exception ObjectiveException if the objective function throws one during + * @exception FunctionEvaluationException if the objective function throws one during * the search * @exception OptimizationException if the algorithm failed to converge * @exception IllegalArgumentException if the start point dimension is wrong */ abstract protected VectorialPointValuePair doOptimize() - throws ObjectiveException, OptimizationException, IllegalArgumentException; + throws FunctionEvaluationException, OptimizationException, IllegalArgumentException; } \ No newline at end of file diff --git a/src/java/org/apache/commons/math/optimization/general/GaussNewtonOptimizer.java b/src/java/org/apache/commons/math/optimization/general/GaussNewtonOptimizer.java index 2d6e3929a..728cf9964 100644 --- a/src/java/org/apache/commons/math/optimization/general/GaussNewtonOptimizer.java +++ b/src/java/org/apache/commons/math/optimization/general/GaussNewtonOptimizer.java @@ -17,13 +17,13 @@ package org.apache.commons.math.optimization.general; +import org.apache.commons.math.FunctionEvaluationException; import org.apache.commons.math.linear.DenseRealMatrix; import org.apache.commons.math.linear.InvalidMatrixException; import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.linear.decomposition.DecompositionSolver; import org.apache.commons.math.linear.decomposition.LUDecompositionImpl; import org.apache.commons.math.linear.decomposition.QRDecompositionImpl; -import org.apache.commons.math.optimization.ObjectiveException; import org.apache.commons.math.optimization.OptimizationException; import org.apache.commons.math.optimization.SimpleVectorialValueChecker; import org.apache.commons.math.optimization.VectorialPointValuePair; @@ -63,7 +63,7 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer { /** {@inheritDoc} */ public VectorialPointValuePair doOptimize() - throws ObjectiveException, OptimizationException, IllegalArgumentException { + throws FunctionEvaluationException, OptimizationException, IllegalArgumentException { // iterate until convergence is reached VectorialPointValuePair current = null; @@ -75,7 +75,7 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer { VectorialPointValuePair previous = current; updateResidualsAndCost(); updateJacobian(); - current = new VectorialPointValuePair(variables, objective); + current = new VectorialPointValuePair(point, objective); // build the linear problem final double[] b = new double[cols]; @@ -114,7 +114,7 @@ public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer { // update the estimated parameters for (int i = 0; i < cols; ++i) { - variables[i] += dX[i]; + point[i] += dX[i]; } } catch(InvalidMatrixException e) { diff --git a/src/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizer.java b/src/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizer.java index d36a5be0b..ab58aa4a6 100644 --- a/src/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizer.java +++ b/src/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizer.java @@ -18,7 +18,7 @@ package org.apache.commons.math.optimization.general; import java.util.Arrays; -import org.apache.commons.math.optimization.ObjectiveException; +import org.apache.commons.math.FunctionEvaluationException; import org.apache.commons.math.optimization.OptimizationException; import org.apache.commons.math.optimization.VectorialPointValuePair; @@ -27,8 +27,8 @@ import org.apache.commons.math.optimization.VectorialPointValuePair; * This class solves a least squares problem using the Levenberg-Marquardt algorithm. * *This implementation should work even for over-determined systems - * (i.e. systems having more variables than equations). Over-determined systems - * are solved by ignoring the variables which have the smallest impact according + * (i.e. systems having more point than equations). Over-determined systems + * are solved by ignoring the point which have the smallest impact according * to their jacobian column norm. Only the rank of the matrix and some loop bounds * are changed to implement this.
* @@ -104,7 +104,7 @@ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { /** Serializable version identifier */ private static final long serialVersionUID = 8851282236194244323L; - /** Number of solved variables. */ + /** Number of solved point. */ private int solvedCols; /** Diagonal elements of the R matrix in the Q.R. decomposition. */ @@ -210,7 +210,7 @@ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { /** {@inheritDoc} */ protected VectorialPointValuePair doOptimize() - throws ObjectiveException, OptimizationException, IllegalArgumentException { + throws FunctionEvaluationException, OptimizationException, IllegalArgumentException { // arrays shared with the other private methods solvedCols = Math.min(rows, cols); @@ -220,7 +220,7 @@ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { permutation = new int[cols]; lmDir = new double[cols]; - // local variables + // local point double delta = 0, xNorm = 0; double[] diag = new double[cols]; double[] oldX = new double[cols]; @@ -255,7 +255,7 @@ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { if (firstIteration) { - // scale the variables according to the norms of the columns + // scale the point according to the norms of the columns // of the initial jacobian xNorm = 0; for (int k = 0; k < cols; ++k) { @@ -263,7 +263,7 @@ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { if (dk == 0) { dk = 1.0; } - double xk = dk * variables[k]; + double xk = dk * point[k]; xNorm += xk * xk; diag[k] = dk; } @@ -291,7 +291,7 @@ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { } if (maxCosine <= orthoTolerance) { // convergence has been reached - return new VectorialPointValuePair(variables, objective); + return new VectorialPointValuePair(point, objective); } // rescale if necessary @@ -305,7 +305,7 @@ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { // save the state for (int j = 0; j < solvedCols; ++j) { int pj = permutation[j]; - oldX[pj] = variables[pj]; + oldX[pj] = point[pj]; } double previousCost = cost; double[] tmpVec = residuals; @@ -320,7 +320,7 @@ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { for (int j = 0; j < solvedCols; ++j) { int pj = permutation[j]; lmDir[pj] = -lmDir[pj]; - variables[pj] = oldX[pj] + lmDir[pj]; + point[pj] = oldX[pj] + lmDir[pj]; double s = diag[pj] * lmDir[pj]; lmNorm += s * s; } @@ -384,7 +384,7 @@ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { firstIteration = false; xNorm = 0; for (int k = 0; k < cols; ++k) { - double xK = diag[k] * variables[k]; + double xK = diag[k] * point[k]; xNorm += xK * xK; } xNorm = Math.sqrt(xNorm); @@ -393,7 +393,7 @@ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { cost = previousCost; for (int j = 0; j < solvedCols; ++j) { int pj = permutation[j]; - variables[pj] = oldX[pj]; + point[pj] = oldX[pj]; } tmpVec = residuals; residuals = oldRes; @@ -405,7 +405,7 @@ public class LevenbergMarquardtOptimizer extends AbstractLeastSquaresOptimizer { (preRed <= costRelativeTolerance) && (ratio <= 2.0)) || (delta <= parRelativeTolerance * xNorm)) { - return new VectorialPointValuePair(variables, objective); + return new VectorialPointValuePair(point, objective); } // tests for termination and stringent tolerances diff --git a/src/test/org/apache/commons/math/optimization/general/GaussNewtonOptimizerTest.java b/src/test/org/apache/commons/math/optimization/general/GaussNewtonOptimizerTest.java index ceb49405e..34854cb1c 100644 --- a/src/test/org/apache/commons/math/optimization/general/GaussNewtonOptimizerTest.java +++ b/src/test/org/apache/commons/math/optimization/general/GaussNewtonOptimizerTest.java @@ -25,12 +25,14 @@ import junit.framework.Test; import junit.framework.TestCase; import junit.framework.TestSuite; +import org.apache.commons.math.FunctionEvaluationException; +import org.apache.commons.math.analysis.DifferentiableMultivariateVectorialFunction; +import org.apache.commons.math.analysis.MultivariateMatrixFunction; +import org.apache.commons.math.analysis.MultivariateVectorialFunction; import org.apache.commons.math.linear.DenseRealMatrix; import org.apache.commons.math.linear.RealMatrix; -import org.apache.commons.math.optimization.ObjectiveException; import org.apache.commons.math.optimization.OptimizationException; import org.apache.commons.math.optimization.SimpleVectorialValueChecker; -import org.apache.commons.math.optimization.VectorialDifferentiableObjectiveFunction; import org.apache.commons.math.optimization.VectorialPointValuePair; /** @@ -102,7 +104,7 @@ extends TestCase { super(name); } - public void testTrivial() throws ObjectiveException, OptimizationException { + public void testTrivial() throws FunctionEvaluationException, OptimizationException { LinearProblem problem = new LinearProblem(new double[][] { { 2 } }, new double[] { 3 }); GaussNewtonOptimizer optimizer = new GaussNewtonOptimizer(true); @@ -115,7 +117,7 @@ extends TestCase { assertEquals(3.0, optimum.getValue()[0], 1.0e-10); } - public void testColumnsPermutation() throws ObjectiveException, OptimizationException { + public void testColumnsPermutation() throws FunctionEvaluationException, OptimizationException { LinearProblem problem = new LinearProblem(new double[][] { { 1.0, -1.0 }, { 0.0, 2.0 }, { 1.0, -2.0 } }, @@ -135,7 +137,7 @@ extends TestCase { } - public void testNoDependency() throws ObjectiveException, OptimizationException { + public void testNoDependency() throws FunctionEvaluationException, OptimizationException { LinearProblem problem = new LinearProblem(new double[][] { { 2, 0, 0, 0, 0, 0 }, { 0, 2, 0, 0, 0, 0 }, @@ -156,7 +158,7 @@ extends TestCase { } } - public void testOneSet() throws ObjectiveException, OptimizationException { + public void testOneSet() throws FunctionEvaluationException, OptimizationException { LinearProblem problem = new LinearProblem(new double[][] { { 1, 0, 0 }, @@ -175,7 +177,7 @@ extends TestCase { } - public void testTwoSets() throws ObjectiveException, OptimizationException { + public void testTwoSets() throws FunctionEvaluationException, OptimizationException { double epsilon = 1.0e-7; LinearProblem problem = new LinearProblem(new double[][] { { 2, 1, 0, 4, 0, 0 }, @@ -222,7 +224,7 @@ extends TestCase { } } - public void testIllConditioned() throws ObjectiveException, OptimizationException { + public void testIllConditioned() throws FunctionEvaluationException, OptimizationException { LinearProblem problem1 = new LinearProblem(new double[][] { { 10.0, 7.0, 8.0, 7.0 }, { 7.0, 5.0, 6.0, 5.0 }, @@ -303,7 +305,7 @@ extends TestCase { } } - public void testRedundantEquations() throws ObjectiveException, OptimizationException { + public void testRedundantEquations() throws FunctionEvaluationException, OptimizationException { LinearProblem problem = new LinearProblem(new double[][] { { 1.0, 1.0 }, { 1.0, -1.0 }, @@ -322,7 +324,7 @@ extends TestCase { } - public void testInconsistentEquations() throws ObjectiveException, OptimizationException { + public void testInconsistentEquations() throws FunctionEvaluationException, OptimizationException { LinearProblem problem = new LinearProblem(new double[][] { { 1.0, 1.0 }, { 1.0, -1.0 }, @@ -337,7 +339,7 @@ extends TestCase { } - public void testInconsistentSizes() throws ObjectiveException, OptimizationException { + public void testInconsistentSizes() throws FunctionEvaluationException, OptimizationException { LinearProblem problem = new LinearProblem(new double[][] { { 1, 0 }, { 0, 1 } }, new double[] { -1, 1 }); GaussNewtonOptimizer optimizer = new GaussNewtonOptimizer(true); @@ -366,7 +368,7 @@ extends TestCase { new double[] { 1 }, new double[] { 0, 0 }); fail("an exception should have been thrown"); - } catch (ObjectiveException oe) { + } catch (FunctionEvaluationException oe) { // expected behavior } catch (Exception e) { fail("wrong exception caught"); @@ -396,7 +398,7 @@ extends TestCase { } } - public void testCircleFitting() throws ObjectiveException, OptimizationException { + public void testCircleFitting() throws FunctionEvaluationException, OptimizationException { Circle circle = new Circle(); circle.addPoint( 30.0, 68.0); circle.addPoint( 50.0, -6.0); @@ -417,7 +419,7 @@ extends TestCase { assertEquals(48.135167894714, center.y, 1.0e-10); } - public void testCircleFittingBadInit() throws ObjectiveException, OptimizationException { + public void testCircleFittingBadInit() throws FunctionEvaluationException, OptimizationException { Circle circle = new Circle(); double[][] points = new double[][] { {-0.312967, 0.072366}, {-0.339248, 0.132965}, {-0.379780, 0.202724}, @@ -477,9 +479,9 @@ extends TestCase { } - private static class LinearProblem implements VectorialDifferentiableObjectiveFunction { + private static class LinearProblem implements DifferentiableMultivariateVectorialFunction { - private static final long serialVersionUID = 703247177355019415L; + private static final long serialVersionUID = -8804268799379350190L; final RealMatrix factors; final double[] target; public LinearProblem(double[][] factors, double[] target) { @@ -487,20 +489,42 @@ extends TestCase { this.target = target; } - public double[][] jacobian(double[] variables, double[] value) { - return factors.getData(); + public double[] value(double[] variables) { + return factors.operate(variables); } - public double[] objective(double[] variables) { - return factors.operate(variables); + public MultivariateVectorialFunction partialDerivative(final int i) { + return new MultivariateVectorialFunction() { + private static final long serialVersionUID = 1037082026387842358L; + public double[] value(double[] point) { + return factors.getColumn(i); + } + }; + } + + public MultivariateVectorialFunction gradient(final int i) { + return new MultivariateVectorialFunction() { + private static final long serialVersionUID = -3268626996728727146L; + public double[] value(double[] point) { + return factors.getRow(i); + } + }; + } + + public MultivariateMatrixFunction jacobian() { + return new MultivariateMatrixFunction() { + private static final long serialVersionUID = -8387467946663627585L; + public double[][] value(double[] point) { + return factors.getData(); + } + }; } } - private static class Circle implements VectorialDifferentiableObjectiveFunction { - - private static final long serialVersionUID = -4711170319243817874L; + private static class Circle implements DifferentiableMultivariateVectorialFunction { + private static final long serialVersionUID = -7165774454925027042L; private ArrayListSome of the unit tests are re-implementations of the MINPACK file17 and