diff --git a/src/main/java/org/apache/commons/math/analysis/FunctionUtils.java b/src/main/java/org/apache/commons/math/analysis/FunctionUtils.java index f0d2e881c..371857d57 100644 --- a/src/main/java/org/apache/commons/math/analysis/FunctionUtils.java +++ b/src/main/java/org/apache/commons/math/analysis/FunctionUtils.java @@ -26,6 +26,11 @@ import org.apache.commons.math.analysis.function.Identity; * @since 3.0 */ public class FunctionUtils { + /** + * Class only contains static methods. + */ + private FunctionUtils() {} + /** * Compose functions. * @@ -142,6 +147,7 @@ public class FunctionUtils { * * @param f Binary function. * @param fixed Value to which the first argument of {@code f} is set. + * @return a unary function. */ public static UnivariateRealFunction fix1stArgument(final BivariateRealFunction f, final double fixed) { @@ -157,6 +163,7 @@ public class FunctionUtils { * * @param f Binary function. * @param fixed Value to which the second argument of {@code f} is set. + * @return a unary function. */ public static UnivariateRealFunction fix2ndArgument(final BivariateRealFunction f, final double fixed) { diff --git a/src/main/java/org/apache/commons/math/analysis/polynomials/PolynomialFunctionNewtonForm.java b/src/main/java/org/apache/commons/math/analysis/polynomials/PolynomialFunctionNewtonForm.java index cb96385dd..c0873ee25 100644 --- a/src/main/java/org/apache/commons/math/analysis/polynomials/PolynomialFunctionNewtonForm.java +++ b/src/main/java/org/apache/commons/math/analysis/polynomials/PolynomialFunctionNewtonForm.java @@ -16,7 +16,6 @@ */ package org.apache.commons.math.analysis.polynomials; -import org.apache.commons.math.exception.NullArgumentException; import org.apache.commons.math.exception.NoDataException; import org.apache.commons.math.exception.DimensionMismatchException; import org.apache.commons.math.analysis.UnivariateRealFunction; @@ -69,7 +68,8 @@ public class PolynomialFunctionNewtonForm implements UnivariateRealFunction { * * @param a Coefficients in Newton form formula. * @param c Centers. - * @throws NullArgumentException if any argument is {@code null}. + * @throws org.apache.commons.math.exception.NullArgumentException if + * any argument is {@code null}. * @throws NoDataException if any array has zero length. * @throws DimensionMismatchException if the size difference between * {@code a} and {@code c} is not equal to 1. @@ -154,7 +154,8 @@ public class PolynomialFunctionNewtonForm implements UnivariateRealFunction { * @param c Centers. * @param z Point at which the function value is to be computed. * @return the function value. - * @throws NullArgumentException if any argument is {@code null}. + * @throws org.apache.commons.math.exception.NullArgumentException if + * any argument is {@code null}. * @throws NoDataException if any array has zero length. * @throws DimensionMismatchException if the size difference between * {@code a} and {@code c} is not equal to 1. @@ -202,7 +203,8 @@ public class PolynomialFunctionNewtonForm implements UnivariateRealFunction { * * @param a the coefficients in Newton form formula * @param c the centers - * @throws NullArgumentException if any argument is {@code null}. + * @throws org.apache.commons.math.exception.NullArgumentException if + * any argument is {@code null}. * @throws NoDataException if any array has zero length. * @throws DimensionMismatchException if the size difference between * {@code a} and {@code c} is not equal to 1. diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/AbstractDifferentiableUnivariateRealSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/AbstractDifferentiableUnivariateRealSolver.java new file mode 100644 index 000000000..f1e82c892 --- /dev/null +++ b/src/main/java/org/apache/commons/math/analysis/solvers/AbstractDifferentiableUnivariateRealSolver.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.analysis.solvers; + +import org.apache.commons.math.analysis.DifferentiableUnivariateRealFunction; +import org.apache.commons.math.analysis.UnivariateRealFunction; + +/** + * Provide a default implementation for several functions useful to generic + * solvers. + * + * @version $Revision$ $Date$ + * @since 3.0 + */ +public abstract class AbstractDifferentiableUnivariateRealSolver + extends BaseAbstractUnivariateRealSolver + implements DifferentiableUnivariateRealSolver { + /** Derivative of the function to solve. */ + private UnivariateRealFunction functionDerivative; + + /** + * Construct a solver with given absolute accuracy. + * + * @param absoluteAccuracy Maximum absolute error. + */ + protected AbstractDifferentiableUnivariateRealSolver(final double absoluteAccuracy) { + super(absoluteAccuracy); + } + + /** + * Construct a solver with given accuracies. + * + * @param relativeAccuracy Maximum relative error. + * @param absoluteAccuracy Maximum absolute error. + * @param functionValueAccuracy Maximum function value error. + */ + protected AbstractDifferentiableUnivariateRealSolver(final double relativeAccuracy, + final double absoluteAccuracy, + final double functionValueAccuracy) { + super(relativeAccuracy, absoluteAccuracy, functionValueAccuracy); + } + + /** + * Compute the objective function value. + * + * @param point Point at which the objective function must be evaluated. + * @return the objective function value at specified point. + * @throws org.apache.commons.math.exception.TooManyEvaluationsException + * if the maximal number of evaluations is exceeded. + */ + protected double computeDerivativeObjectiveValue(double point) { + incrementEvaluationCount(); + return functionDerivative.value(point); + } + + /** + * {@inheritDoc} + */ + @Override + protected void setup(DifferentiableUnivariateRealFunction f, + double min, double max, + double startValue) { + super.setup(f, min, max, startValue); + functionDerivative = f.derivative(); + } +} diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/AbstractPolynomialSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/AbstractPolynomialSolver.java new file mode 100644 index 000000000..8961bcd8d --- /dev/null +++ b/src/main/java/org/apache/commons/math/analysis/solvers/AbstractPolynomialSolver.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.analysis.solvers; + +import org.apache.commons.math.analysis.polynomials.PolynomialFunction; + +/** + * Base class for solvers. + * + * @version $Revision$ $Date$ + * @since 3.0 + */ +public abstract class AbstractPolynomialSolver + extends BaseAbstractUnivariateRealSolver + implements PolynomialSolver { + /** Function. */ + private PolynomialFunction polynomialFunction; + + /** + * Construct a solver with given absolute accuracy. + * + * @param absoluteAccuracy Maximum absolute error. + */ + protected AbstractPolynomialSolver(final double absoluteAccuracy) { + super(absoluteAccuracy); + } + /** + * Construct a solver with given accuracies. + * + * @param relativeAccuracy Maximum relative error. + * @param absoluteAccuracy Maximum absolute error. + */ + protected AbstractPolynomialSolver(final double relativeAccuracy, + final double absoluteAccuracy) { + super(relativeAccuracy, absoluteAccuracy); + } + /** + * Construct a solver with given accuracies. + * + * @param relativeAccuracy Maximum relative error. + * @param absoluteAccuracy Maximum absolute error. + * @param functionValueAccuracy Maximum function value error. + */ + protected AbstractPolynomialSolver(final double relativeAccuracy, + final double absoluteAccuracy, + final double functionValueAccuracy) { + super(relativeAccuracy, absoluteAccuracy, functionValueAccuracy); + } + + /** + * {@inheritDoc} + */ + @Override + protected void setup(PolynomialFunction f, + double min, double max, + double startValue) { + super.setup(f, min, max, startValue); + polynomialFunction = f; + } + + /** + * @return the coefficients of the polynomial function. + */ + protected double[] getCoefficients() { + return polynomialFunction.getCoefficients(); + } +} diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/AbstractUnivariateRealSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/AbstractUnivariateRealSolver.java new file mode 100644 index 000000000..318ec689f --- /dev/null +++ b/src/main/java/org/apache/commons/math/analysis/solvers/AbstractUnivariateRealSolver.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.analysis.solvers; + +import org.apache.commons.math.analysis.UnivariateRealFunction; + +/** + * Base class for solvers. + * + * @version $Revision$ $Date$ + * @since 3.0 + */ +public abstract class AbstractUnivariateRealSolver + extends BaseAbstractUnivariateRealSolver + implements UnivariateRealSolver { + /** + * Construct a solver with given absolute accuracy. + * + * @param absoluteAccuracy Maximum absolute error. + */ + protected AbstractUnivariateRealSolver(final double absoluteAccuracy) { + super(absoluteAccuracy); + } + /** + * Construct a solver with given accuracies. + * + * @param relativeAccuracy Maximum relative error. + * @param absoluteAccuracy Maximum absolute error. + */ + protected AbstractUnivariateRealSolver(final double relativeAccuracy, + final double absoluteAccuracy) { + super(relativeAccuracy, absoluteAccuracy); + } + /** + * Construct a solver with given accuracies. + * + * @param relativeAccuracy Maximum relative error. + * @param absoluteAccuracy Maximum absolute error. + * @param functionValueAccuracy Maximum function value error. + */ + protected AbstractUnivariateRealSolver(final double relativeAccuracy, + final double absoluteAccuracy, + final double functionValueAccuracy) { + super(relativeAccuracy, absoluteAccuracy, functionValueAccuracy); + } +} diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/BaseAbstractUnivariateRealSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/BaseAbstractUnivariateRealSolver.java new file mode 100644 index 000000000..39b6cde52 --- /dev/null +++ b/src/main/java/org/apache/commons/math/analysis/solvers/BaseAbstractUnivariateRealSolver.java @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.analysis.solvers; + +import org.apache.commons.math.util.Incrementor; +import org.apache.commons.math.exception.MaxCountExceededException; +import org.apache.commons.math.exception.TooManyEvaluationsException; +import org.apache.commons.math.exception.NullArgumentException; +import org.apache.commons.math.analysis.UnivariateRealFunction; + +/** + * Provide a default implementation for several functions useful to generic + * solvers. + * + * @param Type of function to solve. + * + * @version $Revision: 1030464 $ $Date: 2010-11-03 14:46:04 +0100 (Wed, 03 Nov 2010) $ + * @since 2.0 + */ +public abstract class BaseAbstractUnivariateRealSolver + implements BaseUnivariateRealSolver { + /** Default absolute accuracy */ + public static final double DEFAULT_ABSOLUTE_ACCURACY = 1e-6; + /** Default relative accuracy. */ + public static final double DEFAULT_RELATIVE_ACCURACY = 1e-14; + /** Default function value accuracy. */ + public static final double DEFAULT_FUNCTION_VALUE_ACCURACY = 1e-15; + /** Function value accuracy. */ + private final double functionValueAccuracy; + /** Absolute accuracy. */ + private final double absoluteAccuracy; + /** Relative accuracy. */ + private final double relativeAccuracy; + /** Evaluations counter. */ + private final Incrementor evaluations = new Incrementor(); + /** Lower end of search interval. */ + private double searchMin; + /** Higher end of search interval. */ + private double searchMax; + /** Initial guess. */ + private double searchStart; + /** Function to solve. */ + private FUNC function; + + /** + * Construct a solver with given absolute accuracy. + * + * @param absoluteAccuracy Maximum absolute error. + */ + protected BaseAbstractUnivariateRealSolver(final double absoluteAccuracy) { + this(DEFAULT_RELATIVE_ACCURACY, + absoluteAccuracy, + DEFAULT_FUNCTION_VALUE_ACCURACY); + } + + /** + * Construct a solver with given accuracies. + * + * @param relativeAccuracy Maximum relative error. + * @param absoluteAccuracy Maximum absolute error. + */ + protected BaseAbstractUnivariateRealSolver(final double relativeAccuracy, + final double absoluteAccuracy) { + this(relativeAccuracy, + absoluteAccuracy, + DEFAULT_FUNCTION_VALUE_ACCURACY); + } + + /** + * Construct a solver with given accuracies. + * + * @param relativeAccuracy Maximum relative error. + * @param absoluteAccuracy Maximum absolute error. + * @param functionValueAccuracy Maximum function value error. + */ + protected BaseAbstractUnivariateRealSolver(final double relativeAccuracy, + final double absoluteAccuracy, + final double functionValueAccuracy) { + this.absoluteAccuracy = absoluteAccuracy; + this.relativeAccuracy = relativeAccuracy; + this.functionValueAccuracy = functionValueAccuracy; + } + + /** {@inheritDoc} */ + public void setMaxEvaluations(int maxEvaluations) { + evaluations.setMaximalCount(maxEvaluations); + } + /** {@inheritDoc} */ + public int getMaxEvaluations() { + return evaluations.getMaximalCount(); + } + /** {@inheritDoc} */ + public int getEvaluations() { + return evaluations.getCount(); + } + /** + * @return the lower end of the search interval. + */ + public double getMin() { + return searchMin; + } + /** + * @return the higher end of the search interval. + */ + public double getMax() { + return searchMax; + } + /** + * @return the initial guess. + */ + public double getStartValue() { + return searchStart; + } + /** + * {@inheritDoc} + */ + public double getAbsoluteAccuracy() { + return absoluteAccuracy; + } + /** + * {@inheritDoc} + */ + public double getRelativeAccuracy() { + return relativeAccuracy; + } + /** + * {@inheritDoc} + */ + public double getFunctionValueAccuracy() { + return functionValueAccuracy; + } + + /** + * Compute the objective function value. + * + * @param point Point at which the objective function must be evaluated. + * @return the objective function value at specified point. + * @throws TooManyEvaluationsException if the maximal number of evaluations + * is exceeded. + */ + protected double computeObjectiveValue(double point) { + incrementEvaluationCount(); + return function.value(point); + } + + /** + * Prepare for computation. + * Subclasses must call this method if they override any of the + * {@code solve} methods. + * + * @param f Function to solve. + * @param min Lower bound for the interval. + * @param max Upper bound for the interval. + * @param startValue Start value to use. + */ + protected void setup(FUNC f, + double min, double max, + double startValue) { + // Checks. + if (f == null) { + throw new NullArgumentException(); + } + + // Reset. + searchMin = min; + searchMax = max; + searchStart = startValue; + function = f; + evaluations.resetCount(); + } + + /** {@inheritDoc} */ + public double solve(FUNC f, double min, double max, double startValue) { + // Initialization. + setup(f, min, max, startValue); + + // Perform computation. + return doSolve(); + } + + /** {@inheritDoc} */ + public double solve(FUNC f, double min, double max) { + return solve(f, min, max, min + 0.5 * (max - min)); + } + + /** {@inheritDoc} */ + public double solve(FUNC f, double startValue) { + return solve(f, Double.NaN, Double.NaN, startValue); + } + + /** + * Method for implementing actual optimization algorithms in derived + * classes. + * + * @return the root. + * @throws TooManyEvaluationsException if the maximal number of evaluations + * is exceeded. + */ + protected abstract double doSolve(); + + /** + * Check whether the function takes opposite signs at the endpoints. + * + * @param lower Lower endpoint. + * @param upper Upper endpoint. + * @return {@code true} if the function values have opposite signs at the + * given points. + */ + protected boolean isBracketing(final double lower, + final double upper) { + return UnivariateRealSolverUtils.isBracketing(function, lower, upper); + } + + /** + * Check whether the arguments form a (strictly) increasing sequence. + * + * @param start First number. + * @param mid Second number. + * @param end Third number. + * @return {@code true} if the arguments form an increasing sequence. + */ + protected boolean isSequence(final double start, + final double mid, + final double end) { + return UnivariateRealSolverUtils.isSequence(start, mid, end); + } + + /** + * Check that the endpoints specify an interval. + * + * @param lower Lower endpoint. + * @param upper Upper endpoint. + * @throws org.apache.commons.math.exception.NumberIsTooLargeException + * if {@code lower >= upper}. + */ + protected void verifyInterval(final double lower, + final double upper) { + UnivariateRealSolverUtils.verifyInterval(lower, upper); + } + + /** + * Check that {@code lower < initial < upper}. + * + * @param lower Lower endpoint. + * @param initial Initial value. + * @param upper Upper endpoint. + * @throws org.apache.commons.math.exception.NumberIsTooLargeException + * if {@code lower >= initial} or {@code initial >= upper}. + */ + protected void verifySequence(final double lower, + final double initial, + final double upper) { + UnivariateRealSolverUtils.verifySequence(lower, initial, upper); + } + + /** + * Check that the endpoints specify an interval and the function takes + * opposite signs at the endpoints. + * + * @param lower Lower endpoint. + * @param upper Upper endpoint. + * @throws org.apache.commons.math.exception.NoBracketingException if + * the function has the same sign at the endpoints. + */ + protected void verifyBracketing(final double lower, + final double upper) { + UnivariateRealSolverUtils.verifyBracketing(function, lower, upper); + } + + /** + * Increment the evaluation count by one. + * Method {@link #computeObjectiveValue(double)} calls this method internally. + * It is provided for subclasses that do not exclusively use + * {@code computeObjectiveValue} to solve the function. + * See e.g. {@link AbstractDifferentiableUnivariateRealSolver}. + */ + protected void incrementEvaluationCount() { + try { + evaluations.incrementCount(); + } catch (MaxCountExceededException e) { + throw new TooManyEvaluationsException(e.getMax()); + } + } +} diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/BaseUnivariateRealSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/BaseUnivariateRealSolver.java new file mode 100644 index 000000000..c224ea997 --- /dev/null +++ b/src/main/java/org/apache/commons/math/analysis/solvers/BaseUnivariateRealSolver.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.analysis.solvers; + +import org.apache.commons.math.analysis.UnivariateRealFunction; + + +/** + * Interface for (univariate real) rootfinding algorithms. + * Implementations will search for only one zero in the given interval. + * + * @param Type of function to solve. + * + * @version $Revision$ $Date$ + * @since 3.0 + */ +public interface BaseUnivariateRealSolver { + /** + * Set the maximal number of function evaluations. + * + * @param maxEvaluations Maximal number of function evaluations. + */ + void setMaxEvaluations(int maxEvaluations); + + /** + * Get the maximal number of function evaluations. + * + * @return the maximal number of function evaluations. + */ + int getMaxEvaluations(); + + /** + * Get the number of evaluations of the objective function. + * The number of evaluations corresponds to the last call to the + * {@code optimize} method. It is 0 if the method has not been + * called yet. + * + * @return the number of evaluations of the objective function. + */ + int getEvaluations(); + + /** + * @return the absolute accuracy. + */ + double getAbsoluteAccuracy(); + /** + * @return the relative accuracy. + */ + double getRelativeAccuracy(); + /** + * @return the function value accuracy. + */ + double getFunctionValueAccuracy(); + + /** + * Solve for a zero root in the given interval. + * A solver may require that the interval brackets a single zero root. + * Solvers that do require bracketing should be able to handle the case + * where one of the endpoints is itself a root. + * + * @param f Function to solve. + * @param min Lower bound for the interval. + * @param max Upper bound for the interval. + * @return a value where the function is zero. + * @throws IllegalArgumentException if {@code min > max} or the endpoints + * do not satisfy the requirements specified by the solver. + * @since 2.0 + */ + double solve(FUNC f, double min, double max); + + /** + * Solve for a zero in the given interval, start at {@code startValue}. + * A solver may require that the interval brackets a single zero root. + * Solvers that do require bracketing should be able to handle the case + * where one of the endpoints is itself a root. + * + * @param f Function to solve. + * @param min Lower bound for the interval. + * @param max Upper bound for the interval. + * @param startValue Start value to use. + * @return a value where the function is zero. + * @throws IllegalArgumentException if {@code min > max} or the arguments + * do not satisfy the requirements specified by the solver. + * @since 2.0 + */ + double solve(FUNC f, double min, double max, double startValue); + + /** + * Solve for a zero in the vicinity of {@code startValue}. + * A solver may require that the interval brackets a single zero root. + * + * @param f Function to solve. + * @param startValue Start value to use. + * @return a value where the function is zero. + * @throws IllegalArgumentException if {@code min > max} or the arguments + * do not satisfy the requirements specified by the solver. + * @since 2.0 + */ + double solve(FUNC f, double startValue); +} diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/BisectionSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/BisectionSolver.java index 12187cb27..93c936e1f 100644 --- a/src/main/java/org/apache/commons/math/analysis/solvers/BisectionSolver.java +++ b/src/main/java/org/apache/commons/math/analysis/solvers/BisectionSolver.java @@ -16,9 +16,6 @@ */ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.MaxIterationsExceededException; -import org.apache.commons.math.analysis.UnivariateRealFunction; -import org.apache.commons.math.exception.MathUserException; import org.apache.commons.math.util.FastMath; /** @@ -29,39 +26,54 @@ import org.apache.commons.math.util.FastMath; * * @version $Revision$ $Date$ */ -public class BisectionSolver extends UnivariateRealSolverImpl { +public class BisectionSolver extends AbstractUnivariateRealSolver { + /** Default absolute accuracy. */ + public static final double DEFAULT_ABSOLUTE_ACCURACY = 1e-6; + /** + * Construct a solver with default accuracy. + */ + public BisectionSolver() { + this(DEFAULT_ABSOLUTE_ACCURACY); + } /** * Construct a solver. * + * @param absoluteAccuracy Absolute accuracy. */ - public BisectionSolver() { - super(100, 1E-6); + public BisectionSolver(double absoluteAccuracy) { + super(absoluteAccuracy); + } + /** + * Construct a solver. + * + * @param relativeAccuracy Relative accuracy. + * @param absoluteAccuracy Absolute accuracy. + */ + public BisectionSolver(double relativeAccuracy, + double absoluteAccuracy) { + super(relativeAccuracy, absoluteAccuracy); } - /** {@inheritDoc} */ - public double solve(final UnivariateRealFunction f, double min, double max, double initial) - throws MaxIterationsExceededException, MathUserException { - return solve(f, min, max); - } - - /** {@inheritDoc} */ - public double solve(final UnivariateRealFunction f, double min, double max) - throws MaxIterationsExceededException, MathUserException { - - clearResult(); - verifyInterval(min,max); + /** + * {@inheritDoc} + */ + @Override + protected double doSolve() { + double min = getMin(); + double max = getMax(); + verifyInterval(min, max); + final double absoluteAccuracy = getAbsoluteAccuracy(); double m; double fm; double fmin; - int i = 0; - while (i < maximalIterationCount) { + while (true) { m = UnivariateRealSolverUtils.midpoint(min, max); - fmin = f.value(min); - fm = f.value(m); + fmin = computeObjectiveValue(min); + fm = computeObjectiveValue(m); - if (fm * fmin > 0.0) { + if (fm * fmin > 0) { // max and m bracket the root. min = m; } else { @@ -71,12 +83,8 @@ public class BisectionSolver extends UnivariateRealSolverImpl { if (FastMath.abs(max - min) <= absoluteAccuracy) { m = UnivariateRealSolverUtils.midpoint(min, max); - setResult(m, i); return m; } - ++i; } - - throw new MaxIterationsExceededException(maximalIterationCount); } } diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/BrentSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/BrentSolver.java index 73738eb74..169d937d7 100644 --- a/src/main/java/org/apache/commons/math/analysis/solvers/BrentSolver.java +++ b/src/main/java/org/apache/commons/math/analysis/solvers/BrentSolver.java @@ -17,297 +17,214 @@ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.MathRuntimeException; -import org.apache.commons.math.MaxIterationsExceededException; -import org.apache.commons.math.analysis.UnivariateRealFunction; -import org.apache.commons.math.exception.MathUserException; -import org.apache.commons.math.exception.util.LocalizedFormats; +import org.apache.commons.math.exception.NoBracketingException; import org.apache.commons.math.util.FastMath; +import org.apache.commons.math.util.MathUtils; /** - * Implements the - * Brent algorithm for finding zeros of real univariate functions. - *

- * The function should be continuous but not necessarily smooth.

+ * This class implements the + * Brent algorithm for finding zeros of real univariate functions. + * The function should be continuous but not necessarily smooth. + * The {@code solve} method returns a zero {@code x} of the function {@code f} + * in the given interval {@code [a, b]} to within a tolerance + * {@code 6 eps abs(x) + t} where {@code eps} is the relative accuracy and + * {@code t} is the absolute accuracy. + * The given interval must bracket the root. * * @version $Revision:670469 $ $Date:2008-06-23 10:01:38 +0200 (lun., 23 juin 2008) $ */ -public class BrentSolver extends UnivariateRealSolverImpl { - - /** - * Default absolute accuracy - * @since 2.1 - */ - public static final double DEFAULT_ABSOLUTE_ACCURACY = 1E-6; - - /** Default maximum number of iterations - * @since 2.1 - */ - public static final int DEFAULT_MAXIMUM_ITERATIONS = 100; - +public class BrentSolver extends AbstractUnivariateRealSolver { /** Serializable version identifier */ private static final long serialVersionUID = 7694577816772532779L; + /** Default absolute accuracy. */ + public static final double DEFAULT_ABSOLUTE_ACCURACY = 1e-6; /** - * Construct a solver with default properties. + * Construct a solver with default accuracies. */ public BrentSolver() { - super(DEFAULT_MAXIMUM_ITERATIONS, DEFAULT_ABSOLUTE_ACCURACY); + this(DEFAULT_ABSOLUTE_ACCURACY); } - /** - * Construct a solver with the given absolute accuracy. + * Construct a solver. * - * @param absoluteAccuracy lower bound for absolute accuracy of solutions returned by the solver - * @since 2.1 + * @param absoluteAccuracy Absolute accuracy. */ public BrentSolver(double absoluteAccuracy) { - super(DEFAULT_MAXIMUM_ITERATIONS, absoluteAccuracy); + super(absoluteAccuracy); + } + /** + * Construct a solver. + * + * @param relativeAccuracy Relative accuracy. + * @param absoluteAccuracy Absolute accuracy. + */ + public BrentSolver(double relativeAccuracy, + double absoluteAccuracy) { + super(relativeAccuracy, absoluteAccuracy); + } + /** + * Construct a solver. + * + * @param relativeAccuracy Relative accuracy. + * @param absoluteAccuracy Absolute accuracy. + * @param functionValueAccuracy Function value accuracy. + */ + public BrentSolver(double relativeAccuracy, + double absoluteAccuracy, + double functionValueAccuracy) { + super(relativeAccuracy, absoluteAccuracy, functionValueAccuracy); } /** - * Contstruct a solver with the given maximum iterations and absolute accuracy. - * - * @param maximumIterations maximum number of iterations - * @param absoluteAccuracy lower bound for absolute accuracy of solutions returned by the solver - * @since 2.1 + * {@inheritDoc} */ - public BrentSolver(int maximumIterations, double absoluteAccuracy) { - super(maximumIterations, absoluteAccuracy); - } + @Override + protected double doSolve() { + double min = getMin(); + double max = getMax(); + final double initial = getStartValue(); + final double functionValueAccuracy = getFunctionValueAccuracy(); - /** - * Find a zero in the given interval with an initial guess. - *

Throws IllegalArgumentException if the values of the - * function at the three points have the same sign (note that it is - * allowed to have endpoints with the same sign if the initial point has - * opposite sign function-wise).

- * - * @param f function to solve. - * @param min the lower bound for the interval. - * @param max the upper bound for the interval. - * @param initial the start value to use (must be set to min if no - * initial point is known). - * @return the value where the function is zero - * @throws MaxIterationsExceededException the maximum iteration count is exceeded - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if initial is not between min and max - * (even if it is a root) - */ - public double solve(final UnivariateRealFunction f, - final double min, final double max, final double initial) - throws MaxIterationsExceededException, MathUserException { + verifySequence(min, initial, max); - clearResult(); - if ((initial < min) || (initial > max)) { - throw MathRuntimeException.createIllegalArgumentException( - LocalizedFormats.INVALID_INTERVAL_INITIAL_VALUE_PARAMETERS, - min, initial, max); - } - - // return the initial guess if it is good enough - double yInitial = f.value(initial); + // Return the initial guess if it is good enough. + double yInitial = computeObjectiveValue(initial); if (FastMath.abs(yInitial) <= functionValueAccuracy) { - setResult(initial, 0); - return result; + return initial; } - // return the first endpoint if it is good enough - double yMin = f.value(min); + // Return the first endpoint if it is good enough. + double yMin = computeObjectiveValue(min); if (FastMath.abs(yMin) <= functionValueAccuracy) { - setResult(min, 0); - return result; + return min; } - // reduce interval if min and initial bracket the root + // Reduce interval if min and initial bracket the root. if (yInitial * yMin < 0) { - return solve(f, min, yMin, initial, yInitial, min, yMin); + return brent(min, initial, yMin, yInitial); } - // return the second endpoint if it is good enough - double yMax = f.value(max); + // Return the second endpoint if it is good enough. + double yMax = computeObjectiveValue(max); if (FastMath.abs(yMax) <= functionValueAccuracy) { - setResult(max, 0); - return result; + return max; } - // reduce interval if initial and max bracket the root + // Reduce interval if initial and max bracket the root. if (yInitial * yMax < 0) { - return solve(f, initial, yInitial, max, yMax, initial, yInitial); + return brent(initial, max, yInitial, yMax); } - throw MathRuntimeException.createIllegalArgumentException( - LocalizedFormats.SAME_SIGN_AT_ENDPOINTS, min, max, yMin, yMax); - + throw new NoBracketingException(min, max, yMin, yMax); } /** - * Find a zero in the given interval. - *

- * Requires that the values of the function at the endpoints have opposite - * signs. An IllegalArgumentException is thrown if this is not - * the case.

+ * Search for a zero inside the provided interval. + * This implemenation is based on the algorithm described at page 58 of + * the book + * + * Algorithms for Minimization Without Derivatives + * Richard P. Brent + * Dover 0-486-41998-3 + * * - * @param f the function to solve - * @param min the lower bound for the interval. - * @param max the upper bound for the interval. - * @return the value where the function is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if min is not less than max or the - * signs of the values of the function at the endpoints are not opposites + * @param lo Lower bound of the search interval. + * @param hi Higher bound of the search interval. + * @param fLo Function value at the lower bound of the search interval. + * @param fHi Function value at the higher bound of the search interval. + * @return the value where the function is zero. */ - public double solve(final UnivariateRealFunction f, - final double min, final double max) - throws MaxIterationsExceededException, MathUserException { + private double brent(double lo, double hi, + double fLo, double fHi) { + double a = lo; + double fa = fLo; + double b = hi; + double fb = fHi; + double c = a; + double fc = fa; + double d = b - a; + double e = d; - clearResult(); - verifyInterval(min, max); + final double t = getAbsoluteAccuracy(); + final double eps = getRelativeAccuracy(); - double ret = Double.NaN; - - double yMin = f.value(min); - double yMax = f.value(max); - - // Verify bracketing - double sign = yMin * yMax; - if (sign > 0) { - // check if either value is close to a zero - if (FastMath.abs(yMin) <= functionValueAccuracy) { - setResult(min, 0); - ret = min; - } else if (FastMath.abs(yMax) <= functionValueAccuracy) { - setResult(max, 0); - ret = max; - } else { - // neither value is close to zero and min and max do not bracket root. - throw MathRuntimeException.createIllegalArgumentException( - LocalizedFormats.SAME_SIGN_AT_ENDPOINTS, min, max, yMin, yMax); + while (true) { + if (FastMath.abs(fc) < FastMath.abs(fb)) { + a = b; + b = c; + c = a; + fa = fb; + fb = fc; + fc = fa; } - } else if (sign < 0){ - // solve using only the first endpoint as initial guess - ret = solve(f, min, yMin, max, yMax, min, yMin); - } else { - // either min or max is a root - if (yMin == 0.0) { - ret = min; - } else { - ret = max; - } - } - return ret; - } + final double tol = 2 * eps * FastMath.abs(b) + t; + final double m = 0.5 * (c - b); - /** - * Find a zero starting search according to the three provided points. - * @param f the function to solve - * @param x0 old approximation for the root - * @param y0 function value at the approximation for the root - * @param x1 last calculated approximation for the root - * @param y1 function value at the last calculated approximation - * for the root - * @param x2 bracket point (must be set to x0 if no bracket point is - * known, this will force starting with linear interpolation) - * @param y2 function value at the bracket point. - * @return the value where the function is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * @throws MathUserException if an error occurs evaluating the function - */ - private double solve(final UnivariateRealFunction f, - double x0, double y0, - double x1, double y1, - double x2, double y2) - throws MaxIterationsExceededException, MathUserException { - - double delta = x1 - x0; - double oldDelta = delta; - - int i = 0; - while (i < maximalIterationCount) { - if (FastMath.abs(y2) < FastMath.abs(y1)) { - // use the bracket point if is better than last approximation - x0 = x1; - x1 = x2; - x2 = x0; - y0 = y1; - y1 = y2; - y2 = y0; + if (FastMath.abs(m) <= tol || + MathUtils.equals(fb, 0)) { + return b; } - if (FastMath.abs(y1) <= functionValueAccuracy) { - // Avoid division by very small values. Assume - // the iteration has converged (the problem may - // still be ill conditioned) - setResult(x1, i); - return result; - } - double dx = x2 - x1; - double tolerance = - FastMath.max(relativeAccuracy * FastMath.abs(x1), absoluteAccuracy); - if (FastMath.abs(dx) <= tolerance) { - setResult(x1, i); - return result; - } - if ((FastMath.abs(oldDelta) < tolerance) || - (FastMath.abs(y0) <= FastMath.abs(y1))) { + if (FastMath.abs(e) < tol || + FastMath.abs(fa) <= FastMath.abs(fb)) { // Force bisection. - delta = 0.5 * dx; - oldDelta = delta; + d = m; + e = d; } else { - double r3 = y1 / y0; + double s = fb / fa; double p; - double p1; - // the equality test (x0 == x2) is intentional, - // it is part of the original Brent's method, - // it should NOT be replaced by proximity test - if (x0 == x2) { + double q; + // The equality test (a == c) is intentional, + // it is part of the original Brent's method and + // it should NOT be replaced by proximity test. + if (a == c) { // Linear interpolation. - p = dx * r3; - p1 = 1.0 - r3; + p = 2 * m * s; + q = 1 - s; } else { // Inverse quadratic interpolation. - double r1 = y0 / y2; - double r2 = y1 / y2; - p = r3 * (dx * r1 * (r1 - r2) - (x1 - x0) * (r2 - 1.0)); - p1 = (r1 - 1.0) * (r2 - 1.0) * (r3 - 1.0); + q = fa / fc; + final double r = fb / fc; + p = s * (2 * m * q * (q - r) - (b - a) * (r - 1)); + q = (q - 1) * (r - 1) * (s - 1); } - if (p > 0.0) { - p1 = -p1; + if (p > 0) { + q = -q; } else { p = -p; } - if (2.0 * p >= 1.5 * dx * p1 - FastMath.abs(tolerance * p1) || - p >= FastMath.abs(0.5 * oldDelta * p1)) { + s = e; + e = d; + if (p >= 1.5 * m * q - FastMath.abs(tol * q) || + p >= FastMath.abs(0.5 * s * q)) { // Inverse quadratic interpolation gives a value // in the wrong direction, or progress is slow. // Fall back to bisection. - delta = 0.5 * dx; - oldDelta = delta; + d = m; + e = d; } else { - oldDelta = delta; - delta = p / p1; + d = p / q; } } - // Save old X1, Y1 - x0 = x1; - y0 = y1; - // Compute new X1, Y1 - if (FastMath.abs(delta) > tolerance) { - x1 = x1 + delta; - } else if (dx > 0.0) { - x1 = x1 + 0.5 * tolerance; - } else if (dx <= 0.0) { - x1 = x1 - 0.5 * tolerance; + a = b; + fa = fb; + + if (FastMath.abs(d) > tol) { + b += d; + } else if (m > 0) { + b += tol; + } else { + b -= tol; } - y1 = f.value(x1); - if ((y1 > 0) == (y2 > 0)) { - x2 = x0; - y2 = y0; - delta = x1 - x0; - oldDelta = delta; + fb = computeObjectiveValue(b); + if ((fb > 0 && fc > 0) || + (fb <= 0 && fc <= 0)) { + c = a; + fc = fa; + d = b - a; + e = d; } - i++; } - throw new MaxIterationsExceededException(maximalIterationCount); } } diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/DifferentiableUnivariateRealSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/DifferentiableUnivariateRealSolver.java new file mode 100644 index 000000000..b494e6a43 --- /dev/null +++ b/src/main/java/org/apache/commons/math/analysis/solvers/DifferentiableUnivariateRealSolver.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.analysis.solvers; + +import org.apache.commons.math.analysis.DifferentiableUnivariateRealFunction; + + +/** + * Interface for (univariate real) rootfinding algorithms. + * Implementations will search for only one zero in the given interval. + * + * @version $Revision: 1034896 $ $Date: 2010-11-13 23:27:34 +0100 (Sat, 13 Nov 2010) $ + */ +public interface DifferentiableUnivariateRealSolver + extends BaseUnivariateRealSolver {} diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/LaguerreSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/LaguerreSolver.java index 56764de9c..47db8b050 100644 --- a/src/main/java/org/apache/commons/math/analysis/solvers/LaguerreSolver.java +++ b/src/main/java/org/apache/commons/math/analysis/solvers/LaguerreSolver.java @@ -16,304 +16,335 @@ */ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.ConvergenceException; -import org.apache.commons.math.MathRuntimeException; -import org.apache.commons.math.MaxIterationsExceededException; -import org.apache.commons.math.analysis.UnivariateRealFunction; -import org.apache.commons.math.analysis.polynomials.PolynomialFunction; import org.apache.commons.math.complex.Complex; -import org.apache.commons.math.exception.MathUserException; +import org.apache.commons.math.exception.NoBracketingException; +import org.apache.commons.math.exception.NullArgumentException; +import org.apache.commons.math.exception.NoDataException; import org.apache.commons.math.exception.util.LocalizedFormats; import org.apache.commons.math.util.FastMath; /** * Implements the * Laguerre's Method for root finding of real coefficient polynomials. - * For reference, see A First Course in Numerical Analysis, - * ISBN 048641454X, chapter 8. - *

+ * For reference, see + * + * A First Course in Numerical Analysis + * ISBN 048641454X, chapter 8. + * * Laguerre's method is global in the sense that it can start with any initial - * approximation and be able to solve all roots from that point.

+ * approximation and be able to solve all roots from that point. + * The algorithm requires a bracketing condition. * * @version $Revision$ $Date$ * @since 1.2 */ -public class LaguerreSolver extends UnivariateRealSolverImpl { +public class LaguerreSolver extends AbstractPolynomialSolver { + /** Default absolute accuracy. */ + public static final double DEFAULT_ABSOLUTE_ACCURACY = 1e-6; + /** Complex solver. */ + protected ComplexSolver complexSolver = new ComplexSolver(); /** - * Construct a solver. + * Construct a solver with default accuracies. */ public LaguerreSolver() { - super(100, 1E-6); + this(DEFAULT_ABSOLUTE_ACCURACY); + } + /** + * Construct a solver. + * + * @param absoluteAccuracy Absolute accuracy. + */ + public LaguerreSolver(double absoluteAccuracy) { + super(absoluteAccuracy); + } + /** + * Construct a solver. + * + * @param relativeAccuracy Relative accuracy. + * @param absoluteAccuracy Absolute accuracy. + */ + public LaguerreSolver(double relativeAccuracy, + double absoluteAccuracy) { + super(relativeAccuracy, absoluteAccuracy); + } + /** + * Construct a solver. + * + * @param relativeAccuracy Relative accuracy. + * @param absoluteAccuracy Absolute accuracy. + * @param functionValueAccuracy Function value accuracy. + */ + public LaguerreSolver(double relativeAccuracy, + double absoluteAccuracy, + double functionValueAccuracy) { + super(relativeAccuracy, absoluteAccuracy, functionValueAccuracy); } /** - * Find a real root in the given interval with initial value. - *

- * Requires bracketing condition.

- * - * @param f function to solve (must be polynomial) - * @param min the lower bound for the interval - * @param max the upper bound for the interval - * @param initial the start value to use - * @return the point at which the function value is zero - * @throws ConvergenceException if the maximum iteration count is exceeded - * or the solver detects convergence problems otherwise - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if any parameters are invalid + * {@inheritDoc} */ - public double solve(final UnivariateRealFunction f, - final double min, final double max, final double initial) - throws ConvergenceException, MathUserException { + @Override + public double doSolve() { + double min = getMin(); + double max = getMax(); + double initial = getStartValue(); + final double functionValueAccuracy = getFunctionValueAccuracy(); - // check for zeros before verifying bracketing - if (f.value(min) == 0.0) { - return min; - } - if (f.value(max) == 0.0) { - return max; - } - if (f.value(initial) == 0.0) { + verifySequence(min, initial, max); + + // Return the initial guess if it is good enough. + double yInitial = computeObjectiveValue(initial); + if (FastMath.abs(yInitial) <= functionValueAccuracy) { return initial; } - verifyBracketing(min, max, f); - verifySequence(min, initial, max); - if (isBracketing(min, initial, f)) { - return solve(f, min, initial); - } else { - return solve(f, initial, max); + // Return the first endpoint if it is good enough. + double yMin = computeObjectiveValue(min); + if (FastMath.abs(yMin) <= functionValueAccuracy) { + return min; } + // Reduce interval if min and initial bracket the root. + if (yInitial * yMin < 0) { + return laguerre(min, initial, yMin, yInitial); + } + + // Return the second endpoint if it is good enough. + double yMax = computeObjectiveValue(max); + if (FastMath.abs(yMax) <= functionValueAccuracy) { + return max; + } + + // Reduce interval if initial and max bracket the root. + if (yInitial * yMax < 0) { + return laguerre(initial, max, yInitial, yMax); + } + + throw new NoBracketingException(min, max, yMin, yMax); } /** * Find a real root in the given interval. - *

- * Despite the bracketing condition, the root returned by solve(Complex[], - * Complex) may not be a real zero inside [min, max]. For example, - * p(x) = x^3 + 1, min = -2, max = 2, initial = 0. We can either try - * another initial value, or, as we did here, call solveAll() to obtain - * all roots and pick up the one that we're looking for.

* - * @param f the function to solve - * @param min the lower bound for the interval - * @param max the upper bound for the interval - * @return the point at which the function value is zero - * @throws ConvergenceException if the maximum iteration count is exceeded - * or the solver detects convergence problems otherwise - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if any parameters are invalid + * Despite the bracketing condition, the root returned by + * {@link LaguerreSolver.ComplexSolver#solve(Complex[],Complex)} may + * not be a real zero inside {@code [min, max]}. + * For example, p(x) = x3 + 1, + * with {@code min = -2}, {@code max = 2}, {@code initial = 0}. + * When it occurs, this code calls + * {@link LaguerreSolver.ComplexSolver#solveAll(Complex[],Complex)} + * in order to obtain all roots and picks up one real root. + * + * @param lo Lower bound of the search interval. + * @param hi Higher bound of the search interval. + * @param fLo Function value at the lower bound of the search interval. + * @param fHi Function value at the higher bound of the search interval. + * @return the point at which the function value is zero. */ - public double solve(final UnivariateRealFunction f, - final double min, final double max) - throws ConvergenceException, MathUserException { - - // check function type - if (!(f instanceof PolynomialFunction)) { - throw MathRuntimeException.createIllegalArgumentException(LocalizedFormats.FUNCTION_NOT_POLYNOMIAL); - } - - // check for zeros before verifying bracketing - if (f.value(min) == 0.0) { return min; } - if (f.value(max) == 0.0) { return max; } - verifyBracketing(min, max, f); - - double coefficients[] = ((PolynomialFunction) f).getCoefficients(); + public double laguerre(double lo, double hi, + double fLo, double fHi) { + double result = Double.NaN; + double coefficients[] = getCoefficients(); Complex c[] = new Complex[coefficients.length]; for (int i = 0; i < coefficients.length; i++) { - c[i] = new Complex(coefficients[i], 0.0); + c[i] = new Complex(coefficients[i], 0); } - Complex initial = new Complex(0.5 * (min + max), 0.0); - Complex z = solve(c, initial); - if (isRootOK(min, max, z)) { - setResult(z.getReal(), iterationCount); - return result; - } - - // solve all roots and select the one we're seeking - Complex[] root = solveAll(c, initial); - for (int i = 0; i < root.length; i++) { - if (isRootOK(min, max, root[i])) { - setResult(root[i].getReal(), iterationCount); - return result; + Complex initial = new Complex(0.5 * (lo + hi), 0); + Complex z = complexSolver.solve(c, initial); + if (complexSolver.isRoot(lo, hi, z)) { + return z.getReal(); + } else { + double r = Double.NaN; + // Solve all roots and select the one we are seeking. + Complex[] root = complexSolver.solveAll(c, initial); + for (int i = 0; i < root.length; i++) { + if (complexSolver.isRoot(lo, hi, root[i])) { + r = root[i].getReal(); + break; + } } + return r; } - - // should never happen - throw new ConvergenceException(); } /** - * Returns true iff the given complex root is actually a real zero - * in the given interval, within the solver tolerance level. - * - * @param min the lower bound for the interval - * @param max the upper bound for the interval - * @param z the complex root - * @return true iff z is the sought-after real zero + * Class for searching all (complex) roots. */ - protected boolean isRootOK(double min, double max, Complex z) { - double tolerance = FastMath.max(relativeAccuracy * z.abs(), absoluteAccuracy); - return (isSequence(min, z.getReal(), max)) && - (FastMath.abs(z.getImaginary()) <= tolerance || - z.abs() <= functionValueAccuracy); - } - - /** - * Find all complex roots for the polynomial with the given coefficients, - * starting from the given initial value. - * - * @param coefficients the polynomial coefficients array - * @param initial the start value to use - * @return the point at which the function value is zero - * @throws ConvergenceException if the maximum iteration count is exceeded - * or the solver detects convergence problems otherwise - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if any parameters are invalid - */ - public Complex[] solveAll(double coefficients[], double initial) throws - ConvergenceException, MathUserException { - - Complex c[] = new Complex[coefficients.length]; - Complex z = new Complex(initial, 0.0); - for (int i = 0; i < c.length; i++) { - c[i] = new Complex(coefficients[i], 0.0); - } - return solveAll(c, z); - } - - /** - * Find all complex roots for the polynomial with the given coefficients, - * starting from the given initial value. - * - * @param coefficients the polynomial coefficients array - * @param initial the start value to use - * @return the point at which the function value is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * or the solver detects convergence problems otherwise - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if any parameters are invalid - */ - public Complex[] solveAll(Complex coefficients[], Complex initial) throws - MaxIterationsExceededException, MathUserException { - - int n = coefficients.length - 1; - int iterationCount = 0; - if (n < 1) { - throw MathRuntimeException.createIllegalArgumentException( - LocalizedFormats.NON_POSITIVE_POLYNOMIAL_DEGREE, n); - } - Complex c[] = new Complex[n+1]; // coefficients for deflated polynomial - for (int i = 0; i <= n; i++) { - c[i] = coefficients[i]; + private class ComplexSolver { + /** + * Check whether the given complex root is actually a real zero + * in the given interval, within the solver tolerance level. + * + * @param min Lower bound for the interval. + * @param max Upper bound for the interval. + * @param z Complex root. + * @return {@code true} if z is a real zero. + */ + public boolean isRoot(double min, double max, Complex z) { + double tolerance = FastMath.max(getRelativeAccuracy() * z.abs(), getAbsoluteAccuracy()); + return (isSequence(min, z.getReal(), max)) && + (FastMath.abs(z.getImaginary()) <= tolerance || + z.abs() <= getFunctionValueAccuracy()); } - // solve individual root successively - Complex root[] = new Complex[n]; - for (int i = 0; i < n; i++) { - Complex subarray[] = new Complex[n-i+1]; - System.arraycopy(c, 0, subarray, 0, subarray.length); - root[i] = solve(subarray, initial); - // polynomial deflation using synthetic division - Complex newc = c[n-i]; - Complex oldc = null; - for (int j = n-i-1; j >= 0; j--) { - oldc = c[j]; - c[j] = newc; - newc = oldc.add(newc.multiply(root[i])); + /** + * Find all complex roots for the polynomial with the given + * coefficients, starting from the given initial value. + * + * @param coefficients Polynomial coefficients. + * @param initial Start value. + * @return the point at which the function value is zero. + * @throws org.apache.commons.math.exception.TooManyEvaluationsException + * if the maximum number of evaluations is exceeded. + * @throws NullArgumentException if the {@code coefficients} is + * {@code null}. + * @throws NoDataException if the {@code coefficients} array is empty. + */ + public Complex[] solveAll(double coefficients[], double initial) { + if (coefficients == null) { + throw new NullArgumentException(); } - iterationCount += this.iterationCount; + Complex c[] = new Complex[coefficients.length]; + Complex z = new Complex(initial, 0); + for (int i = 0; i < c.length; i++) { + c[i] = new Complex(coefficients[i], 0); + } + return solveAll(c, z); } - resultComputed = true; - this.iterationCount = iterationCount; - return root; - } - - /** - * Find a complex root for the polynomial with the given coefficients, - * starting from the given initial value. - * - * @param coefficients the polynomial coefficients array - * @param initial the start value to use - * @return the point at which the function value is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * or the solver detects convergence problems otherwise - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if any parameters are invalid - */ - public Complex solve(Complex coefficients[], Complex initial) throws - MaxIterationsExceededException, MathUserException { - - int n = coefficients.length - 1; - if (n < 1) { - throw MathRuntimeException.createIllegalArgumentException( - LocalizedFormats.NON_POSITIVE_POLYNOMIAL_DEGREE, n); - } - Complex N = new Complex(n, 0.0); - Complex N1 = new Complex(n - 1, 0.0); - - int i = 1; - Complex pv = null; - Complex dv = null; - Complex d2v = null; - Complex G = null; - Complex G2 = null; - Complex H = null; - Complex delta = null; - Complex denominator = null; - Complex z = initial; - Complex oldz = new Complex(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY); - while (i <= maximalIterationCount) { - // Compute pv (polynomial value), dv (derivative value), and - // d2v (second derivative value) simultaneously. - pv = coefficients[n]; - dv = Complex.ZERO; - d2v = Complex.ZERO; - for (int j = n-1; j >= 0; j--) { - d2v = dv.add(z.multiply(d2v)); - dv = pv.add(z.multiply(dv)); - pv = coefficients[j].add(z.multiply(pv)); + /** + * Find all complex roots for the polynomial with the given + * coefficients, starting from the given initial value. + * + * @param coefficients Polynomial coefficients. + * @param initial Start value. + * @return the point at which the function value is zero. + * @throws org.apache.commons.math.exception.TooManyEvaluationsException + * if the maximum number of evaluations is exceeded. + * @throws NullArgumentException if the {@code coefficients} is + * {@code null}. + * @throws NoDataException if the {@code coefficients} array is empty. + */ + public Complex[] solveAll(Complex coefficients[], Complex initial) { + if (coefficients == null) { + throw new NullArgumentException(); } - d2v = d2v.multiply(new Complex(2.0, 0.0)); - - // check for convergence - double tolerance = FastMath.max(relativeAccuracy * z.abs(), - absoluteAccuracy); - if ((z.subtract(oldz)).abs() <= tolerance) { - resultComputed = true; - iterationCount = i; - return z; + int n = coefficients.length - 1; + if (n == 0) { + throw new NoDataException(LocalizedFormats.POLYNOMIAL); } - if (pv.abs() <= functionValueAccuracy) { - resultComputed = true; - iterationCount = i; - return z; + // Coefficients for deflated polynomial. + Complex c[] = new Complex[n + 1]; + for (int i = 0; i <= n; i++) { + c[i] = coefficients[i]; } - // now pv != 0, calculate the new approximation - G = dv.divide(pv); - G2 = G.multiply(G); - H = G2.subtract(d2v.divide(pv)); - delta = N1.multiply((N.multiply(H)).subtract(G2)); - // choose a denominator larger in magnitude - Complex deltaSqrt = delta.sqrt(); - Complex dplus = G.add(deltaSqrt); - Complex dminus = G.subtract(deltaSqrt); - denominator = dplus.abs() > dminus.abs() ? dplus : dminus; - // Perturb z if denominator is zero, for instance, - // p(x) = x^3 + 1, z = 0. - if (denominator.equals(new Complex(0.0, 0.0))) { - z = z.add(new Complex(absoluteAccuracy, absoluteAccuracy)); - oldz = new Complex(Double.POSITIVE_INFINITY, - Double.POSITIVE_INFINITY); - } else { - oldz = z; - z = z.subtract(N.divide(denominator)); + // Solve individual roots successively. + Complex root[] = new Complex[n]; + for (int i = 0; i < n; i++) { + Complex subarray[] = new Complex[n - i + 1]; + System.arraycopy(c, 0, subarray, 0, subarray.length); + root[i] = solve(subarray, initial); + // Polynomial deflation using synthetic division. + Complex newc = c[n - i]; + Complex oldc = null; + for (int j = n - i - 1; j >= 0; j--) { + oldc = c[j]; + c[j] = newc; + newc = oldc.add(newc.multiply(root[i])); + } + } + + return root; + } + + /** + * Find a complex root for the polynomial with the given coefficients, + * starting from the given initial value. + * + * @param coefficients Polynomial coefficients. + * @param initial Start value. + * @return the point at which the function value is zero. + * @throws org.apache.commons.math.exception.TooManyEvaluationsException + * if the maximum number of evaluations is exceeded. + * @throws NullArgumentException if the {@code coefficients} is + * {@code null}. + * @throws NoDataException if the {@code coefficients} array is empty. + */ + public Complex solve(Complex coefficients[], Complex initial) { + if (coefficients == null) { + throw new NullArgumentException(); + } + + int n = coefficients.length - 1; + if (n == 0) { + throw new NoDataException(LocalizedFormats.POLYNOMIAL); + } + + final double absoluteAccuracy = getAbsoluteAccuracy(); + final double relativeAccuracy = getRelativeAccuracy(); + final double functionValueAccuracy = getFunctionValueAccuracy(); + + Complex N = new Complex(n, 0.0); + Complex N1 = new Complex(n - 1, 0.0); + + Complex pv = null; + Complex dv = null; + Complex d2v = null; + Complex G = null; + Complex G2 = null; + Complex H = null; + Complex delta = null; + Complex denominator = null; + Complex z = initial; + Complex oldz = new Complex(Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY); + while (true) { + // Compute pv (polynomial value), dv (derivative value), and + // d2v (second derivative value) simultaneously. + pv = coefficients[n]; + dv = Complex.ZERO; + d2v = Complex.ZERO; + for (int j = n-1; j >= 0; j--) { + d2v = dv.add(z.multiply(d2v)); + dv = pv.add(z.multiply(dv)); + pv = coefficients[j].add(z.multiply(pv)); + } + d2v = d2v.multiply(new Complex(2.0, 0.0)); + + // check for convergence + double tolerance = FastMath.max(relativeAccuracy * z.abs(), + absoluteAccuracy); + if ((z.subtract(oldz)).abs() <= tolerance) { + return z; + } + if (pv.abs() <= functionValueAccuracy) { + return z; + } + + // now pv != 0, calculate the new approximation + G = dv.divide(pv); + G2 = G.multiply(G); + H = G2.subtract(d2v.divide(pv)); + delta = N1.multiply((N.multiply(H)).subtract(G2)); + // choose a denominator larger in magnitude + Complex deltaSqrt = delta.sqrt(); + Complex dplus = G.add(deltaSqrt); + Complex dminus = G.subtract(deltaSqrt); + denominator = dplus.abs() > dminus.abs() ? dplus : dminus; + // Perturb z if denominator is zero, for instance, + // p(x) = x^3 + 1, z = 0. + if (denominator.equals(new Complex(0.0, 0.0))) { + z = z.add(new Complex(absoluteAccuracy, absoluteAccuracy)); + oldz = new Complex(Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY); + } else { + oldz = z; + z = z.subtract(N.divide(denominator)); + } + incrementEvaluationCount(); } - i++; } - throw new MaxIterationsExceededException(maximalIterationCount); } } diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/MullerSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/MullerSolver.java index 6dd994723..232272d14 100644 --- a/src/main/java/org/apache/commons/math/analysis/solvers/MullerSolver.java +++ b/src/main/java/org/apache/commons/math/analysis/solvers/MullerSolver.java @@ -16,95 +16,116 @@ */ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.ConvergenceException; -import org.apache.commons.math.MaxIterationsExceededException; -import org.apache.commons.math.analysis.UnivariateRealFunction; -import org.apache.commons.math.exception.MathUserException; import org.apache.commons.math.util.FastMath; import org.apache.commons.math.util.MathUtils; /** - * Implements the + * This class implements the * Muller's Method for root finding of real univariate functions. For * reference, see Elementary Numerical Analysis, ISBN 0070124477, * chapter 3. *

* Muller's method applies to both real and complex functions, but here we - * restrict ourselves to real functions. Methods solve() and solve2() find - * real zeros, using different ways to bypass complex arithmetics.

+ * restrict ourselves to real functions. + * This class differs from {@link MullerSolver} in the way it avoids complex + * operations.

+ * Muller's original method would have function evaluation at complex point. + * Since our f(x) is real, we have to find ways to avoid that. Bracketing + * condition is one way to go: by requiring bracketing in every iteration, + * the newly computed approximation is guaranteed to be real.

+ *

+ * Normally Muller's method converges quadratically in the vicinity of a + * zero, however it may be very slow in regions far away from zeros. For + * example, f(x) = exp(x) - 1, min = -50, max = 100. In such case we use + * bisection as a safety backup if it performs very poorly.

+ *

+ * The formulas here use divided differences directly.

* * @version $Revision$ $Date$ * @since 1.2 + * @see MullerSolver2 */ -public class MullerSolver extends UnivariateRealSolverImpl { +public class MullerSolver extends AbstractUnivariateRealSolver { + /** Serializable version identifier */ + private static final long serialVersionUID = 7694577816772532779L; + /** Default absolute accuracy. */ + public static final double DEFAULT_ABSOLUTE_ACCURACY = 1e-6; /** - * Construct a solver. + * Construct a solver with default accuracies. */ public MullerSolver() { - super(100, 1E-6); + this(DEFAULT_ABSOLUTE_ACCURACY); + } + /** + * Construct a solver. + * + * @param absoluteAccuracy Absolute accuracy. + */ + public MullerSolver(double absoluteAccuracy) { + super(absoluteAccuracy); + } + /** + * Construct a solver. + * + * @param relativeAccuracy Relative accuracy. + * @param absoluteAccuracy Absolute accuracy. + */ + public MullerSolver(double relativeAccuracy, + double absoluteAccuracy) { + super(relativeAccuracy, absoluteAccuracy); } /** - * Find a real root in the given interval with initial value. - *

- * Requires bracketing condition.

- * - * @param f the function to solve - * @param min the lower bound for the interval - * @param max the upper bound for the interval - * @param initial the start value to use - * @return the point at which the function value is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * or the solver detects convergence problems otherwise - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if any parameters are invalid + * {@inheritDoc} */ - public double solve(final UnivariateRealFunction f, - final double min, final double max, final double initial) - throws MaxIterationsExceededException, MathUserException { + @Override + protected double doSolve() { + final double min = getMin(); + final double max = getMax(); + final double initial = getStartValue(); + + final double functionValueAccuracy = getFunctionValueAccuracy(); + + verifySequence(min, initial, max); // check for zeros before verifying bracketing - if (f.value(min) == 0.0) { return min; } - if (f.value(max) == 0.0) { return max; } - if (f.value(initial) == 0.0) { return initial; } + final double fMin = computeObjectiveValue(min); + if (FastMath.abs(fMin) < functionValueAccuracy) { + return min; + } + final double fMax = computeObjectiveValue(max); + if (FastMath.abs(fMax) < functionValueAccuracy) { + return max; + } + final double fInitial = computeObjectiveValue(initial); + if (FastMath.abs(fInitial) < functionValueAccuracy) { + return initial; + } - verifyBracketing(min, max, f); - verifySequence(min, initial, max); - if (isBracketing(min, initial, f)) { - return solve(f, min, initial); + verifyBracketing(min, max); + + if (isBracketing(min, initial)) { + return solve(min, initial, fMin, fInitial); } else { - return solve(f, initial, max); + return solve(initial, max, fInitial, fMax); } } /** * Find a real root in the given interval. - *

- * Original Muller's method would have function evaluation at complex point. - * Since our f(x) is real, we have to find ways to avoid that. Bracketing - * condition is one way to go: by requiring bracketing in every iteration, - * the newly computed approximation is guaranteed to be real.

- *

- * Normally Muller's method converges quadratically in the vicinity of a - * zero, however it may be very slow in regions far away from zeros. For - * example, f(x) = exp(x) - 1, min = -50, max = 100. In such case we use - * bisection as a safety backup if it performs very poorly.

- *

- * The formulas here use divided differences directly.

* - * @param f the function to solve - * @param min the lower bound for the interval - * @param max the upper bound for the interval - * @return the point at which the function value is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * or the solver detects convergence problems otherwise - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if any parameters are invalid + * @param min Lower bound for the interval. + * @param max Upper bound for the interval. + * @param fMin function value at the lower bound. + * @param fMax function value at the upper bound. + * @return the point at which the function value is zero. */ - public double solve(final UnivariateRealFunction f, - final double min, final double max) - throws MaxIterationsExceededException, MathUserException { + private double solve(double min, double max, + double fMin, double fMax) { + final double relativeAccuracy = getRelativeAccuracy(); + final double absoluteAccuracy = getAbsoluteAccuracy(); + final double functionValueAccuracy = getFunctionValueAccuracy(); // [x0, x2] is the bracketing interval in each iteration // x1 is the last approximation and an interpolation point in (x0, x2) @@ -112,23 +133,14 @@ public class MullerSolver extends UnivariateRealSolverImpl { // d01, d12, d012 are divided differences double x0 = min; - double y0 = f.value(x0); + double y0 = fMin; double x2 = max; - double y2 = f.value(x2); + double y2 = fMax; double x1 = 0.5 * (x0 + x2); - double y1 = f.value(x1); - - // check for zeros before verifying bracketing - if (y0 == 0.0) { - return min; - } - if (y2 == 0.0) { - return max; - } - verifyBracketing(min, max, f); + double y1 = computeObjectiveValue(x1); double oldx = Double.POSITIVE_INFINITY; - for (int i = 1; i <= maximalIterationCount; ++i) { + while (true) { // Muller's method employs quadratic interpolation through // x0, x1, x2 and x is the zero of the interpolating parabola. // Due to bracketing condition, this parabola must have two @@ -143,17 +155,13 @@ public class MullerSolver extends UnivariateRealSolverImpl { // xplus and xminus are two roots of parabola and at least // one of them should lie in (x0, x2) final double x = isSequence(x0, xplus, x2) ? xplus : xminus; - final double y = f.value(x); + final double y = computeObjectiveValue(x); // check for convergence final double tolerance = FastMath.max(relativeAccuracy * FastMath.abs(x), absoluteAccuracy); - if (FastMath.abs(x - oldx) <= tolerance) { - setResult(x, i); - return result; - } - if (FastMath.abs(y) <= functionValueAccuracy) { - setResult(x, i); - return result; + if (FastMath.abs(x - oldx) <= tolerance || + FastMath.abs(y) <= functionValueAccuracy) { + return x; } // Bisect if convergence is too slow. Bisection would waste @@ -173,118 +181,16 @@ public class MullerSolver extends UnivariateRealSolverImpl { oldx = x; } else { double xm = 0.5 * (x0 + x2); - double ym = f.value(xm); + double ym = computeObjectiveValue(xm); if (MathUtils.sign(y0) + MathUtils.sign(ym) == 0.0) { x2 = xm; y2 = ym; } else { x0 = xm; y0 = ym; } x1 = 0.5 * (x0 + x2); - y1 = f.value(x1); + y1 = computeObjectiveValue(x1); oldx = Double.POSITIVE_INFINITY; } } - throw new MaxIterationsExceededException(maximalIterationCount); - } - - /** - * Find a real root in the given interval. - *

- * solve2() differs from solve() in the way it avoids complex operations. - * Except for the initial [min, max], solve2() does not require bracketing - * condition, e.g. f(x0), f(x1), f(x2) can have the same sign. If complex - * number arises in the computation, we simply use its modulus as real - * approximation.

- *

- * Because the interval may not be bracketing, bisection alternative is - * not applicable here. However in practice our treatment usually works - * well, especially near real zeros where the imaginary part of complex - * approximation is often negligible.

- *

- * The formulas here do not use divided differences directly.

- * - * @param f the function to solve - * @param min the lower bound for the interval - * @param max the upper bound for the interval - * @return the point at which the function value is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * or the solver detects convergence problems otherwise - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if any parameters are invalid - */ - public double solve2(final UnivariateRealFunction f, - final double min, final double max) - throws MaxIterationsExceededException, MathUserException { - - // x2 is the last root approximation - // x is the new approximation and new x2 for next round - // x0 < x1 < x2 does not hold here - - double x0 = min; - double y0 = f.value(x0); - double x1 = max; - double y1 = f.value(x1); - double x2 = 0.5 * (x0 + x1); - double y2 = f.value(x2); - - // check for zeros before verifying bracketing - if (y0 == 0.0) { return min; } - if (y1 == 0.0) { return max; } - verifyBracketing(min, max, f); - - double oldx = Double.POSITIVE_INFINITY; - for (int i = 1; i <= maximalIterationCount; ++i) { - // quadratic interpolation through x0, x1, x2 - final double q = (x2 - x1) / (x1 - x0); - final double a = q * (y2 - (1 + q) * y1 + q * y0); - final double b = (2 * q + 1) * y2 - (1 + q) * (1 + q) * y1 + q * q * y0; - final double c = (1 + q) * y2; - final double delta = b * b - 4 * a * c; - double x; - final double denominator; - if (delta >= 0.0) { - // choose a denominator larger in magnitude - double dplus = b + FastMath.sqrt(delta); - double dminus = b - FastMath.sqrt(delta); - denominator = FastMath.abs(dplus) > FastMath.abs(dminus) ? dplus : dminus; - } else { - // take the modulus of (B +/- FastMath.sqrt(delta)) - denominator = FastMath.sqrt(b * b - delta); - } - if (denominator != 0) { - x = x2 - 2.0 * c * (x2 - x1) / denominator; - // perturb x if it exactly coincides with x1 or x2 - // the equality tests here are intentional - while (x == x1 || x == x2) { - x += absoluteAccuracy; - } - } else { - // extremely rare case, get a random number to skip it - x = min + FastMath.random() * (max - min); - oldx = Double.POSITIVE_INFINITY; - } - final double y = f.value(x); - - // check for convergence - final double tolerance = FastMath.max(relativeAccuracy * FastMath.abs(x), absoluteAccuracy); - if (FastMath.abs(x - oldx) <= tolerance) { - setResult(x, i); - return result; - } - if (FastMath.abs(y) <= functionValueAccuracy) { - setResult(x, i); - return result; - } - - // prepare the next iteration - x0 = x1; - y0 = y1; - x1 = x2; - y1 = y2; - x2 = x; - y2 = y; - oldx = x; - } - throw new MaxIterationsExceededException(maximalIterationCount); } } diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/MullerSolver2.java b/src/main/java/org/apache/commons/math/analysis/solvers/MullerSolver2.java new file mode 100644 index 000000000..b76c86d99 --- /dev/null +++ b/src/main/java/org/apache/commons/math/analysis/solvers/MullerSolver2.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.analysis.solvers; + +import org.apache.commons.math.exception.NoBracketingException; +import org.apache.commons.math.util.FastMath; + +/** + * This class implements the + * Muller's Method for root finding of real univariate functions. For + * reference, see Elementary Numerical Analysis, ISBN 0070124477, + * chapter 3. + *

+ * Muller's method applies to both real and complex functions, but here we + * restrict ourselves to real functions.< + * This class differs from {@link MullerSolver} in the way it avoids complex + * operations.

+ * Except for the initial [min, max], it does not require bracketing + * condition, e.g. f(x0), f(x1), f(x2) can have the same sign. If complex + * number arises in the computation, we simply use its modulus as real + * approximation.

+ *

+ * Because the interval may not be bracketing, bisection alternative is + * not applicable here. However in practice our treatment usually works + * well, especially near real zeroes where the imaginary part of complex + * approximation is often negligible.

+ *

+ * The formulas here do not use divided differences directly.

+ * + * @version $Revision: 1034896 $ $Date: 2010-11-13 23:27:34 +0100 (Sat, 13 Nov 2010) $ + * @since 1.2 + * @see MullerSolver + */ +public class MullerSolver2 extends AbstractUnivariateRealSolver { + /** Serializable version identifier */ + private static final long serialVersionUID = 7694577816772532779L; + /** Default absolute accuracy. */ + public static final double DEFAULT_ABSOLUTE_ACCURACY = 1e-6; + + /** + * Construct a solver with default accuracies. + */ + public MullerSolver2() { + this(DEFAULT_ABSOLUTE_ACCURACY); + } + /** + * Construct a solver. + * + * @param absoluteAccuracy Absolute accuracy. + */ + public MullerSolver2(double absoluteAccuracy) { + super(absoluteAccuracy); + } + /** + * Construct a solver. + * + * @param relativeAccuracy Relative accuracy. + * @param absoluteAccuracy Absolute accuracy. + */ + public MullerSolver2(double relativeAccuracy, + double absoluteAccuracy) { + super(relativeAccuracy, absoluteAccuracy); + } + + /** + * {@inheritDoc} + */ + @Override + protected double doSolve() { + final double min = getMin(); + final double max = getMax(); + + verifyInterval(min, max); + + final double relativeAccuracy = getRelativeAccuracy(); + final double absoluteAccuracy = getAbsoluteAccuracy(); + final double functionValueAccuracy = getFunctionValueAccuracy(); + + // x2 is the last root approximation + // x is the new approximation and new x2 for next round + // x0 < x1 < x2 does not hold here + + double x0 = min; + double y0 = computeObjectiveValue(x0); + if (FastMath.abs(y0) < functionValueAccuracy) { + return x0; + } + double x1 = max; + double y1 = computeObjectiveValue(x1); + if (FastMath.abs(y1) < functionValueAccuracy) { + return x1; + } + + if(y0 * y1 > 0) { + throw new NoBracketingException(x0, x1, y0, y1); + } + + double x2 = 0.5 * (x0 + x1); + double y2 = computeObjectiveValue(x2); + + double oldx = Double.POSITIVE_INFINITY; + while (true) { + // quadratic interpolation through x0, x1, x2 + final double q = (x2 - x1) / (x1 - x0); + final double a = q * (y2 - (1 + q) * y1 + q * y0); + final double b = (2 * q + 1) * y2 - (1 + q) * (1 + q) * y1 + q * q * y0; + final double c = (1 + q) * y2; + final double delta = b * b - 4 * a * c; + double x; + final double denominator; + if (delta >= 0.0) { + // choose a denominator larger in magnitude + double dplus = b + FastMath.sqrt(delta); + double dminus = b - FastMath.sqrt(delta); + denominator = FastMath.abs(dplus) > FastMath.abs(dminus) ? dplus : dminus; + } else { + // take the modulus of (B +/- FastMath.sqrt(delta)) + denominator = FastMath.sqrt(b * b - delta); + } + if (denominator != 0) { + x = x2 - 2.0 * c * (x2 - x1) / denominator; + // perturb x if it exactly coincides with x1 or x2 + // the equality tests here are intentional + while (x == x1 || x == x2) { + x += absoluteAccuracy; + } + } else { + // extremely rare case, get a random number to skip it + x = min + FastMath.random() * (max - min); + oldx = Double.POSITIVE_INFINITY; + } + final double y = computeObjectiveValue(x); + + // check for convergence + final double tolerance = FastMath.max(relativeAccuracy * FastMath.abs(x), absoluteAccuracy); + if (FastMath.abs(x - oldx) <= tolerance || + FastMath.abs(y) <= functionValueAccuracy) { + return x; + } + + // prepare the next iteration + x0 = x1; + y0 = y1; + x1 = x2; + y1 = y2; + x2 = x; + y2 = y; + oldx = x; + } + } +} diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/NewtonSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/NewtonSolver.java index ed34f2efc..563bce114 100644 --- a/src/main/java/org/apache/commons/math/analysis/solvers/NewtonSolver.java +++ b/src/main/java/org/apache/commons/math/analysis/solvers/NewtonSolver.java @@ -17,12 +17,7 @@ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.exception.MathUserException; -import org.apache.commons.math.MathRuntimeException; -import org.apache.commons.math.MaxIterationsExceededException; import org.apache.commons.math.analysis.DifferentiableUnivariateRealFunction; -import org.apache.commons.math.analysis.UnivariateRealFunction; -import org.apache.commons.math.exception.util.LocalizedFormats; import org.apache.commons.math.util.FastMath; /** @@ -33,76 +28,58 @@ import org.apache.commons.math.util.FastMath; * * @version $Revision$ $Date$ */ -public class NewtonSolver extends UnivariateRealSolverImpl { +public class NewtonSolver extends AbstractDifferentiableUnivariateRealSolver { + /** Default absolute accuracy. */ + public static final double DEFAULT_ABSOLUTE_ACCURACY = 1e-6; /** * Construct a solver. */ public NewtonSolver() { - super(100, 1E-6); + this(DEFAULT_ABSOLUTE_ACCURACY); + } + /** + * Construct a solver. + * + * @param absoluteAccuracy Absolute accuracy. + */ + public NewtonSolver(double absoluteAccuracy) { + super(absoluteAccuracy); } /** - * Find a zero near the midpoint of min and max. + * Find a zero near the midpoint of {@code min} and {@code max}. * - * @param f the function to solve - * @param min the lower bound for the interval - * @param max the upper bound for the interval - * @return the value where the function is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * @throws MathUserException if an error occurs evaluating the function or derivative - * @throws IllegalArgumentException if min is not less than max + * @param f Function to solve. + * @param min Lower bound for the interval? + * @param max Upper bound for the interval. + * @return the value where the function is zero. + * @throws org.apache.commons.math.exception.TooManyEvaluationsException + * if the maximum evaluation count is exceeded. + * @throws IllegalArgumentException if {@code min >= max}. */ - public double solve(final UnivariateRealFunction f, - final double min, final double max) - throws MaxIterationsExceededException, MathUserException { - return solve(f, min, max, UnivariateRealSolverUtils.midpoint(min, max)); + public double solve(final DifferentiableUnivariateRealFunction f, + final double min, final double max) { + return super.solve(f, UnivariateRealSolverUtils.midpoint(min, max)); } /** - * Find a zero near the value startValue. - * - * @param f the function to solve - * @param min the lower bound for the interval (ignored). - * @param max the upper bound for the interval (ignored). - * @param startValue the start value to use. - * @return the value where the function is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * @throws MathUserException if an error occurs evaluating the function or derivative - * @throws IllegalArgumentException if startValue is not between min and max or - * if function is not a {@link DifferentiableUnivariateRealFunction} instance + * {@inheritDoc} */ - public double solve(final UnivariateRealFunction f, - final double min, final double max, final double startValue) - throws MaxIterationsExceededException, MathUserException { + @Override + protected double doSolve() { + final double startValue = getStartValue(); + final double absoluteAccuracy = getAbsoluteAccuracy(); - try { - - final UnivariateRealFunction derivative = - ((DifferentiableUnivariateRealFunction) f).derivative(); - clearResult(); - verifySequence(min, startValue, max); - - double x0 = startValue; - double x1; - - int i = 0; - while (i < maximalIterationCount) { - - x1 = x0 - (f.value(x0) / derivative.value(x0)); - if (FastMath.abs(x1 - x0) <= absoluteAccuracy) { - setResult(x1, i); - return x1; - } - - x0 = x1; - ++i; + double x0 = startValue; + double x1; + while (true) { + x1 = x0 - (computeObjectiveValue(x0) / computeDerivativeObjectiveValue(x0)); + if (FastMath.abs(x1 - x0) <= absoluteAccuracy) { + return x1; } - throw new MaxIterationsExceededException(maximalIterationCount); - } catch (ClassCastException cce) { - throw MathRuntimeException.createIllegalArgumentException(LocalizedFormats.FUNCTION_NOT_DIFFERENTIABLE); + x0 = x1; } } - } diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/PolynomialSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/PolynomialSolver.java new file mode 100644 index 000000000..8bf8827c0 --- /dev/null +++ b/src/main/java/org/apache/commons/math/analysis/solvers/PolynomialSolver.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.analysis.solvers; + +import org.apache.commons.math.analysis.polynomials.PolynomialFunction; + +/** + * Interface for (polynomial) root-finding algorithms. + * Implementations will search for only one zero in the given interval. + * + * @version $Revision$ $Date$ + * @since 3.0 + */ +public interface PolynomialSolver + extends BaseUnivariateRealSolver {} diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/RiddersSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/RiddersSolver.java index 551ab596b..6ae8621e5 100644 --- a/src/main/java/org/apache/commons/math/analysis/solvers/RiddersSolver.java +++ b/src/main/java/org/apache/commons/math/analysis/solvers/RiddersSolver.java @@ -16,10 +16,6 @@ */ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.ConvergenceException; -import org.apache.commons.math.exception.MathUserException; -import org.apache.commons.math.MaxIterationsExceededException; -import org.apache.commons.math.analysis.UnivariateRealFunction; import org.apache.commons.math.util.FastMath; import org.apache.commons.math.util.MathUtils; @@ -35,106 +31,84 @@ import org.apache.commons.math.util.MathUtils; * @version $Revision$ $Date$ * @since 1.2 */ -public class RiddersSolver extends UnivariateRealSolverImpl { +public class RiddersSolver extends AbstractUnivariateRealSolver { + /** Default absolute accuracy. */ + public static final double DEFAULT_ABSOLUTE_ACCURACY = 1e-6; /** - * Construct a solver. + * Construct a solver with default accuracy. */ public RiddersSolver() { - super(100, 1E-6); + this(DEFAULT_ABSOLUTE_ACCURACY); + } + /** + * Construct a solver. + * + * @param absoluteAccuracy Absolute accuracy. + */ + public RiddersSolver(double absoluteAccuracy) { + super(absoluteAccuracy); + } + /** + * Construct a solver. + * + * @param relativeAccuracy Relative accuracy. + * @param absoluteAccuracy Absolute accuracy. + */ + public RiddersSolver(double relativeAccuracy, + double absoluteAccuracy) { + super(relativeAccuracy, absoluteAccuracy); } /** - * Find a root in the given interval with initial value. - *

- * Requires bracketing condition.

- * - * @param f the function to solve - * @param min the lower bound for the interval - * @param max the upper bound for the interval - * @param initial the start value to use - * @return the point at which the function value is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if any parameters are invalid + * {@inheritDoc} */ - public double solve(final UnivariateRealFunction f, - final double min, final double max, final double initial) - throws MaxIterationsExceededException, MathUserException { - - // check for zeros before verifying bracketing - if (f.value(min) == 0.0) { return min; } - if (f.value(max) == 0.0) { return max; } - if (f.value(initial) == 0.0) { return initial; } - - verifyBracketing(min, max, f); - verifySequence(min, initial, max); - if (isBracketing(min, initial, f)) { - return solve(f, min, initial); - } else { - return solve(f, initial, max); - } - } - - /** - * Find a root in the given interval. - *

- * Requires bracketing condition.

- * - * @param f the function to solve - * @param min the lower bound for the interval - * @param max the upper bound for the interval - * @return the point at which the function value is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if any parameters are invalid - */ - public double solve(final UnivariateRealFunction f, - final double min, final double max) - throws MaxIterationsExceededException, MathUserException { - + @Override + protected double doSolve() { + double min = getMin(); + double max = getMax(); // [x1, x2] is the bracketing interval in each iteration // x3 is the midpoint of [x1, x2] // x is the new root approximation and an endpoint of the new interval double x1 = min; - double y1 = f.value(x1); + double y1 = computeObjectiveValue(x1); double x2 = max; - double y2 = f.value(x2); + double y2 = computeObjectiveValue(x2); // check for zeros before verifying bracketing - if (y1 == 0.0) { + if (y1 == 0) { return min; } - if (y2 == 0.0) { + if (y2 == 0) { return max; } - verifyBracketing(min, max, f); + verifyBracketing(min, max); + + final double absoluteAccuracy = getAbsoluteAccuracy(); + final double functionValueAccuracy = getFunctionValueAccuracy(); + final double relativeAccuracy = getRelativeAccuracy(); - int i = 1; double oldx = Double.POSITIVE_INFINITY; - while (i <= maximalIterationCount) { + while (true) { // calculate the new root approximation final double x3 = 0.5 * (x1 + x2); - final double y3 = f.value(x3); + final double y3 = computeObjectiveValue(x3); if (FastMath.abs(y3) <= functionValueAccuracy) { - setResult(x3, i); - return result; + return x3; } final double delta = 1 - (y1 * y2) / (y3 * y3); // delta > 1 due to bracketing final double correction = (MathUtils.sign(y2) * MathUtils.sign(y3)) * (x3 - x1) / FastMath.sqrt(delta); final double x = x3 - correction; // correction != 0 - final double y = f.value(x); + final double y = computeObjectiveValue(x); // check for convergence final double tolerance = FastMath.max(relativeAccuracy * FastMath.abs(x), absoluteAccuracy); if (FastMath.abs(x - oldx) <= tolerance) { - setResult(x, i); - return result; + return x; } if (FastMath.abs(y) <= functionValueAccuracy) { - setResult(x, i); - return result; + return x; } // prepare the new interval for next iteration @@ -161,8 +135,6 @@ public class RiddersSolver extends UnivariateRealSolverImpl { } } oldx = x; - i++; } - throw new MaxIterationsExceededException(maximalIterationCount); } } diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/SecantSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/SecantSolver.java index bf063a174..4770563fb 100644 --- a/src/main/java/org/apache/commons/math/analysis/solvers/SecantSolver.java +++ b/src/main/java/org/apache/commons/math/analysis/solvers/SecantSolver.java @@ -16,12 +16,7 @@ */ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.ConvergenceException; -import org.apache.commons.math.exception.MathUserException; -import org.apache.commons.math.MathRuntimeException; -import org.apache.commons.math.MaxIterationsExceededException; -import org.apache.commons.math.analysis.UnivariateRealFunction; -import org.apache.commons.math.exception.util.LocalizedFormats; +import org.apache.commons.math.exception.NoBracketingException; import org.apache.commons.math.util.FastMath; @@ -41,52 +36,46 @@ import org.apache.commons.math.util.FastMath; * * @version $Revision$ $Date$ */ -public class SecantSolver extends UnivariateRealSolverImpl { +public class SecantSolver extends AbstractUnivariateRealSolver { + /** Default absolute accuracy. */ + public static final double DEFAULT_ABSOLUTE_ACCURACY = 1e-6; /** - * Construct a solver. + * Construct a solver with default accuracy. */ public SecantSolver() { - super(100, 1E-6); + this(DEFAULT_ABSOLUTE_ACCURACY); } - /** - * Find a zero in the given interval. + * Construct a solver. * - * @param f the function to solve - * @param min the lower bound for the interval - * @param max the upper bound for the interval - * @param initial the start value to use (ignored) - * @return the value where the function is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if min is not less than max or the - * signs of the values of the function at the endpoints are not opposites + * @param absoluteAccuracy Absolute accuracy. */ - public double solve(final UnivariateRealFunction f, - final double min, final double max, final double initial) - throws MaxIterationsExceededException, MathUserException { - return solve(f, min, max); + public SecantSolver(double absoluteAccuracy) { + super(absoluteAccuracy); + } + /** + * Construct a solver. + * + * @param relativeAccuracy Relative accuracy. + * @param absoluteAccuracy Absolute accuracy. + */ + public SecantSolver(double relativeAccuracy, + double absoluteAccuracy) { + super(relativeAccuracy, absoluteAccuracy); } /** - * Find a zero in the given interval. - * @param f the function to solve - * @param min the lower bound for the interval. - * @param max the upper bound for the interval. - * @return the value where the function is zero - * @throws MaxIterationsExceededException if the maximum iteration count is exceeded - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if min is not less than max or the - * signs of the values of the function at the endpoints are not opposites + * {@inheritDoc} */ - public double solve(final UnivariateRealFunction f, - final double min, final double max) - throws MaxIterationsExceededException, MathUserException { - - clearResult(); + @Override + protected double doSolve() { + double min = getMin(); + double max = getMax(); verifyInterval(min, max); + final double functionValueAccuracy = getFunctionValueAccuracy(); + // Index 0 is the old approximation for the root. // Index 1 is the last calculated approximation for the root. // Index 2 is a bracket for the root with respect to x0. @@ -94,20 +83,31 @@ public class SecantSolver extends UnivariateRealSolverImpl { // iteration. double x0 = min; double x1 = max; - double y0 = f.value(x0); - double y1 = f.value(x1); + + double y0 = computeObjectiveValue(x0); + // return the first endpoint if it is good enough + if (FastMath.abs(y0) <= functionValueAccuracy) { + return x0; + } + + // return the second endpoint if it is good enough + double y1 = computeObjectiveValue(x1); + if (FastMath.abs(y1) <= functionValueAccuracy) { + return x1; + } // Verify bracketing if (y0 * y1 >= 0) { - throw MathRuntimeException.createIllegalArgumentException( - LocalizedFormats.SAME_SIGN_AT_ENDPOINTS, min, max, y0, y1); + throw new NoBracketingException(min, max, y0, y1); } + final double absoluteAccuracy = getAbsoluteAccuracy(); + final double relativeAccuracy = getRelativeAccuracy(); + double x2 = x0; double y2 = y0; double oldDelta = x2 - x1; - int i = 0; - while (i < maximalIterationCount) { + while (true) { if (FastMath.abs(y2) < FastMath.abs(y1)) { x0 = x1; x1 = x2; @@ -117,13 +117,11 @@ public class SecantSolver extends UnivariateRealSolverImpl { y2 = y0; } if (FastMath.abs(y1) <= functionValueAccuracy) { - setResult(x1, i); - return result; + return x1; } - if (FastMath.abs(oldDelta) < - FastMath.max(relativeAccuracy * FastMath.abs(x1), absoluteAccuracy)) { - setResult(x1, i); - return result; + if (FastMath.abs(oldDelta) < FastMath.max(relativeAccuracy * FastMath.abs(x1), + absoluteAccuracy)) { + return x1; } double delta; if (FastMath.abs(y1) > FastMath.abs(y0)) { @@ -140,16 +138,13 @@ public class SecantSolver extends UnivariateRealSolverImpl { x0 = x1; y0 = y1; x1 = x1 + delta; - y1 = f.value(x1); + y1 = computeObjectiveValue(x1); if ((y1 > 0) == (y2 > 0)) { // New bracket is (x0,x1). x2 = x0; y2 = y0; } oldDelta = x2 - x1; - i++; } - throw new MaxIterationsExceededException(maximalIterationCount); } - } diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolver.java b/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolver.java index 209098cfc..9f787a762 100644 --- a/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolver.java +++ b/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolver.java @@ -16,104 +16,14 @@ */ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.ConvergenceException; -import org.apache.commons.math.ConvergingAlgorithm; -import org.apache.commons.math.exception.MathUserException; import org.apache.commons.math.analysis.UnivariateRealFunction; /** - * Interface for (univariate real) rootfinding algorithms. - *

- * Implementations will search for only one zero in the given interval.

+ * Interface for (univariate real) root-finding algorithms. + * Implementations will search for only one zero in the given interval. * * @version $Revision$ $Date$ */ -public interface UnivariateRealSolver extends ConvergingAlgorithm { - - /** - * Set the function value accuracy. - *

- * This is used to determine when an evaluated function value or some other - * value which is used as divisor is zero.

- *

- * This is a safety guard and it shouldn't be necessary to change this in - * general.

- * - * @param accuracy the accuracy. - * @throws IllegalArgumentException if the accuracy can't be achieved by - * the solver or is otherwise deemed unreasonable. - */ - void setFunctionValueAccuracy(double accuracy); - - /** - * Get the actual function value accuracy. - * @return the accuracy - */ - double getFunctionValueAccuracy(); - - /** - * Reset the actual function accuracy to the default. - * The default value is provided by the solver implementation. - */ - void resetFunctionValueAccuracy(); - - /** - * Solve for a zero root in the given interval. - *

A solver may require that the interval brackets a single zero root. - * Solvers that do require bracketing should be able to handle the case - * where one of the endpoints is itself a root.

- * - * @param f the function to solve. - * @param min the lower bound for the interval. - * @param max the upper bound for the interval. - * @return a value where the function is zero - * @throws ConvergenceException if the maximum iteration count is exceeded - * or the solver detects convergence problems otherwise. - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if min > max or the endpoints do not - * satisfy the requirements specified by the solver - * @since 2.0 - */ - double solve(UnivariateRealFunction f, double min, double max) - throws ConvergenceException, MathUserException; - - /** - * Solve for a zero in the given interval, start at startValue. - *

A solver may require that the interval brackets a single zero root. - * Solvers that do require bracketing should be able to handle the case - * where one of the endpoints is itself a root.

- * - * @param f the function to solve. - * @param min the lower bound for the interval. - * @param max the upper bound for the interval. - * @param startValue the start value to use - * @return a value where the function is zero - * @throws ConvergenceException if the maximum iteration count is exceeded - * or the solver detects convergence problems otherwise. - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if min > max or the arguments do not - * satisfy the requirements specified by the solver - * @since 2.0 - */ - double solve(UnivariateRealFunction f, double min, double max, double startValue) - throws ConvergenceException, MathUserException, IllegalArgumentException; - - /** - * Get the result of the last run of the solver. - * - * @return the last result. - * @throws IllegalStateException if there is no result available, either - * because no result was yet computed or the last attempt failed. - */ - double getResult(); - - /** - * Get the result of the last run of the solver. - * - * @return the value of the function at the last result. - * @throws IllegalStateException if there is no result available, either - * because no result was yet computed or the last attempt failed. - */ - double getFunctionValue(); -} +public interface UnivariateRealSolver + extends BaseUnivariateRealSolver {} diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverFactory.java b/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverFactory.java deleted file mode 100644 index 2e7a8bed4..000000000 --- a/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverFactory.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.commons.math.analysis.solvers; - -/** - * Abstract factory class used to create {@link UnivariateRealSolver} instances. - *

- * Solvers implementing the following algorithms are supported: - *

    - *
  • Bisection
  • - *
  • Brent's method
  • - *
  • Secant method
  • - *
- * Concrete factories extending this class also specify a default solver, instances of which - * are returned by newDefaultSolver().

- *

- * Common usage:

- * SolverFactory factory = UnivariateRealSolverFactory.newInstance();

- * - * // create a Brent solver to use - * BrentSolver solver = factory.newBrentSolver(); - *
- * - * @version $Revision$ $Date$ - */ -public abstract class UnivariateRealSolverFactory { - /** - * Default constructor. - */ - protected UnivariateRealSolverFactory() { - } - - /** - * Create a new factory. - * @return a new factory. - */ - public static UnivariateRealSolverFactory newInstance() { - return new UnivariateRealSolverFactoryImpl(); - } - - /** - * Create a new {@link UnivariateRealSolver}. The - * actual solver returned is determined by the underlying factory. - * @return the new solver. - */ - public abstract UnivariateRealSolver newDefaultSolver(); - - /** - * Create a new {@link UnivariateRealSolver}. The - * solver is an implementation of the bisection method. - * @return the new solver. - */ - public abstract UnivariateRealSolver newBisectionSolver(); - - /** - * Create a new {@link UnivariateRealSolver}. The - * solver is an implementation of the Brent method. - * @return the new solver. - */ - public abstract UnivariateRealSolver newBrentSolver(); - - /** - * Create a new {@link UnivariateRealSolver}. The - * solver is an implementation of Newton's Method. - * @return the new solver. - */ - public abstract UnivariateRealSolver newNewtonSolver(); - - /** - * Create a new {@link UnivariateRealSolver}. The - * solver is an implementation of the secant method. - * @return the new solver. - */ - public abstract UnivariateRealSolver newSecantSolver(); - -} diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverFactoryImpl.java b/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverFactoryImpl.java deleted file mode 100644 index e473d48ac..000000000 --- a/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverFactoryImpl.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.commons.math.analysis.solvers; - -/** - * A concrete {@link UnivariateRealSolverFactory}. This is the default solver factory - * used by commons-math. - *

- * The default solver returned by this factory is a {@link BrentSolver}.

- * - * @version $Revision$ $Date$ - */ -public class UnivariateRealSolverFactoryImpl extends UnivariateRealSolverFactory { - - /** - * Default constructor. - */ - public UnivariateRealSolverFactoryImpl() { - } - - /** {@inheritDoc} */ - @Override - public UnivariateRealSolver newDefaultSolver() { - return newBrentSolver(); - } - - /** {@inheritDoc} */ - @Override - public UnivariateRealSolver newBisectionSolver() { - return new BisectionSolver(); - } - - /** {@inheritDoc} */ - @Override - public UnivariateRealSolver newBrentSolver() { - return new BrentSolver(); - } - - /** {@inheritDoc} */ - @Override - public UnivariateRealSolver newNewtonSolver() { - return new NewtonSolver(); - } - - /** {@inheritDoc} */ - @Override - public UnivariateRealSolver newSecantSolver() { - return new SecantSolver(); - } -} diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverImpl.java b/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverImpl.java index df1fb0d5f..e812abc2b 100644 --- a/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverImpl.java +++ b/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverImpl.java @@ -22,16 +22,17 @@ import org.apache.commons.math.MathRuntimeException; import org.apache.commons.math.analysis.UnivariateRealFunction; import org.apache.commons.math.exception.util.LocalizedFormats; import org.apache.commons.math.exception.MathUserException; -import org.apache.commons.math.exception.NullArgumentException; /** * Provide a default implementation for several functions useful to generic * solvers. * * @version $Revision$ $Date$ + * @deprecated in 2.2 (to be removed in 3.0). Please use + * {@link AbstractUnivariateRealSolver} instead. */ -public abstract class UnivariateRealSolverImpl - extends ConvergingAlgorithmImpl implements UnivariateRealSolver { +@Deprecated +public abstract class UnivariateRealSolverImpl extends ConvergingAlgorithmImpl { /** Maximum error of function. */ protected double functionValueAccuracy; diff --git a/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverUtils.java b/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverUtils.java index ce6e94f74..cdd4ff547 100644 --- a/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverUtils.java +++ b/src/main/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverUtils.java @@ -16,12 +16,12 @@ */ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.ConvergenceException; -import org.apache.commons.math.MathRuntimeException; import org.apache.commons.math.analysis.UnivariateRealFunction; import org.apache.commons.math.exception.util.LocalizedFormats; -import org.apache.commons.math.exception.MathUserException; import org.apache.commons.math.exception.NullArgumentException; +import org.apache.commons.math.exception.NoBracketingException; +import org.apache.commons.math.exception.NumberIsTooLargeException; +import org.apache.commons.math.exception.NotStrictlyPositiveException; import org.apache.commons.math.util.FastMath; /** @@ -30,56 +30,53 @@ import org.apache.commons.math.util.FastMath; * @version $Revision$ $Date$ */ public class UnivariateRealSolverUtils { - /** - * Default constructor. + * Class contains only static methods. */ - private UnivariateRealSolverUtils() { - super(); - } + private UnivariateRealSolverUtils() {} /** * Convenience method to find a zero of a univariate real function. A default * solver is used. * - * @param f the function. - * @param x0 the lower bound for the interval. - * @param x1 the upper bound for the interval. + * @param function Function. + * @param x0 Lower bound for the interval. + * @param x1 Upper bound for the interval. * @return a value where the function is zero. - * @throws ConvergenceException if the iteration count was exceeded - * @throws MathUserException if an error occurs evaluating the function * @throws IllegalArgumentException if f is null or the endpoints do not - * specify a valid interval + * specify a valid interval. */ - public static double solve(UnivariateRealFunction f, double x0, double x1) - throws ConvergenceException, MathUserException { - setup(f); - return LazyHolder.FACTORY.newDefaultSolver().solve(f, x0, x1); + public static double solve(UnivariateRealFunction function, double x0, double x1) { + if (function == null) { + throw new NullArgumentException(LocalizedFormats.FUNCTION); + } + final UnivariateRealSolver solver = new BrentSolver(); + solver.setMaxEvaluations(Integer.MAX_VALUE); + return solver.solve(function, x0, x1); } /** * Convenience method to find a zero of a univariate real function. A default * solver is used. * - * @param f the function - * @param x0 the lower bound for the interval - * @param x1 the upper bound for the interval - * @param absoluteAccuracy the accuracy to be used by the solver - * @return a value where the function is zero - * @throws ConvergenceException if the iteration count is exceeded - * @throws MathUserException if an error occurs evaluating the function - * @throws IllegalArgumentException if f is null, the endpoints do not - * specify a valid interval, or the absoluteAccuracy is not valid for the - * default solver + * @param function Function. + * @param x0 Lower bound for the interval. + * @param x1 Upper bound for the interval. + * @param absoluteAccuracy Accuracy to be used by the solver. + * @return a value where the function is zero. + * @throws IllegalArgumentException if {@code function} is {@code null}, + * the endpoints do not specify a valid interval, or the absolute accuracy + * is not valid for the default solver. */ - public static double solve(UnivariateRealFunction f, double x0, double x1, - double absoluteAccuracy) throws ConvergenceException, - MathUserException { - - setup(f); - UnivariateRealSolver solver = LazyHolder.FACTORY.newDefaultSolver(); - solver.setAbsoluteAccuracy(absoluteAccuracy); - return solver.solve(f, x0, x1); + public static double solve(UnivariateRealFunction function, + double x0, double x1, + double absoluteAccuracy) { + if (function == null) { + throw new NullArgumentException(LocalizedFormats.FUNCTION); + } + final UnivariateRealSolver solver = new BrentSolver(absoluteAccuracy); + solver.setMaxEvaluations(Integer.MAX_VALUE); + return solver.solve(function, x0, x1); } /** @@ -110,23 +107,21 @@ public class UnivariateRealSolverUtils { * {@link #bracket(UnivariateRealFunction, double, double, double, int)}, * explicitly specifying the maximum number of iterations.

* - * @param function the function - * @param initial initial midpoint of interval being expanded to - * bracket a root - * @param lowerBound lower bound (a is never lower than this value) - * @param upperBound upper bound (b never is greater than this - * value) - * @return a two element array holding {a, b} - * @throws ConvergenceException if a root can not be bracketted - * @throws MathUserException if an error occurs evaluating the function + * @param function Function. + * @param initial Initial midpoint of interval being expanded to + * bracket a root. + * @param lowerBound Lower bound (a is never lower than this value) + * @param upperBound Upper bound (b never is greater than this + * value). + * @return a two-element array holding a and b. + * @throws NoBracketingException if a root cannot be bracketted. * @throws IllegalArgumentException if function is null, maximumIterations - * is not positive, or initial is not between lowerBound and upperBound + * is not positive, or initial is not between lowerBound and upperBound. */ public static double[] bracket(UnivariateRealFunction function, - double initial, double lowerBound, double upperBound) - throws ConvergenceException, MathUserException { - return bracket( function, initial, lowerBound, upperBound, - Integer.MAX_VALUE ) ; + double initial, + double lowerBound, double upperBound) { + return bracket(function, initial, lowerBound, upperBound, Integer.MAX_VALUE); } /** @@ -148,42 +143,36 @@ public class UnivariateRealSolverUtils { *
  • maximumIterations iterations elapse * -- ConvergenceException
  • * - * @param function the function - * @param initial initial midpoint of interval being expanded to - * bracket a root - * @param lowerBound lower bound (a is never lower than this value) - * @param upperBound upper bound (b never is greater than this - * value) - * @param maximumIterations maximum number of iterations to perform - * @return a two element array holding {a, b}. - * @throws ConvergenceException if the algorithm fails to find a and b - * satisfying the desired conditions - * @throws MathUserException if an error occurs evaluating the function + * @param function Function. + * @param initial Initial midpoint of interval being expanded to + * bracket a root. + * @param lowerBound Lower bound (a is never lower than this value). + * @param upperBound Upper bound (b never is greater than this + * value). + * @param maximumIterations Maximum number of iterations to perform + * @return a two element array holding a and b. + * @throws NoBracketingException if the algorithm fails to find a and b + * satisfying the desired conditions. * @throws IllegalArgumentException if function is null, maximumIterations - * is not positive, or initial is not between lowerBound and upperBound + * is not positive, or initial is not between lowerBound and upperBound. */ public static double[] bracket(UnivariateRealFunction function, - double initial, double lowerBound, double upperBound, - int maximumIterations) throws ConvergenceException, - MathUserException { - + double initial, + double lowerBound, double upperBound, + int maximumIterations) { if (function == null) { throw new NullArgumentException(LocalizedFormats.FUNCTION); } if (maximumIterations <= 0) { - throw MathRuntimeException.createIllegalArgumentException( - LocalizedFormats.INVALID_MAX_ITERATIONS, maximumIterations); - } - if (initial < lowerBound || initial > upperBound || lowerBound >= upperBound) { - throw MathRuntimeException.createIllegalArgumentException( - LocalizedFormats.INVALID_BRACKETING_PARAMETERS, - lowerBound, initial, upperBound); + throw new NotStrictlyPositiveException(LocalizedFormats.INVALID_MAX_ITERATIONS, maximumIterations); } + verifySequence(lowerBound, initial, upperBound); + double a = initial; double b = initial; double fa; double fb; - int numIterations = 0 ; + int numIterations = 0; do { a = FastMath.max(a - 1.0, lowerBound); @@ -191,18 +180,18 @@ public class UnivariateRealSolverUtils { fa = function.value(a); fb = function.value(b); - numIterations++ ; + ++numIterations; } while ((fa * fb > 0.0) && (numIterations < maximumIterations) && ((a > lowerBound) || (b < upperBound))); - if (fa * fb > 0.0 ) { - throw new ConvergenceException( - LocalizedFormats.FAILED_BRACKETING, - numIterations, maximumIterations, initial, - lowerBound, upperBound, a, b, fa, fb); + if (fa * fb > 0.0) { + throw new NoBracketingException(LocalizedFormats.FAILED_BRACKETING, + a, b, fa, fb, + numIterations, maximumIterations, initial, + lowerBound, upperBound); } - return new double[]{a, b}; + return new double[] {a, b}; } /** @@ -213,28 +202,95 @@ public class UnivariateRealSolverUtils { * @return the midpoint. */ public static double midpoint(double a, double b) { - return (a + b) * .5; + return (a + b) * 0.5; } /** - * Checks to see if f is null, throwing IllegalArgumentException if so. - * @param f input function - * @throws IllegalArgumentException if f is null + * Check whether the function takes opposite signs at the endpoints. + * + * @param function Function. + * @param lower Lower endpoint. + * @param upper Upper endpoint. + * @return {@code true} if the function values have opposite signs at the + * given points. */ - private static void setup(UnivariateRealFunction f) { - if (f == null) { + public static boolean isBracketing(UnivariateRealFunction function, + final double lower, + final double upper) { + if (function == null) { throw new NullArgumentException(LocalizedFormats.FUNCTION); } + final double fLo = function.value(lower); + final double fHi = function.value(upper); + return (fLo > 0 && fHi < 0) || (fLo < 0 && fHi > 0); + } + + /** + * Check whether the arguments form a (strictly) increasing sequence. + * + * @param start First number. + * @param mid Second number. + * @param end Third number. + * @return {@code true} if the arguments form an increasing sequence. + */ + public static boolean isSequence(final double start, + final double mid, + final double end) { + return (start < mid) && (mid < end); + } + + /** + * Check that the endpoints specify an interval. + * + * @param lower Lower endpoint. + * @param upper Upper endpoint. + * @throws NumberIsTooLargeException if {@code lower >= upper}. + */ + public static void verifyInterval(final double lower, + final double upper) { + if (lower >= upper) { + throw new NumberIsTooLargeException(LocalizedFormats.ENDPOINTS_NOT_AN_INTERVAL, + lower, upper, false); + } } - // CHECKSTYLE: stop HideUtilityClassConstructor - /** Holder for the factory. - *

    We use here the Initialization On Demand Holder Idiom.

    + /** + * Check that {@code lower < initial < upper}. + * + * @param lower Lower endpoint. + * @param initial Initial value. + * @param upper Upper endpoint. + * @throws NumberIsTooLargeException if {@code lower >= initial} or + * {@code initial >= upper}. */ - private static class LazyHolder { - /** Cached solver factory */ - private static final UnivariateRealSolverFactory FACTORY = UnivariateRealSolverFactory.newInstance(); + public static void verifySequence(final double lower, + final double initial, + final double upper) { + verifyInterval(lower, initial); + verifyInterval(initial, upper); } - // CHECKSTYLE: resume HideUtilityClassConstructor + /** + * Check that the endpoints specify an interval and the function takes + * opposite signs at the endpoints. + * + * @param function Function. + * @param lower Lower endpoint. + * @param upper Upper endpoint. + * @throws NoBracketingException if function has the same sign at the + * endpoints. + */ + public static void verifyBracketing(UnivariateRealFunction function, + final double lower, + final double upper) { + if (function == null) { + throw new NullArgumentException(LocalizedFormats.FUNCTION); + } + verifyInterval(lower, upper); + if (!isBracketing(function, lower, upper)) { + throw new NoBracketingException(lower, upper, + function.value(lower), + function.value(upper)); + } + } } diff --git a/src/main/java/org/apache/commons/math/distribution/AbstractContinuousDistribution.java b/src/main/java/org/apache/commons/math/distribution/AbstractContinuousDistribution.java index 2f418736c..c38ded084 100644 --- a/src/main/java/org/apache/commons/math/distribution/AbstractContinuousDistribution.java +++ b/src/main/java/org/apache/commons/math/distribution/AbstractContinuousDistribution.java @@ -18,7 +18,6 @@ package org.apache.commons.math.distribution; import java.io.Serializable; -import org.apache.commons.math.ConvergenceException; import org.apache.commons.math.MathException; import org.apache.commons.math.analysis.UnivariateRealFunction; import org.apache.commons.math.analysis.solvers.BrentSolver; @@ -27,6 +26,7 @@ import org.apache.commons.math.exception.MathUserException; import org.apache.commons.math.exception.NotStrictlyPositiveException; import org.apache.commons.math.exception.OutOfRangeException; import org.apache.commons.math.exception.util.LocalizedFormats; +import org.apache.commons.math.exception.NumberIsTooLargeException; import org.apache.commons.math.random.RandomDataImpl; import org.apache.commons.math.util.FastMath; @@ -106,7 +106,7 @@ public abstract class AbstractContinuousDistribution bracket = UnivariateRealSolverUtils.bracket( rootFindingFunction, getInitialDomain(p), lowerBound, upperBound); - } catch (ConvergenceException ex) { + } catch (NumberIsTooLargeException ex) { /* * Check domain endpoints to see if one gives value that is within * the default solver's defaultAbsoluteAccuracy of 0 (will be the diff --git a/src/main/java/org/apache/commons/math/exception/NoBracketingException.java b/src/main/java/org/apache/commons/math/exception/NoBracketingException.java new file mode 100644 index 000000000..f5b0db11c --- /dev/null +++ b/src/main/java/org/apache/commons/math/exception/NoBracketingException.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.exception; + +import org.apache.commons.math.exception.util.Localizable; +import org.apache.commons.math.exception.util.LocalizedFormats; + +/** + * Exception to be thrown when function values have the same sign at both + * ends of an interval. + * + * @since 3.0 + * @version $Revision$ $Date$ + */ +public class NoBracketingException extends MathIllegalArgumentException { + /** Serializable version Id. */ + private static final long serialVersionUID = -3629324471511904459L; + /** Lower end of the interval. */ + private final double lo; + /** Higher end of the interval. */ + private final double hi; + /** Value at lower end of the interval. */ + private final double fLo; + /** Value at higher end of the interval. */ + private final double fHi; + + /** + * Construct the exception. + * + * @param lo Lower end of the interval. + * @param hi Higher end of the interval. + * @param fLo Value at lower end of the interval. + * @param fHi Value at higher end of the interval. + */ + public NoBracketingException(double lo, double hi, + double fLo, double fHi) { + this(null, lo, hi, fLo, fHi); + } + /** + * Construct the exception with a specific context. + * + * @param specific Contextual information on what caused the exception. + * @param lo Lower end of the interval. + * @param hi Higher end of the interval. + * @param fLo Value at lower end of the interval. + * @param fHi Value at higher end of the interval. + */ + public NoBracketingException(Localizable specific, + double lo, double hi, + double fLo, double fHi) { + super(specific, LocalizedFormats.SAME_SIGN_AT_ENDPOINTS, lo, hi, fLo, fHi); + this.lo = lo; + this.hi = hi; + this.fLo = fLo; + this.fHi = fHi; + } + /** + * Construct the exception with a specific context. + * + * @param specific Contextual information on what caused the exception. + * @param lo Lower end of the interval. + * @param hi Higher end of the interval. + * @param fLo Value at lower end of the interval. + * @param fHi Value at higher end of the interval. + * @param args Additional arguments. + */ + public NoBracketingException(Localizable specific, + double lo, double hi, + double fLo, double fHi, + Object ... args) { + super(specific, LocalizedFormats.SAME_SIGN_AT_ENDPOINTS, lo, hi, fLo, fHi, args); + this.lo = lo; + this.hi = hi; + this.fLo = fLo; + this.fHi = fHi; + } + + /** + * Get the lower end of the interval. + * + * @return the lower end. + */ + public double getLo() { + return lo; + } + /** + * Get the higher end of the interval. + * + * @return the higher end. + */ + public double getHi() { + return hi; + } + /** + * Get the value at the lower end of the interval. + * + * @return the value at the lower end. + */ + public double getFLo() { + return fLo; + } + /** + * Get the value at the higher end of the interval. + * + * @return the value at the higher end. + */ + public double getFHi() { + return fHi; + } +} diff --git a/src/main/java/org/apache/commons/math/exception/util/LocalizedFormats.java b/src/main/java/org/apache/commons/math/exception/util/LocalizedFormats.java index 8cb9ab1cc..f030ab803 100644 --- a/src/main/java/org/apache/commons/math/exception/util/LocalizedFormats.java +++ b/src/main/java/org/apache/commons/math/exception/util/LocalizedFormats.java @@ -98,7 +98,7 @@ public enum LocalizedFormats implements Localizable { EVALUATION("evaluation"), /* keep */ EXPANSION_FACTOR_SMALLER_THAN_ONE("expansion factor smaller than one ({0})"), FACTORIAL_NEGATIVE_PARAMETER("must have n >= 0 for n!, got n = {0}"), - FAILED_BRACKETING("number of iterations={0}, maximum iterations={1}, initial={2}, lower bound={3}, upper bound={4}, final a value={5}, final b value={6}, f(a)={7}, f(b)={8}"), + FAILED_BRACKETING("number of iterations={4}, maximum iterations={5}, initial={6}, lower bound={7}, upper bound={8}, final a value={0}, final b value={1}, f(a)={2}, f(b)={3}"), FAILED_FRACTION_CONVERSION("Unable to convert {0} to fraction after {1} iterations"), FIRST_COLUMNS_NOT_INITIALIZED_YET("first {0} columns are not initialized yet"), FIRST_ELEMENT_NOT_ZERO("first element is not 0: {0}"), @@ -263,6 +263,7 @@ public enum LocalizedFormats implements Localizable { PERCENTILE_IMPLEMENTATION_CANNOT_ACCESS_METHOD("cannot access {0} method in percentile implementation {1}"), PERCENTILE_IMPLEMENTATION_UNSUPPORTED_METHOD("percentile implementation {0} does not support {1}"), PERMUTATION_EXCEEDS_N("permutation size ({0}) exceeds permuation domain ({1})"), /* keep */ + POLYNOMIAL("polynomial"), /* keep */ POLYNOMIAL_INTERPOLANTS_MISMATCH_SEGMENTS("number of polynomial interpolants must match the number of segments ({0} != {1} - 1)"), POPULATION_LIMIT_NOT_POSITIVE("population limit has to be positive"), POSITION_SIZE_MISMATCH_INPUT_ARRAY("position {0} and size {1} don't fit to the size of the input array {2}"), diff --git a/src/main/java/org/apache/commons/math/ode/events/EventState.java b/src/main/java/org/apache/commons/math/ode/events/EventState.java index 5e9d7bd8a..930b5136c 100644 --- a/src/main/java/org/apache/commons/math/ode/events/EventState.java +++ b/src/main/java/org/apache/commons/math/ode/events/EventState.java @@ -252,7 +252,8 @@ public class EventState { } } }; - final BrentSolver solver = new BrentSolver(maxIterationCount, convergence); + final BrentSolver solver = new BrentSolver(convergence); + solver.setMaxEvaluations(maxIterationCount); final double root = (ta <= tb) ? solver.solve(f, ta, tb) : solver.solve(f, tb, ta); if ((FastMath.abs(root - ta) <= convergence) && (FastMath.abs(root - previousEventTime) <= convergence)) { diff --git a/src/main/java/org/apache/commons/math/optimization/LeastSquaresConverter.java b/src/main/java/org/apache/commons/math/optimization/LeastSquaresConverter.java index a3376b72f..923f0eb8f 100644 --- a/src/main/java/org/apache/commons/math/optimization/LeastSquaresConverter.java +++ b/src/main/java/org/apache/commons/math/optimization/LeastSquaresConverter.java @@ -109,8 +109,7 @@ public class LeastSquaresConverter implements MultivariateRealFunction { * the {@link #value(double[])} method is called) */ public LeastSquaresConverter(final MultivariateVectorialFunction function, - final double[] observations, final double[] weights) - throws IllegalArgumentException { + final double[] observations, final double[] weights) { if (observations.length != weights.length) { throw new DimensionMismatchException(observations.length, weights.length); } diff --git a/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java b/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java index 89cf7b424..9d1907eae 100644 --- a/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java +++ b/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java @@ -18,7 +18,6 @@ package org.apache.commons.math.optimization.fitting; import org.apache.commons.math.optimization.DifferentiableMultivariateVectorialOptimizer; -import org.apache.commons.math.optimization.OptimizationException; import org.apache.commons.math.optimization.fitting.CurveFitter; import org.apache.commons.math.optimization.fitting.WeightedObservedPoint; @@ -86,7 +85,8 @@ public class GaussianFitter { /** * Fits Gaussian function to the observed points. - * It will call {@link CurveFitter#fit()}. + * It will call the base class + * {@link CurveFitter#fit(ParametricRealFunction,double[]) fit} method. * * @return the Gaussian function that best fits the observed points. * @see CurveFitter diff --git a/src/main/java/org/apache/commons/math/optimization/fitting/HarmonicFitter.java b/src/main/java/org/apache/commons/math/optimization/fitting/HarmonicFitter.java index b41afcfe0..009eba4b2 100644 --- a/src/main/java/org/apache/commons/math/optimization/fitting/HarmonicFitter.java +++ b/src/main/java/org/apache/commons/math/optimization/fitting/HarmonicFitter.java @@ -77,6 +77,7 @@ public class HarmonicFitter { * @return harmonic Function that best fits the observed points. * @throws NumberIsTooSmallException if the sample is too short or if * the first guess cannot be computed. + * @throws OptimizationException */ public HarmonicFunction fit() throws OptimizationException { // shall we compute the first guess of the parameters ourselves ? @@ -93,7 +94,7 @@ public class HarmonicFitter { guesser.getGuessedAmplitude(), guesser.getGuessedPulsation(), guesser.getGuessedPhase() - }; + }; } double[] fitted = fitter.fit(new ParametricHarmonicFunction(), parameters); diff --git a/src/main/java/org/apache/commons/math/optimization/fitting/ParametricGaussianFunction.java b/src/main/java/org/apache/commons/math/optimization/fitting/ParametricGaussianFunction.java index 58eb206f4..1dc9bae68 100644 --- a/src/main/java/org/apache/commons/math/optimization/fitting/ParametricGaussianFunction.java +++ b/src/main/java/org/apache/commons/math/optimization/fitting/ParametricGaussianFunction.java @@ -97,11 +97,9 @@ public class ParametricGaussianFunction implements ParametricRealFunction, Seria * respect to {@code c}, and the partial derivative of {@code f(a, b, c, * d)} with respect to {@code d}. * - * @param x {@code x} value to be used as constant in {@code f(a, b, c, d)}. - * @param parameters values of {@code a}, {@code b}, {@code c}, and - * {@code d} for computation of gradient vector of {@code f(a, b, c, d)}. - * @return the gradient vector of {@code f(a, b, c, d)}. + * @param x Value to be used as constant in {@code f(x, a, b, c, d)}. * @param parameters Values of {@code a}, {@code b}, {@code c}, and {@code d}. + * @return the gradient vector of {@code f(a, b, c, d)}. * @throws NullArgumentException if {@code parameters} is {@code null}. * @throws DimensionMismatchException if the size of {@code parameters} is * not 4. diff --git a/src/main/java/org/apache/commons/math/optimization/general/AbstractLeastSquaresOptimizer.java b/src/main/java/org/apache/commons/math/optimization/general/AbstractLeastSquaresOptimizer.java index bb4ab6dc7..8a50970c1 100644 --- a/src/main/java/org/apache/commons/math/optimization/general/AbstractLeastSquaresOptimizer.java +++ b/src/main/java/org/apache/commons/math/optimization/general/AbstractLeastSquaresOptimizer.java @@ -17,7 +17,6 @@ package org.apache.commons.math.optimization.general; -import org.apache.commons.math.exception.MathUserException; import org.apache.commons.math.exception.NumberIsTooSmallException; import org.apache.commons.math.exception.DimensionMismatchException; import org.apache.commons.math.analysis.DifferentiableMultivariateVectorialFunction; @@ -98,9 +97,10 @@ public abstract class AbstractLeastSquaresOptimizer * * @throws DimensionMismatchException if the Jacobian dimension does not * match problem dimension. - * @throws MathUserException if users jacobian function throws one + * @throws org.apache.commons.math.exception.MathUserException if the jacobian + * function throws one. */ - protected void updateJacobian() throws MathUserException { + protected void updateJacobian() { ++jacobianEvaluations; weightedResidualJacobian = jF.value(point); if (weightedResidualJacobian.length != rows) { @@ -126,7 +126,7 @@ public abstract class AbstractLeastSquaresOptimizer * @throws org.apache.commons.math.exception.TooManyEvaluationsException * if the maximal number of evaluations is exceeded. */ - protected void updateResidualsAndCost() throws MathUserException { + protected void updateResidualsAndCost() { objective = computeObjectiveValue(point); if (objective.length != rows) { throw new DimensionMismatchException(objective.length, rows); @@ -176,9 +176,10 @@ public abstract class AbstractLeastSquaresOptimizer * @return the covariance matrix. * @throws org.apache.commons.math.exception.SingularMatrixException * if the covariance matrix cannot be computed (singular problem). - * @throws MathUserException if jacobian function throws one + * @throws org.apache.commons.math.exception.MathUserException if the jacobian + * function throws one. */ - public double[][] getCovariances() throws MathUserException { + public double[][] getCovariances() { // set up the jacobian updateJacobian(); @@ -211,9 +212,10 @@ public abstract class AbstractLeastSquaresOptimizer * @throws NumberIsTooSmallException if the number of degrees of freedom is not * positive, i.e. the number of measurements is less or equal to the number of * parameters. - * @throws MathUserException if jacobian function throws one + * @throws org.apache.commons.math.exception.MathUserException if the jacobian + * function throws one. */ - public double[] guessParametersErrors() throws MathUserException { + public double[] guessParametersErrors() { if (rows <= cols) { throw new NumberIsTooSmallException(LocalizedFormats.NO_DEGREES_OF_FREEDOM, rows, cols, false); @@ -231,7 +233,7 @@ public abstract class AbstractLeastSquaresOptimizer @Override public VectorialPointValuePair optimize(final DifferentiableMultivariateVectorialFunction f, final double[] target, final double[] weights, - final double[] startPoint) throws MathUserException { + final double[] startPoint) { // Reset counter. jacobianEvaluations = 0; diff --git a/src/main/java/org/apache/commons/math/optimization/general/NonLinearConjugateGradientOptimizer.java b/src/main/java/org/apache/commons/math/optimization/general/NonLinearConjugateGradientOptimizer.java index 854f44713..8fe478cab 100644 --- a/src/main/java/org/apache/commons/math/optimization/general/NonLinearConjugateGradientOptimizer.java +++ b/src/main/java/org/apache/commons/math/optimization/general/NonLinearConjugateGradientOptimizer.java @@ -18,7 +18,6 @@ package org.apache.commons.math.optimization.general; import org.apache.commons.math.exception.MathIllegalStateException; -import org.apache.commons.math.exception.ConvergenceException; import org.apache.commons.math.exception.MathUserException; import org.apache.commons.math.analysis.UnivariateRealFunction; import org.apache.commons.math.analysis.solvers.BrentSolver; @@ -40,7 +39,6 @@ import org.apache.commons.math.util.FastMath; * @since 2.0 * */ - public class NonLinearConjugateGradientOptimizer extends AbstractScalarDifferentiableOptimizer { /** Update formula for the beta parameter. */ @@ -86,7 +84,8 @@ public class NonLinearConjugateGradientOptimizer * default {@link BrentSolver Brent solver}. */ public void setLineSearchSolver(final UnivariateRealSolver lineSearchSolver) { - this.solver = lineSearchSolver; + solver = lineSearchSolver; + solver.setMaxEvaluations(getMaxEvaluations()); } /** @@ -116,6 +115,7 @@ public class NonLinearConjugateGradientOptimizer } if (solver == null) { solver = new BrentSolver(); + solver.setMaxEvaluations(getMaxEvaluations()); } point = getStartPoint(); final GoalType goal = getGoalType(); @@ -158,15 +158,15 @@ public class NonLinearConjugateGradientOptimizer // Find the optimal step in the search direction. final UnivariateRealFunction lsf = new LineSearchFunction(searchDirection); - try { - final double step = solver.solve(lsf, 0, findUpperBound(lsf, 0, initialStep)); + final double uB = findUpperBound(lsf, 0, initialStep); + // XXX Last parameters is set to a value clode to zero in order to + // work around the divergence problem in the "testCircleFitting" + // unit test (see MATH-439). + final double step = solver.solve(lsf, 0, uB, 1e-15); - // Validate new point. - for (int i = 0; i < point.length; ++i) { - point[i] += step * searchDirection[i]; - } - } catch (org.apache.commons.math.ConvergenceException e) { - throw new ConvergenceException(); // XXX ugly workaround. + // Validate new point. + for (int i = 0; i < point.length; ++i) { + point[i] += step * searchDirection[i]; } r = computeObjectiveGradient(point); @@ -242,7 +242,6 @@ public class NonLinearConjugateGradientOptimizer public double[] precondition(double[] variables, double[] r) { return r.clone(); } - } /** Internal class for line search. @@ -267,7 +266,6 @@ public class NonLinearConjugateGradientOptimizer /** {@inheritDoc} */ public double value(double x) throws MathUserException { - // current point in the search direction final double[] shiftedPoint = point.clone(); for (int i = 0; i < shiftedPoint.length; ++i) { @@ -284,9 +282,6 @@ public class NonLinearConjugateGradientOptimizer } return dotProduct; - } - } - } diff --git a/src/main/resources/META-INF/localization/LocalizedFormats_fr.properties b/src/main/resources/META-INF/localization/LocalizedFormats_fr.properties index e2b9d12d9..cf3150984 100644 --- a/src/main/resources/META-INF/localization/LocalizedFormats_fr.properties +++ b/src/main/resources/META-INF/localization/LocalizedFormats_fr.properties @@ -70,7 +70,7 @@ EVALUATION_FAILED = erreur d''\u00e9valuation pour l''argument {0} EVALUATION = \u00e9valuation EXPANSION_FACTOR_SMALLER_THAN_ONE = facteur d''extension inf\u00e9rieur \u00e0 un ({0}) FACTORIAL_NEGATIVE_PARAMETER = n doit \u00eatre positif pour le calcul de n!, or n = {0} -FAILED_BRACKETING = nombre d''it\u00e9rations = {0}, it\u00e9rations maximum = {1}, valeur initiale = {2}, borne inf\u00e9rieure = {3}, borne sup\u00e9rieure = {4}, valeur a finale = {5}, valeur b finale = {6}, f(a) = {7}, f(b) = {8} +FAILED_BRACKETING = nombre d''it\u00e9rations = {4}, it\u00e9rations maximum = {5}, valeur initiale = {6}, borne inf\u00e9rieure = {7}, borne sup\u00e9rieure = {8}, valeur a finale = {0}, valeur b finale = {1}, f(a) = {2}, f(b) = {3} FAILED_FRACTION_CONVERSION = Impossible de convertir {0} en fraction apr\u00e8s {1} it\u00e9rations FIRST_COLUMNS_NOT_INITIALIZED_YET = les {0} premi\u00e8res colonnes ne sont pas encore initialis\u00e9es FIRST_ELEMENT_NOT_ZERO = le premier \u00e9l\u00e9ment n''est pas nul : {0} @@ -235,6 +235,7 @@ OVERFLOW_IN_SUBTRACTION = d\u00e9passement de capacit\u00e9 pour la soustraction PERCENTILE_IMPLEMENTATION_CANNOT_ACCESS_METHOD = acc\u00e8s impossible \u00e0 la m\u00e9thode {0} PERCENTILE_IMPLEMENTATION_UNSUPPORTED_METHOD = l''implantation de pourcentage {0} ne dispose pas de la m\u00e9thode {1} PERMUTATION_EXCEEDS_N = la taille de la permutation ({0}) d\u00e9passe le domaine de la permutation ({1}) +POLYNOMIAL = polyn\u00f4me POLYNOMIAL_INTERPOLANTS_MISMATCH_SEGMENTS = le nombre d''interpolants polyn\u00f4miaux doit correspondre au nombre de segments ({0} != {1} - 1) POPULATION_LIMIT_NOT_POSITIVE = la limite de population doit \u00eatre positive POSITION_SIZE_MISMATCH_INPUT_ARRAY = la position {0} et la taille {1} sont incompatibles avec la taille du tableau d''entr\u00e9e {2} diff --git a/src/test/java/org/apache/commons/math/analysis/solvers/BisectionSolverTest.java b/src/test/java/org/apache/commons/math/analysis/solvers/BisectionSolverTest.java index 7fbfa3d7d..abbf22761 100644 --- a/src/test/java/org/apache/commons/math/analysis/solvers/BisectionSolverTest.java +++ b/src/test/java/org/apache/commons/math/analysis/solvers/BisectionSolverTest.java @@ -16,166 +16,90 @@ */ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.MathException; import org.apache.commons.math.analysis.QuinticFunction; import org.apache.commons.math.analysis.SinFunction; import org.apache.commons.math.analysis.UnivariateRealFunction; import org.apache.commons.math.util.FastMath; - -import junit.framework.TestCase; +import org.junit.Assert; +import org.junit.Test; /** * @version $Revision$ $Date$ */ -public final class BisectionSolverTest extends TestCase { - - public void testSinZero() throws MathException { +public final class BisectionSolverTest { + @Test + public void testSinZero() { UnivariateRealFunction f = new SinFunction(); double result; - UnivariateRealSolver solver = new BisectionSolver(); + BisectionSolver solver = new BisectionSolver(); + solver.setMaxEvaluations(50); result = solver.solve(f, 3, 4); - assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); result = solver.solve(f, 1, 4); - assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); } - public void testQuinticZero() throws MathException { + @Test + public void testQuinticZero() { UnivariateRealFunction f = new QuinticFunction(); double result; - UnivariateRealSolver solver = new BisectionSolver(); + BisectionSolver solver = new BisectionSolver(); + solver.setMaxEvaluations(50); result = solver.solve(f, -0.2, 0.2); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 0, solver.getAbsoluteAccuracy()); result = solver.solve(f, -0.1, 0.3); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 0, solver.getAbsoluteAccuracy()); result = solver.solve(f, -0.3, 0.45); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 0, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.3, 0.7); - assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.2, 0.6); - assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.05, 0.95); - assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.85, 1.25); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.8, 1.2); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.85, 1.75); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.55, 1.45); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.85, 5); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - assertEquals(result, solver.getResult(), 0); - assertTrue(solver.getIterationCount() > 0); + Assert.assertTrue(solver.getEvaluations() > 0); } + @Test public void testMath369() throws Exception { UnivariateRealFunction f = new SinFunction(); - UnivariateRealSolver solver = new BisectionSolver(); - assertEquals(FastMath.PI, solver.solve(f, 3.0, 3.2, 3.1), solver.getAbsoluteAccuracy()); + BisectionSolver solver = new BisectionSolver(); + solver.setMaxEvaluations(40); + Assert.assertEquals(FastMath.PI, solver.solve(f, 3.0, 3.2, 3.1), solver.getAbsoluteAccuracy()); } /** * */ - public void testSetFunctionValueAccuracy(){ - double expected = 1.0e-2; - UnivariateRealSolver solver = new BisectionSolver(); - solver.setFunctionValueAccuracy(expected); - assertEquals(expected, solver.getFunctionValueAccuracy(), 1.0e-2); - } - - /** - * - */ - public void testResetFunctionValueAccuracy(){ - double newValue = 1.0e-2; - UnivariateRealSolver solver = new BisectionSolver(); - double oldValue = solver.getFunctionValueAccuracy(); - solver.setFunctionValueAccuracy(newValue); - solver.resetFunctionValueAccuracy(); - assertEquals(oldValue, solver.getFunctionValueAccuracy(), 1.0e-2); - } - - /** - * - */ - public void testSetAbsoluteAccuracy(){ - double expected = 1.0e-2; - UnivariateRealSolver solver = new BisectionSolver(); - solver.setAbsoluteAccuracy(expected); - assertEquals(expected, solver.getAbsoluteAccuracy(), 1.0e-2); - } - - /** - * - */ - public void testResetAbsoluteAccuracy(){ - double newValue = 1.0e-2; - UnivariateRealSolver solver = new BisectionSolver(); - double oldValue = solver.getAbsoluteAccuracy(); - solver.setAbsoluteAccuracy(newValue); - solver.resetAbsoluteAccuracy(); - assertEquals(oldValue, solver.getAbsoluteAccuracy(), 1.0e-2); - } - - /** - * - */ - public void testSetMaximalIterationCount(){ + @Test + public void testSetMaximalEvaluationCount(){ int expected = 100; - UnivariateRealSolver solver = new BisectionSolver(); - solver.setMaximalIterationCount(expected); - assertEquals(expected, solver.getMaximalIterationCount()); + BisectionSolver solver = new BisectionSolver(); + solver.setMaxEvaluations(expected); + Assert.assertEquals(expected, solver.getMaxEvaluations()); } - - /** - * - */ - public void testResetMaximalIterationCount(){ - int newValue = 10000; - UnivariateRealSolver solver = new BisectionSolver(); - int oldValue = solver.getMaximalIterationCount(); - solver.setMaximalIterationCount(newValue); - solver.resetMaximalIterationCount(); - assertEquals(oldValue, solver.getMaximalIterationCount()); - } - - /** - * - */ - public void testSetRelativeAccuracy(){ - double expected = 1.0e-2; - UnivariateRealSolver solver = new BisectionSolver(); - solver.setRelativeAccuracy(expected); - assertEquals(expected, solver.getRelativeAccuracy(), 1.0e-2); - } - - /** - * - */ - public void testResetRelativeAccuracy(){ - double newValue = 1.0e-2; - UnivariateRealSolver solver = new BisectionSolver(); - double oldValue = solver.getRelativeAccuracy(); - solver.setRelativeAccuracy(newValue); - solver.resetRelativeAccuracy(); - assertEquals(oldValue, solver.getRelativeAccuracy(), 1.0e-2); - } - - } diff --git a/src/test/java/org/apache/commons/math/analysis/solvers/BrentSolverTest.java b/src/test/java/org/apache/commons/math/analysis/solvers/BrentSolverTest.java index 3654fb40a..74cca8afc 100644 --- a/src/test/java/org/apache/commons/math/analysis/solvers/BrentSolverTest.java +++ b/src/test/java/org/apache/commons/math/analysis/solvers/BrentSolverTest.java @@ -16,17 +16,16 @@ */ package org.apache.commons.math.analysis.solvers; -import junit.framework.TestCase; - -import org.apache.commons.math.MathException; import org.apache.commons.math.analysis.MonitoredFunction; import org.apache.commons.math.analysis.QuinticFunction; import org.apache.commons.math.analysis.SinFunction; import org.apache.commons.math.analysis.UnivariateRealFunction; import org.apache.commons.math.util.FastMath; +import org.junit.Assert; +import org.junit.Test; /** - * Testcase for UnivariateRealSolver. + * Testcase for {@link BrentSolver Brent} solver. * Because Brent-Dekker is guaranteed to converge in less than the default * maximum iteration count due to bisection fallback, it is quite hard to * debug. I include measured iteration counts plus one in order to detect @@ -36,50 +35,32 @@ import org.apache.commons.math.util.FastMath; * * @version $Revision:670469 $ $Date:2008-06-23 10:01:38 +0200 (lun., 23 juin 2008) $ */ -public final class BrentSolverTest extends TestCase { - - public BrentSolverTest(String name) { - super(name); - } - - public void testSinZero() throws MathException { - // The sinus function is behaved well around the root at #pi. The second +public final class BrentSolverTest { + @Test + public void testSinZero() { + // The sinus function is behaved well around the root at pi. The second // order derivative is zero, which means linar approximating methods will // still converge quadratically. UnivariateRealFunction f = new SinFunction(); double result; UnivariateRealSolver solver = new BrentSolver(); + solver.setMaxEvaluations(10); // Somewhat benign interval. The function is monotone. result = solver.solve(f, 3, 4); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); - // 4 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 5); + // System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 7); // Larger and somewhat less benign interval. The function is grows first. result = solver.solve(f, 1, 4); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); - // 5 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 6); - solver = new SecantSolver(); - result = solver.solve(f, 3, 4); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); - // 4 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 5); - result = solver.solve(f, 1, 4); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); - // 5 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 6); - assertEquals(result, solver.getResult(), 0); + // System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 8); } - public void testQuinticZero() throws MathException { + @Test + public void testQuinticZero() { // The quintic function has zeros at 0, +-0.5 and +-1. // Around the root of 0 the function is well behaved, with a second derivative // of zero a 0. @@ -91,238 +72,144 @@ public final class BrentSolverTest extends TestCase { double result; // Brent-Dekker solver. UnivariateRealSolver solver = new BrentSolver(); + solver.setMaxEvaluations(20); // Symmetric bracket around 0. Test whether solvers can handle hitting // the root in the first iteration. result = solver.solve(f, -0.2, 0.2); //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); - assertTrue(solver.getIterationCount() <= 2); + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 3); // 1 iterations on i586 JDK 1.4.1. // Asymmetric bracket around 0, just for fun. Contains extremum. result = solver.solve(f, -0.1, 0.3); //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 0, solver.getAbsoluteAccuracy()); // 5 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 6); + Assert.assertTrue(solver.getEvaluations() <= 7); // Large bracket around 0. Contains two extrema. result = solver.solve(f, -0.3, 0.45); //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 0, solver.getAbsoluteAccuracy()); // 6 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 7); + Assert.assertTrue(solver.getEvaluations() <= 8); // Benign bracket around 0.5, function is monotonous. result = solver.solve(f, 0.3, 0.7); //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); // 6 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 7); + Assert.assertTrue(solver.getEvaluations() <= 9); // Less benign bracket around 0.5, contains one extremum. result = solver.solve(f, 0.2, 0.6); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); - // 6 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 7); + // System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 10); // Large, less benign bracket around 0.5, contains both extrema. result = solver.solve(f, 0.05, 0.95); //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); - // 8 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 9); + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 11); // Relatively benign bracket around 1, function is monotonous. Fast growth for x>1 // is still a problem. result = solver.solve(f, 0.85, 1.25); //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - // 8 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 9); + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 11); // Less benign bracket around 1 with extremum. result = solver.solve(f, 0.8, 1.2); //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - // 8 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 9); + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 11); // Large bracket around 1. Monotonous. result = solver.solve(f, 0.85, 1.75); //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - // 10 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 11); + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 13); // Large bracket around 1. Interval contains extremum. result = solver.solve(f, 0.55, 1.45); //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - // 7 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 8); + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 10); // Very large bracket around 1 for testing fast growth behaviour. result = solver.solve(f, 0.85, 5); //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - // 12 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 13); - // Secant solver. - solver = new SecantSolver(); - result = solver.solve(f, -0.2, 0.2); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); - // 1 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 2); - result = solver.solve(f, -0.1, 0.3); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); - // 5 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 6); - result = solver.solve(f, -0.3, 0.45); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); - // 6 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 7); - result = solver.solve(f, 0.3, 0.7); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); - // 7 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 8); - result = solver.solve(f, 0.2, 0.6); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); - // 6 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 7); - result = solver.solve(f, 0.05, 0.95); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); - // 8 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 9); - result = solver.solve(f, 0.85, 1.25); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - // 10 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 11); - result = solver.solve(f, 0.8, 1.2); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - // 8 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 9); - result = solver.solve(f, 0.85, 1.75); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - // 14 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 15); - // The followig is especially slow because the solver first has to reduce - // the bracket to exclude the extremum. After that, convergence is rapide. - result = solver.solve(f, 0.55, 1.45); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - // 7 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 8); - result = solver.solve(f, 0.85, 5); - //System.out.println( - // "Root: " + result + " Iterations: " + solver.getIterationCount()); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - // 14 iterations on i586 JDK 1.4.1. - assertTrue(solver.getIterationCount() <= 15); - // Static solve method - result = UnivariateRealSolverUtils.solve(f, -0.2, 0.2); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); - result = UnivariateRealSolverUtils.solve(f, -0.1, 0.3); - assertEquals(result, 0, 1E-8); - result = UnivariateRealSolverUtils.solve(f, -0.3, 0.45); - assertEquals(result, 0, 1E-6); - result = UnivariateRealSolverUtils.solve(f, 0.3, 0.7); - assertEquals(result, 0.5, 1E-6); - result = UnivariateRealSolverUtils.solve(f, 0.2, 0.6); - assertEquals(result, 0.5, 1E-6); - result = UnivariateRealSolverUtils.solve(f, 0.05, 0.95); - assertEquals(result, 0.5, 1E-6); - result = UnivariateRealSolverUtils.solve(f, 0.85, 1.25); - assertEquals(result, 1.0, 1E-6); - result = UnivariateRealSolverUtils.solve(f, 0.8, 1.2); - assertEquals(result, 1.0, 1E-6); - result = UnivariateRealSolverUtils.solve(f, 0.85, 1.75); - assertEquals(result, 1.0, 1E-6); - result = UnivariateRealSolverUtils.solve(f, 0.55, 1.45); - assertEquals(result, 1.0, 1E-6); - result = UnivariateRealSolverUtils.solve(f, 0.85, 5); - assertEquals(result, 1.0, 1E-6); + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 15); } - public void testRootEndpoints() throws Exception { + @Test + public void testRootEndpoints() { UnivariateRealFunction f = new SinFunction(); - UnivariateRealSolver solver = new BrentSolver(); + BrentSolver solver = new BrentSolver(); + solver.setMaxEvaluations(10); // endpoint is root double result = solver.solve(f, FastMath.PI, 4); - assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy()); + Assert.assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy()); result = solver.solve(f, 3, FastMath.PI); - assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy()); + Assert.assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy()); result = solver.solve(f, FastMath.PI, 4, 3.5); - assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy()); + Assert.assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy()); result = solver.solve(f, 3, FastMath.PI, 3.07); - assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy()); + Assert.assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy()); } - public void testBadEndpoints() throws Exception { + @Test + public void testBadEndpoints() { UnivariateRealFunction f = new SinFunction(); - UnivariateRealSolver solver = new BrentSolver(); + BrentSolver solver = new BrentSolver(); + solver.setMaxEvaluations(10); try { // bad interval solver.solve(f, 1, -1); - fail("Expecting IllegalArgumentException - bad interval"); + Assert.fail("Expecting IllegalArgumentException - bad interval"); } catch (IllegalArgumentException ex) { // expected } try { // no bracket solver.solve(f, 1, 1.5); - fail("Expecting IllegalArgumentException - non-bracketing"); + Assert.fail("Expecting IllegalArgumentException - non-bracketing"); } catch (IllegalArgumentException ex) { // expected } try { // no bracket solver.solve(f, 1, 1.5, 1.2); - fail("Expecting IllegalArgumentException - non-bracketing"); + Assert.fail("Expecting IllegalArgumentException - non-bracketing"); } catch (IllegalArgumentException ex) { // expected } } - public void testInitialGuess() throws MathException { - + @Test + public void testInitialGuess() { MonitoredFunction f = new MonitoredFunction(new QuinticFunction()); - UnivariateRealSolver solver = new BrentSolver(); + BrentSolver solver = new BrentSolver(); + solver.setMaxEvaluations(20); double result; // no guess result = solver.solve(f, 0.6, 7.0); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); int referenceCallsCount = f.getCallsCount(); - assertTrue(referenceCallsCount >= 13); + Assert.assertTrue(referenceCallsCount >= 13); // invalid guess (it *is* a root, but outside of the range) try { result = solver.solve(f, 0.6, 7.0, 0.0); - fail("an IllegalArgumentException was expected"); + Assert.fail("an IllegalArgumentException was expected"); } catch (IllegalArgumentException iae) { // expected behaviour } @@ -330,22 +217,20 @@ public final class BrentSolverTest extends TestCase { // bad guess f.setCallsCount(0); result = solver.solve(f, 0.6, 7.0, 0.61); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - assertTrue(f.getCallsCount() > referenceCallsCount); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(f.getCallsCount() > referenceCallsCount); // good guess f.setCallsCount(0); result = solver.solve(f, 0.6, 7.0, 0.999999); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - assertTrue(f.getCallsCount() < referenceCallsCount); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(f.getCallsCount() < referenceCallsCount); // perfect guess f.setCallsCount(0); result = solver.solve(f, 0.6, 7.0, 1.0); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); - assertEquals(0, solver.getIterationCount()); - assertEquals(1, f.getCallsCount()); - + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(1, solver.getEvaluations()); + Assert.assertEquals(1, f.getCallsCount()); } - } diff --git a/src/test/java/org/apache/commons/math/analysis/solvers/LaguerreSolverTest.java b/src/test/java/org/apache/commons/math/analysis/solvers/LaguerreSolverTest.java index 9e9aa8c6f..b6ef3f76a 100644 --- a/src/test/java/org/apache/commons/math/analysis/solvers/LaguerreSolverTest.java +++ b/src/test/java/org/apache/commons/math/analysis/solvers/LaguerreSolverTest.java @@ -22,8 +22,9 @@ import org.apache.commons.math.analysis.SinFunction; import org.apache.commons.math.analysis.polynomials.PolynomialFunction; import org.apache.commons.math.complex.Complex; import org.apache.commons.math.util.FastMath; - -import junit.framework.TestCase; +import org.junit.Assert; +import org.junit.Test; +import org.junit.Ignore; /** * Testcase for Laguerre solver. @@ -35,144 +36,145 @@ import junit.framework.TestCase; * * @version $Revision$ $Date$ */ -public final class LaguerreSolverTest extends TestCase { - +public final class LaguerreSolverTest { /** * Test of solver for the linear function. */ - public void testLinearFunction() throws MathException { + @Test + public void testLinearFunction() { double min, max, expected, result, tolerance; // p(x) = 4x - 1 double coefficients[] = { -1.0, 4.0 }; PolynomialFunction f = new PolynomialFunction(coefficients); - UnivariateRealSolver solver = new LaguerreSolver(); + LaguerreSolver solver = new LaguerreSolver(); + solver.setMaxEvaluations(10); min = 0.0; max = 1.0; expected = 0.25; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); } /** * Test of solver for the quadratic function. */ - public void testQuadraticFunction() throws MathException { + @Test + public void testQuadraticFunction() { double min, max, expected, result, tolerance; // p(x) = 2x^2 + 5x - 3 = (x+3)(2x-1) double coefficients[] = { -3.0, 5.0, 2.0 }; PolynomialFunction f = new PolynomialFunction(coefficients); - UnivariateRealSolver solver = new LaguerreSolver(); + LaguerreSolver solver = new LaguerreSolver(); + solver.setMaxEvaluations(10); min = 0.0; max = 2.0; expected = 0.5; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = -4.0; max = -1.0; expected = -3.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); } /** * Test of solver for the quintic function. */ - public void testQuinticFunction() throws MathException { + @Test + public void testQuinticFunction() { double min, max, expected, result, tolerance; // p(x) = x^5 - x^4 - 12x^3 + x^2 - x - 12 = (x+1)(x+3)(x-4)(x^2-x+1) double coefficients[] = { -12.0, -1.0, 1.0, -12.0, -1.0, 1.0 }; PolynomialFunction f = new PolynomialFunction(coefficients); - UnivariateRealSolver solver = new LaguerreSolver(); + LaguerreSolver solver = new LaguerreSolver(); + solver.setMaxEvaluations(10); min = -2.0; max = 2.0; expected = -1.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = -5.0; max = -2.5; expected = -3.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = 3.0; max = 6.0; expected = 4.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); } /** * Test of solver for the quintic function using solveAll(). + * XXX commented out because "solveAll" is not part of the API. */ - public void testQuinticFunction2() throws MathException { - double initial = 0.0, tolerance; - Complex expected, result[]; + // public void testQuinticFunction2() { + // double initial = 0.0, tolerance; + // Complex expected, result[]; - // p(x) = x^5 + 4x^3 + x^2 + 4 = (x+1)(x^2-x+1)(x^2+4) - double coefficients[] = { 4.0, 0.0, 1.0, 4.0, 0.0, 1.0 }; - LaguerreSolver solver = new LaguerreSolver(); - result = solver.solveAll(coefficients, initial); + // // p(x) = x^5 + 4x^3 + x^2 + 4 = (x+1)(x^2-x+1)(x^2+4) + // double coefficients[] = { 4.0, 0.0, 1.0, 4.0, 0.0, 1.0 }; + // LaguerreSolver solver = new LaguerreSolver(); + // result = solver.solveAll(coefficients, initial); - expected = new Complex(0.0, -2.0); - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected.abs() * solver.getRelativeAccuracy())); - TestUtils.assertContains(result, expected, tolerance); + // expected = new Complex(0.0, -2.0); + // tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + // FastMath.abs(expected.abs() * solver.getRelativeAccuracy())); + // TestUtils.assertContains(result, expected, tolerance); - expected = new Complex(0.0, 2.0); - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected.abs() * solver.getRelativeAccuracy())); - TestUtils.assertContains(result, expected, tolerance); + // expected = new Complex(0.0, 2.0); + // tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + // FastMath.abs(expected.abs() * solver.getRelativeAccuracy())); + // TestUtils.assertContains(result, expected, tolerance); - expected = new Complex(0.5, 0.5 * FastMath.sqrt(3.0)); - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected.abs() * solver.getRelativeAccuracy())); - TestUtils.assertContains(result, expected, tolerance); + // expected = new Complex(0.5, 0.5 * FastMath.sqrt(3.0)); + // tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + // FastMath.abs(expected.abs() * solver.getRelativeAccuracy())); + // TestUtils.assertContains(result, expected, tolerance); - expected = new Complex(-1.0, 0.0); - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected.abs() * solver.getRelativeAccuracy())); - TestUtils.assertContains(result, expected, tolerance); + // expected = new Complex(-1.0, 0.0); + // tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + // FastMath.abs(expected.abs() * solver.getRelativeAccuracy())); + // TestUtils.assertContains(result, expected, tolerance); - expected = new Complex(0.5, -0.5 * FastMath.sqrt(3.0)); - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected.abs() * solver.getRelativeAccuracy())); - TestUtils.assertContains(result, expected, tolerance); - } + // expected = new Complex(0.5, -0.5 * FastMath.sqrt(3.0)); + // tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + // FastMath.abs(expected.abs() * solver.getRelativeAccuracy())); + // TestUtils.assertContains(result, expected, tolerance); + // } /** * Test of parameters for the solver. */ - public void testParameters() throws Exception { + @Test + public void testParameters() { double coefficients[] = { -3.0, 5.0, 2.0 }; PolynomialFunction f = new PolynomialFunction(coefficients); - UnivariateRealSolver solver = new LaguerreSolver(); + LaguerreSolver solver = new LaguerreSolver(); + solver.setMaxEvaluations(10); try { // bad interval solver.solve(f, 1, -1); - fail("Expecting IllegalArgumentException - bad interval"); + Assert.fail("Expecting IllegalArgumentException - bad interval"); } catch (IllegalArgumentException ex) { // expected } try { // no bracketing solver.solve(f, 2, 3); - fail("Expecting IllegalArgumentException - no bracketing"); - } catch (IllegalArgumentException ex) { - // expected - } - try { - // bad function - solver.solve(new SinFunction(), -1, 1); - fail("Expecting IllegalArgumentException - bad function"); + Assert.fail("Expecting IllegalArgumentException - no bracketing"); } catch (IllegalArgumentException ex) { // expected } diff --git a/src/test/java/org/apache/commons/math/analysis/solvers/MullerSolver2Test.java b/src/test/java/org/apache/commons/math/analysis/solvers/MullerSolver2Test.java new file mode 100644 index 000000000..56cb297e3 --- /dev/null +++ b/src/test/java/org/apache/commons/math/analysis/solvers/MullerSolver2Test.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.analysis.solvers; + +import org.apache.commons.math.analysis.Expm1Function; +import org.apache.commons.math.analysis.QuinticFunction; +import org.apache.commons.math.analysis.SinFunction; +import org.apache.commons.math.analysis.UnivariateRealFunction; +import org.apache.commons.math.exception.NumberIsTooLargeException; +import org.apache.commons.math.exception.NoBracketingException; +import org.apache.commons.math.util.FastMath; +import org.junit.Assert; +import org.junit.Test; + +/** + * Testcase for {@link MullerSolver2 Muller} solver. + *

    + * Muller's method converges almost quadratically near roots, but it can + * be very slow in regions far away from zeros. Test runs show that for + * reasonably good initial values, for a default absolute accuracy of 1E-6, + * it generally takes 5 to 10 iterations for the solver to converge. + *

    + * Tests for the exponential function illustrate the situations where + * Muller solver performs poorly. + * + * @version $Revision: 1034896 $ $Date: 2010-11-13 23:27:34 +0100 (Sat, 13 Nov 2010) $ + */ +public final class MullerSolver2Test { + /** + * Test of solver for the sine function. + */ + @Test + public void testSinFunction() { + UnivariateRealFunction f = new SinFunction(); + UnivariateRealSolver solver = new MullerSolver2(); + solver.setMaxEvaluations(10); + double min, max, expected, result, tolerance; + + min = 3.0; max = 4.0; expected = FastMath.PI; + tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + FastMath.abs(expected * solver.getRelativeAccuracy())); + result = solver.solve(f, min, max); + Assert.assertEquals(expected, result, tolerance); + + min = -1.0; max = 1.5; expected = 0.0; + tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + FastMath.abs(expected * solver.getRelativeAccuracy())); + result = solver.solve(f, min, max); + Assert.assertEquals(expected, result, tolerance); + } + + /** + * Test of solver for the quintic function. + */ + @Test + public void testQuinticFunction() { + UnivariateRealFunction f = new QuinticFunction(); + UnivariateRealSolver solver = new MullerSolver2(); + solver.setMaxEvaluations(10); + double min, max, expected, result, tolerance; + + min = -0.4; max = 0.2; expected = 0.0; + tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + FastMath.abs(expected * solver.getRelativeAccuracy())); + result = solver.solve(f, min, max); + Assert.assertEquals(expected, result, tolerance); + + min = 0.75; max = 1.5; expected = 1.0; + tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + FastMath.abs(expected * solver.getRelativeAccuracy())); + result = solver.solve(f, min, max); + Assert.assertEquals(expected, result, tolerance); + + min = -0.9; max = -0.2; expected = -0.5; + tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + FastMath.abs(expected * solver.getRelativeAccuracy())); + result = solver.solve(f, min, max); + Assert.assertEquals(expected, result, tolerance); + } + + /** + * Test of solver for the exponential function. + *

    + * It takes 25 to 50 iterations for the last two tests to converge. + */ + @Test + public void testExpm1Function() { + UnivariateRealFunction f = new Expm1Function(); + UnivariateRealSolver solver = new MullerSolver2(); + solver.setMaxEvaluations(55); + double min, max, expected, result, tolerance; + + min = -1.0; max = 2.0; expected = 0.0; + tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + FastMath.abs(expected * solver.getRelativeAccuracy())); + result = solver.solve(f, min, max); + Assert.assertEquals(expected, result, tolerance); + + min = -20.0; max = 10.0; expected = 0.0; + tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + FastMath.abs(expected * solver.getRelativeAccuracy())); + result = solver.solve(f, min, max); + Assert.assertEquals(expected, result, tolerance); + + min = -50.0; max = 100.0; expected = 0.0; + tolerance = FastMath.max(solver.getAbsoluteAccuracy(), + FastMath.abs(expected * solver.getRelativeAccuracy())); + result = solver.solve(f, min, max); + Assert.assertEquals(expected, result, tolerance); + } + + /** + * Test of parameters for the solver. + */ + @Test + public void testParameters() throws Exception { + UnivariateRealFunction f = new SinFunction(); + UnivariateRealSolver solver = new MullerSolver2(); + solver.setMaxEvaluations(10); + + try { + // bad interval + solver.solve(f, 1, -1); + Assert.fail("Expecting IllegalArgumentException - bad interval"); + } catch (NumberIsTooLargeException ex) { + // expected + } + try { + // no bracketing + solver.solve(f, 2, 3); + Assert.fail("Expecting IllegalArgumentException - no bracketing"); + } catch (NoBracketingException ex) { + // expected + } + } +} diff --git a/src/test/java/org/apache/commons/math/analysis/solvers/MullerSolverTest.java b/src/test/java/org/apache/commons/math/analysis/solvers/MullerSolverTest.java index 40ae2014d..46e14e7ef 100644 --- a/src/test/java/org/apache/commons/math/analysis/solvers/MullerSolverTest.java +++ b/src/test/java/org/apache/commons/math/analysis/solvers/MullerSolverTest.java @@ -16,17 +16,18 @@ */ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.MathException; import org.apache.commons.math.analysis.Expm1Function; import org.apache.commons.math.analysis.QuinticFunction; import org.apache.commons.math.analysis.SinFunction; import org.apache.commons.math.analysis.UnivariateRealFunction; +import org.apache.commons.math.exception.NumberIsTooLargeException; +import org.apache.commons.math.exception.NoBracketingException; import org.apache.commons.math.util.FastMath; - -import junit.framework.TestCase; +import org.junit.Assert; +import org.junit.Test; /** - * Testcase for Muller solver. + * Testcase for {@link MullerSolver Muller} solver. *

    * Muller's method converges almost quadratically near roots, but it can * be very slow in regions far away from zeros. Test runs show that for @@ -38,102 +39,57 @@ import junit.framework.TestCase; * * @version $Revision$ $Date$ */ -public final class MullerSolverTest extends TestCase { - +public final class MullerSolverTest { /** * Test of solver for the sine function. */ - public void testSinFunction() throws MathException { + @Test + public void testSinFunction() { UnivariateRealFunction f = new SinFunction(); UnivariateRealSolver solver = new MullerSolver(); + solver.setMaxEvaluations(10); double min, max, expected, result, tolerance; min = 3.0; max = 4.0; expected = FastMath.PI; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = -1.0; max = 1.5; expected = 0.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); - } - - /** - * Test of solver for the sine function using solve2(). - */ - public void testSinFunction2() throws MathException { - UnivariateRealFunction f = new SinFunction(); - MullerSolver solver = new MullerSolver(); - double min, max, expected, result, tolerance; - - min = 3.0; max = 4.0; expected = FastMath.PI; - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected * solver.getRelativeAccuracy())); - result = solver.solve2(f, min, max); - assertEquals(expected, result, tolerance); - - min = -1.0; max = 1.5; expected = 0.0; - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected * solver.getRelativeAccuracy())); - result = solver.solve2(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); } /** * Test of solver for the quintic function. */ - public void testQuinticFunction() throws MathException { + @Test + public void testQuinticFunction() { UnivariateRealFunction f = new QuinticFunction(); UnivariateRealSolver solver = new MullerSolver(); + solver.setMaxEvaluations(15); double min, max, expected, result, tolerance; min = -0.4; max = 0.2; expected = 0.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = 0.75; max = 1.5; expected = 1.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = -0.9; max = -0.2; expected = -0.5; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); - } - - /** - * Test of solver for the quintic function using solve2(). - */ - public void testQuinticFunction2() throws MathException { - UnivariateRealFunction f = new QuinticFunction(); - MullerSolver solver = new MullerSolver(); - double min, max, expected, result, tolerance; - - min = -0.4; max = 0.2; expected = 0.0; - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected * solver.getRelativeAccuracy())); - result = solver.solve2(f, min, max); - assertEquals(expected, result, tolerance); - - min = 0.75; max = 1.5; expected = 1.0; - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected * solver.getRelativeAccuracy())); - result = solver.solve2(f, min, max); - assertEquals(expected, result, tolerance); - - min = -0.9; max = -0.2; expected = -0.5; - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected * solver.getRelativeAccuracy())); - result = solver.solve2(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); } /** @@ -143,78 +99,54 @@ public final class MullerSolverTest extends TestCase { * In fact, if not for the bisection alternative, the solver would * exceed the default maximal iteration of 100. */ - public void testExpm1Function() throws MathException { + @Test + public void testExpm1Function() { UnivariateRealFunction f = new Expm1Function(); UnivariateRealSolver solver = new MullerSolver(); + solver.setMaxEvaluations(25); double min, max, expected, result, tolerance; min = -1.0; max = 2.0; expected = 0.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = -20.0; max = 10.0; expected = 0.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = -50.0; max = 100.0; expected = 0.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); - } - - /** - * Test of solver for the exponential function using solve2(). - *

    - * It takes 25 to 50 iterations for the last two tests to converge. - */ - public void testExpm1Function2() throws MathException { - UnivariateRealFunction f = new Expm1Function(); - MullerSolver solver = new MullerSolver(); - double min, max, expected, result, tolerance; - - min = -1.0; max = 2.0; expected = 0.0; - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected * solver.getRelativeAccuracy())); - result = solver.solve2(f, min, max); - assertEquals(expected, result, tolerance); - - min = -20.0; max = 10.0; expected = 0.0; - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected * solver.getRelativeAccuracy())); - result = solver.solve2(f, min, max); - assertEquals(expected, result, tolerance); - - min = -50.0; max = 100.0; expected = 0.0; - tolerance = FastMath.max(solver.getAbsoluteAccuracy(), - FastMath.abs(expected * solver.getRelativeAccuracy())); - result = solver.solve2(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); } /** * Test of parameters for the solver. */ + @Test public void testParameters() throws Exception { UnivariateRealFunction f = new SinFunction(); UnivariateRealSolver solver = new MullerSolver(); + solver.setMaxEvaluations(10); try { // bad interval - solver.solve(f, 1, -1); - fail("Expecting IllegalArgumentException - bad interval"); - } catch (IllegalArgumentException ex) { + double root = solver.solve(f, 1, -1); + System.out.println("root=" + root); + Assert.fail("Expecting IllegalArgumentException - bad interval"); + } catch (NumberIsTooLargeException ex) { // expected } try { // no bracketing solver.solve(f, 2, 3); - fail("Expecting IllegalArgumentException - no bracketing"); - } catch (IllegalArgumentException ex) { + Assert.fail("Expecting IllegalArgumentException - no bracketing"); + } catch (NoBracketingException ex) { // expected } } diff --git a/src/test/java/org/apache/commons/math/analysis/solvers/NewtonSolverTest.java b/src/test/java/org/apache/commons/math/analysis/solvers/NewtonSolverTest.java index 20aa15481..f59bf127f 100644 --- a/src/test/java/org/apache/commons/math/analysis/solvers/NewtonSolverTest.java +++ b/src/test/java/org/apache/commons/math/analysis/solvers/NewtonSolverTest.java @@ -16,78 +16,78 @@ */ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.MathException; import org.apache.commons.math.analysis.DifferentiableUnivariateRealFunction; import org.apache.commons.math.analysis.QuinticFunction; import org.apache.commons.math.analysis.SinFunction; import org.apache.commons.math.util.FastMath; +import org.junit.Assert; +import org.junit.Test; -import junit.framework.TestCase; - /** * @version $Revision$ $Date$ */ -public final class NewtonSolverTest extends TestCase { - +public final class NewtonSolverTest { /** * */ - public void testSinZero() throws MathException { + @Test + public void testSinZero() { DifferentiableUnivariateRealFunction f = new SinFunction(); double result; - UnivariateRealSolver solver = new NewtonSolver(); + NewtonSolver solver = new NewtonSolver(); + solver.setMaxEvaluations(10); result = solver.solve(f, 3, 4); - assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); result = solver.solve(f, 1, 4); - assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); - assertEquals(result, solver.getResult(), 0); - assertTrue(solver.getIterationCount() > 0); + Assert.assertTrue(solver.getEvaluations() > 0); } /** * */ - public void testQuinticZero() throws MathException { + @Test + public void testQuinticZero() { DifferentiableUnivariateRealFunction f = new QuinticFunction(); double result; - UnivariateRealSolver solver = new NewtonSolver(); + NewtonSolver solver = new NewtonSolver(); + solver.setMaxEvaluations(30); result = solver.solve(f, -0.2, 0.2); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 0, solver.getAbsoluteAccuracy()); result = solver.solve(f, -0.1, 0.3); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 0, solver.getAbsoluteAccuracy()); result = solver.solve(f, -0.3, 0.45); - assertEquals(result, 0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 0, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.3, 0.7); - assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.2, 0.6); - assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.05, 0.95); - assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.85, 1.25); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.8, 1.2); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.85, 1.75); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.55, 1.45); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); result = solver.solve(f, 0.85, 5); - assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); } - } diff --git a/src/test/java/org/apache/commons/math/analysis/solvers/RiddersSolverTest.java b/src/test/java/org/apache/commons/math/analysis/solvers/RiddersSolverTest.java index 8c591ec53..8f4a5adc5 100644 --- a/src/test/java/org/apache/commons/math/analysis/solvers/RiddersSolverTest.java +++ b/src/test/java/org/apache/commons/math/analysis/solvers/RiddersSolverTest.java @@ -16,17 +16,18 @@ */ package org.apache.commons.math.analysis.solvers; -import org.apache.commons.math.MathException; import org.apache.commons.math.analysis.Expm1Function; import org.apache.commons.math.analysis.QuinticFunction; import org.apache.commons.math.analysis.SinFunction; import org.apache.commons.math.analysis.UnivariateRealFunction; import org.apache.commons.math.util.FastMath; - -import junit.framework.TestCase; +import org.apache.commons.math.exception.NoBracketingException; +import org.apache.commons.math.exception.NumberIsTooLargeException; +import org.junit.Assert; +import org.junit.Test; /** - * Testcase for Ridders solver. + * Testcase for {@link RiddersSolver Ridders} solver. *

    * Ridders' method converges superlinearly, more specific, its rate of * convergence is sqrt(2). Test runs show that for a default absolute @@ -36,102 +37,109 @@ import junit.framework.TestCase; * * @version $Revision$ $Date$ */ -public final class RiddersSolverTest extends TestCase { - +public final class RiddersSolverTest { /** * Test of solver for the sine function. */ - public void testSinFunction() throws MathException { + @Test + public void testSinFunction() { UnivariateRealFunction f = new SinFunction(); UnivariateRealSolver solver = new RiddersSolver(); + solver.setMaxEvaluations(10); double min, max, expected, result, tolerance; min = 3.0; max = 4.0; expected = FastMath.PI; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = -1.0; max = 1.5; expected = 0.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); } /** * Test of solver for the quintic function. */ - public void testQuinticFunction() throws MathException { + @Test + public void testQuinticFunction() { UnivariateRealFunction f = new QuinticFunction(); UnivariateRealSolver solver = new RiddersSolver(); + solver.setMaxEvaluations(15); double min, max, expected, result, tolerance; min = -0.4; max = 0.2; expected = 0.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = 0.75; max = 1.5; expected = 1.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = -0.9; max = -0.2; expected = -0.5; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); } /** * Test of solver for the exponential function. */ - public void testExpm1Function() throws MathException { + @Test + public void testExpm1Function() { UnivariateRealFunction f = new Expm1Function(); UnivariateRealSolver solver = new RiddersSolver(); + solver.setMaxEvaluations(20); double min, max, expected, result, tolerance; min = -1.0; max = 2.0; expected = 0.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = -20.0; max = 10.0; expected = 0.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); min = -50.0; max = 100.0; expected = 0.0; tolerance = FastMath.max(solver.getAbsoluteAccuracy(), FastMath.abs(expected * solver.getRelativeAccuracy())); result = solver.solve(f, min, max); - assertEquals(expected, result, tolerance); + Assert.assertEquals(expected, result, tolerance); } /** * Test of parameters for the solver. */ - public void testParameters() throws Exception { + @Test + public void testParameters() { UnivariateRealFunction f = new SinFunction(); UnivariateRealSolver solver = new RiddersSolver(); + solver.setMaxEvaluations(10); try { // bad interval solver.solve(f, 1, -1); - fail("Expecting IllegalArgumentException - bad interval"); - } catch (IllegalArgumentException ex) { + Assert.fail("Expecting IllegalArgumentException - bad interval"); + } catch (NumberIsTooLargeException ex) { // expected } try { // no bracketing solver.solve(f, 2, 3); - fail("Expecting IllegalArgumentException - no bracketing"); - } catch (IllegalArgumentException ex) { + Assert.fail("Expecting IllegalArgumentException - no bracketing"); + } catch (NoBracketingException ex) { // expected } } diff --git a/src/test/java/org/apache/commons/math/analysis/solvers/SecantSolverTest.java b/src/test/java/org/apache/commons/math/analysis/solvers/SecantSolverTest.java new file mode 100644 index 000000000..12b71e7af --- /dev/null +++ b/src/test/java/org/apache/commons/math/analysis/solvers/SecantSolverTest.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.analysis.solvers; + +import org.apache.commons.math.analysis.MonitoredFunction; +import org.apache.commons.math.analysis.QuinticFunction; +import org.apache.commons.math.analysis.SinFunction; +import org.apache.commons.math.analysis.UnivariateRealFunction; +import org.apache.commons.math.util.FastMath; +import org.junit.Assert; +import org.junit.Test; + +/** + * Testcase for {@link SecantSolver}. + * + * @version $Revision:670469 $ $Date:2008-06-23 10:01:38 +0200 (lun., 23 juin 2008) $ + */ +public final class SecantSolverTest { + @Test + public void testSinZero() { + // The sinus function is behaved well around the root at pi. The second + // order derivative is zero, which means linar approximating methods will + // still converge quadratically. + UnivariateRealFunction f = new SinFunction(); + double result; + UnivariateRealSolver solver = new SecantSolver(); + solver.setMaxEvaluations(10); + + result = solver.solve(f, 3, 4); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 6); + result = solver.solve(f, 1, 4); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 7); + } + + @Test + public void testQuinticZero() { + // The quintic function has zeros at 0, +-0.5 and +-1. + // Around the root of 0 the function is well behaved, with a second derivative + // of zero a 0. + // The other roots are less well to find, in particular the root at 1, because + // the function grows fast for x>1. + // The function has extrema (first derivative is zero) at 0.27195613 and 0.82221643, + // intervals containing these values are harder for the solvers. + UnivariateRealFunction f = new QuinticFunction(); + double result; + // Brent-Dekker solver. + UnivariateRealSolver solver = new SecantSolver(); + solver.setMaxEvaluations(20); + result = solver.solve(f, -0.2, 0.2); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 3); + result = solver.solve(f, -0.1, 0.3); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 7); + result = solver.solve(f, -0.3, 0.45); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 8); + result = solver.solve(f, 0.3, 0.7); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 9); + result = solver.solve(f, 0.2, 0.6); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 8); + result = solver.solve(f, 0.05, 0.95); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 0.5, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 10); + result = solver.solve(f, 0.85, 1.25); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 12); + result = solver.solve(f, 0.8, 1.2); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 10); + result = solver.solve(f, 0.85, 1.75); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 16); + // The followig is especially slow because the solver first has to reduce + // the bracket to exclude the extremum. After that, convergence is rapide. + result = solver.solve(f, 0.55, 1.45); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 9); + result = solver.solve(f, 0.85, 5); + //System.out.println( + // "Root: " + result + " Evaluations: " + solver.getEvaluations()); + Assert.assertEquals(result, 1.0, solver.getAbsoluteAccuracy()); + Assert.assertTrue(solver.getEvaluations() <= 16); + } + + @Test + public void testRootEndpoints() { + UnivariateRealFunction f = new SinFunction(); + SecantSolver solver = new SecantSolver(); + solver.setMaxEvaluations(10); + + // endpoint is root + double result = solver.solve(f, FastMath.PI, 4); + Assert.assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy()); + + result = solver.solve(f, 3, FastMath.PI); + Assert.assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy()); + + result = solver.solve(f, FastMath.PI, 4, 3.5); + Assert.assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy()); + + result = solver.solve(f, 3, FastMath.PI, 3.07); + Assert.assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy()); + + } + + @Test + public void testBadEndpoints() { + UnivariateRealFunction f = new SinFunction(); + SecantSolver solver = new SecantSolver(); + solver.setMaxEvaluations(10); + try { // bad interval + solver.solve(f, 1, -1); + Assert.fail("Expecting IllegalArgumentException - bad interval"); + } catch (IllegalArgumentException ex) { + // expected + } + try { // no bracket + solver.solve(f, 1, 1.5); + Assert.fail("Expecting IllegalArgumentException - non-bracketing"); + } catch (IllegalArgumentException ex) { + // expected + } + try { // no bracket + solver.solve(f, 1, 1.5, 1.2); + Assert.fail("Expecting IllegalArgumentException - non-bracketing"); + } catch (IllegalArgumentException ex) { + // expected + } + } +} diff --git a/src/test/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverFactoryImplTest.java b/src/test/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverFactoryImplTest.java deleted file mode 100644 index 25d1e1139..000000000 --- a/src/test/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverFactoryImplTest.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.commons.math.analysis.solvers; - -import junit.framework.TestCase; - -/** - * @version $Revision$ $Date$ - */ -public class UnivariateRealSolverFactoryImplTest extends TestCase { - - /** solver factory */ - private UnivariateRealSolverFactory factory; - - /** - * @throws java.lang.Exception - * @see junit.framework.TestCase#tearDown() - */ - @Override - protected void setUp() throws Exception { - super.setUp(); - factory = new UnivariateRealSolverFactoryImpl(); - } - - /** - * @throws java.lang.Exception - * @see junit.framework.TestCase#tearDown() - */ - @Override - protected void tearDown() throws Exception { - factory = null; - super.tearDown(); - } - - public void testNewBisectionSolverValid() { - UnivariateRealSolver solver = factory.newBisectionSolver(); - assertNotNull(solver); - assertTrue(solver instanceof BisectionSolver); - } - - public void testNewNewtonSolverValid() { - UnivariateRealSolver solver = factory.newNewtonSolver(); - assertNotNull(solver); - assertTrue(solver instanceof NewtonSolver); - } - - public void testNewBrentSolverValid() { - UnivariateRealSolver solver = factory.newBrentSolver(); - assertNotNull(solver); - assertTrue(solver instanceof BrentSolver); - } - - public void testNewSecantSolverValid() { - UnivariateRealSolver solver = factory.newSecantSolver(); - assertNotNull(solver); - assertTrue(solver instanceof SecantSolver); - } - -} diff --git a/src/test/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverUtilsTest.java b/src/test/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverUtilsTest.java index bfa016859..ce6002d7d 100644 --- a/src/test/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverUtilsTest.java +++ b/src/test/java/org/apache/commons/math/analysis/solvers/UnivariateRealSolverUtilsTest.java @@ -17,125 +17,168 @@ package org.apache.commons.math.analysis.solvers; -import junit.framework.TestCase; - import org.apache.commons.math.MathException; import org.apache.commons.math.analysis.SinFunction; +import org.apache.commons.math.analysis.QuinticFunction; import org.apache.commons.math.analysis.UnivariateRealFunction; import org.apache.commons.math.util.FastMath; +import org.junit.Assert; +import org.junit.Test; /** * @version $Revision$ $Date$ */ -public class UnivariateRealSolverUtilsTest extends TestCase { +public class UnivariateRealSolverUtilsTest { protected UnivariateRealFunction sin = new SinFunction(); - public void testSolveNull() throws MathException { + @Test + public void testSolveNull() { try { UnivariateRealSolverUtils.solve(null, 0.0, 4.0); - fail(); + Assert.fail(); } catch(IllegalArgumentException ex){ // success } } - public void testSolveBadEndpoints() throws MathException { + @Test + public void testSolveBadEndpoints() { try { // bad endpoints - UnivariateRealSolverUtils.solve(sin, -0.1, 4.0, 4.0); - fail("Expecting IllegalArgumentException"); + double root = UnivariateRealSolverUtils.solve(sin, 4.0, -0.1, 1e-6); + System.out.println("root=" + root); + Assert.fail("Expecting IllegalArgumentException"); } catch (IllegalArgumentException ex) { // expected } } - public void testSolveBadAccuracy() throws MathException { + @Test + public void testSolveBadAccuracy() { try { // bad accuracy UnivariateRealSolverUtils.solve(sin, 0.0, 4.0, 0.0); -// fail("Expecting IllegalArgumentException"); // TODO needs rework since convergence behaviour was changed +// Assert.fail("Expecting IllegalArgumentException"); // TODO needs rework since convergence behaviour was changed } catch (IllegalArgumentException ex) { // expected } } - public void testSolveSin() throws MathException { + @Test + public void testSolveSin() { double x = UnivariateRealSolverUtils.solve(sin, 1.0, 4.0); - assertEquals(FastMath.PI, x, 1.0e-4); + Assert.assertEquals(FastMath.PI, x, 1.0e-4); } - public void testSolveAccuracyNull() throws MathException { + @Test + public void testSolveAccuracyNull() { try { double accuracy = 1.0e-6; UnivariateRealSolverUtils.solve(null, 0.0, 4.0, accuracy); - fail(); + Assert.fail(); } catch(IllegalArgumentException ex){ // success } } - public void testSolveAccuracySin() throws MathException { + @Test + public void testSolveAccuracySin() { double accuracy = 1.0e-6; double x = UnivariateRealSolverUtils.solve(sin, 1.0, 4.0, accuracy); - assertEquals(FastMath.PI, x, accuracy); + Assert.assertEquals(FastMath.PI, x, accuracy); } - public void testSolveNoRoot() throws MathException { + @Test + public void testSolveNoRoot() { try { UnivariateRealSolverUtils.solve(sin, 1.0, 1.5); - fail("Expecting IllegalArgumentException "); + Assert.fail("Expecting IllegalArgumentException "); } catch (IllegalArgumentException ex) { // expected } } - public void testBracketSin() throws MathException { + @Test + public void testBracketSin() { double[] result = UnivariateRealSolverUtils.bracket(sin, 0.0, -2.0, 2.0); - assertTrue(sin.value(result[0]) < 0); - assertTrue(sin.value(result[1]) > 0); + Assert.assertTrue(sin.value(result[0]) < 0); + Assert.assertTrue(sin.value(result[1]) > 0); } - public void testBracketEndpointRoot() throws MathException { + @Test + public void testBracketEndpointRoot() { double[] result = UnivariateRealSolverUtils.bracket(sin, 1.5, 0, 2.0); - assertEquals(0.0, sin.value(result[0]), 1.0e-15); - assertTrue(sin.value(result[1]) > 0); + Assert.assertEquals(0.0, sin.value(result[0]), 1.0e-15); + Assert.assertTrue(sin.value(result[1]) > 0); } - public void testNullFunction() throws MathException { + @Test + public void testNullFunction() { try { // null function UnivariateRealSolverUtils.bracket(null, 1.5, 0, 2.0); - fail("Expecting IllegalArgumentException"); + Assert.fail("Expecting IllegalArgumentException"); } catch (IllegalArgumentException ex) { // expected } } - public void testBadInitial() throws MathException { + @Test + public void testBadInitial() { try { // initial not between endpoints UnivariateRealSolverUtils.bracket(sin, 2.5, 0, 2.0); - fail("Expecting IllegalArgumentException"); + Assert.fail("Expecting IllegalArgumentException"); } catch (IllegalArgumentException ex) { // expected } } - public void testBadEndpoints() throws MathException { + @Test + public void testBadEndpoints() { try { // endpoints not valid UnivariateRealSolverUtils.bracket(sin, 1.5, 2.0, 1.0); - fail("Expecting IllegalArgumentException"); + Assert.fail("Expecting IllegalArgumentException"); } catch (IllegalArgumentException ex) { // expected } } - public void testBadMaximumIterations() throws MathException { + @Test + public void testBadMaximumIterations() { try { // bad maximum iterations UnivariateRealSolverUtils.bracket(sin, 1.5, 0, 2.0, 0); - fail("Expecting IllegalArgumentException"); + Assert.fail("Expecting IllegalArgumentException"); } catch (IllegalArgumentException ex) { // expected } } + @Test + public void testMisc() { + UnivariateRealFunction f = new QuinticFunction(); + double result; + // Static solve method + result = UnivariateRealSolverUtils.solve(f, -0.2, 0.2); + Assert.assertEquals(result, 0, 1E-8); + result = UnivariateRealSolverUtils.solve(f, -0.1, 0.3); + Assert.assertEquals(result, 0, 1E-8); + result = UnivariateRealSolverUtils.solve(f, -0.3, 0.45); + Assert.assertEquals(result, 0, 1E-6); + result = UnivariateRealSolverUtils.solve(f, 0.3, 0.7); + Assert.assertEquals(result, 0.5, 1E-6); + result = UnivariateRealSolverUtils.solve(f, 0.2, 0.6); + Assert.assertEquals(result, 0.5, 1E-6); + result = UnivariateRealSolverUtils.solve(f, 0.05, 0.95); + Assert.assertEquals(result, 0.5, 1E-6); + result = UnivariateRealSolverUtils.solve(f, 0.85, 1.25); + Assert.assertEquals(result, 1.0, 1E-6); + result = UnivariateRealSolverUtils.solve(f, 0.8, 1.2); + Assert.assertEquals(result, 1.0, 1E-6); + result = UnivariateRealSolverUtils.solve(f, 0.85, 1.75); + Assert.assertEquals(result, 1.0, 1E-6); + result = UnivariateRealSolverUtils.solve(f, 0.55, 1.45); + Assert.assertEquals(result, 1.0, 1E-6); + result = UnivariateRealSolverUtils.solve(f, 0.85, 5); + Assert.assertEquals(result, 1.0, 1E-6); + } } diff --git a/src/test/java/org/apache/commons/math/ode/nonstiff/DormandPrince853IntegratorTest.java b/src/test/java/org/apache/commons/math/ode/nonstiff/DormandPrince853IntegratorTest.java index d7a52dc48..d558adc54 100644 --- a/src/test/java/org/apache/commons/math/ode/nonstiff/DormandPrince853IntegratorTest.java +++ b/src/test/java/org/apache/commons/math/ode/nonstiff/DormandPrince853IntegratorTest.java @@ -235,7 +235,7 @@ public class DormandPrince853IntegratorTest pb.getInitialTime(), pb.getInitialState(), pb.getFinalTime(), new double[pb.getDimension()]); - assertTrue(handler.getMaximalValueError() < 5.0e-8); + assertEquals(0, handler.getMaximalValueError(), 1.1e-7); assertEquals(0, handler.getMaximalTimeError(), 1.0e-12); assertEquals(12.0, handler.getLastTime(), 1.0e-8 * maxStep); integ.clearEventHandlers(); diff --git a/src/test/java/org/apache/commons/math/ode/nonstiff/HighamHall54IntegratorTest.java b/src/test/java/org/apache/commons/math/ode/nonstiff/HighamHall54IntegratorTest.java index 454f23c43..b51bce89d 100644 --- a/src/test/java/org/apache/commons/math/ode/nonstiff/HighamHall54IntegratorTest.java +++ b/src/test/java/org/apache/commons/math/ode/nonstiff/HighamHall54IntegratorTest.java @@ -20,8 +20,9 @@ package org.apache.commons.math.ode.nonstiff; import junit.framework.TestCase; import org.apache.commons.math.ConvergenceException; -import org.apache.commons.math.exception.MathUserException; import org.apache.commons.math.exception.util.LocalizedFormats; +import org.apache.commons.math.exception.TooManyEvaluationsException; +import org.apache.commons.math.exception.MathUserException; import org.apache.commons.math.ode.FirstOrderDifferentialEquations; import org.apache.commons.math.ode.FirstOrderIntegrator; import org.apache.commons.math.ode.IntegratorException; @@ -268,9 +269,8 @@ public class HighamHall54IntegratorTest pb.getInitialTime(), pb.getInitialState(), pb.getFinalTime(), new double[pb.getDimension()]); fail("an exception should have been thrown"); - } catch (IntegratorException ie) { - assertTrue(ie.getCause() != null); - assertTrue(ie.getCause() instanceof ConvergenceException); + } catch (TooManyEvaluationsException tmee) { + // Expected. } } diff --git a/src/test/java/org/apache/commons/math/optimization/general/CircleScalar.java b/src/test/java/org/apache/commons/math/optimization/general/CircleScalar.java new file mode 100644 index 000000000..deb0fa456 --- /dev/null +++ b/src/test/java/org/apache/commons/math/optimization/general/CircleScalar.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.optimization.general; + +import java.awt.geom.Point2D; +import java.util.ArrayList; +import org.apache.commons.math.analysis.DifferentiableMultivariateRealFunction; +import org.apache.commons.math.analysis.MultivariateRealFunction; +import org.apache.commons.math.analysis.MultivariateVectorialFunction; + +/** + * Class used in the tests. + */ +class CircleScalar implements DifferentiableMultivariateRealFunction { + private ArrayList points; + + public CircleScalar() { + points = new ArrayList(); + } + + public void addPoint(double px, double py) { + points.add(new Point2D.Double(px, py)); + } + + public double getRadius(Point2D.Double center) { + double r = 0; + for (Point2D.Double point : points) { + r += point.distance(center); + } + return r / points.size(); + } + + private double[] gradient(double[] point) { + // optimal radius + Point2D.Double center = new Point2D.Double(point[0], point[1]); + double radius = getRadius(center); + + // gradient of the sum of squared residuals + double dJdX = 0; + double dJdY = 0; + for (Point2D.Double pk : points) { + double dk = pk.distance(center); + dJdX += (center.x - pk.x) * (dk - radius) / dk; + dJdY += (center.y - pk.y) * (dk - radius) / dk; + } + dJdX *= 2; + dJdY *= 2; + + return new double[] { dJdX, dJdY }; + } + + public double value(double[] variables) { + Point2D.Double center = new Point2D.Double(variables[0], variables[1]); + double radius = getRadius(center); + + double sum = 0; + for (Point2D.Double point : points) { + double di = point.distance(center) - radius; + sum += di * di; + } + + return sum; + } + + public MultivariateVectorialFunction gradient() { + return new MultivariateVectorialFunction() { + private static final long serialVersionUID = 3174909643301201710L; + public double[] value(double[] point) { + return gradient(point); + } + }; + } + + public MultivariateRealFunction partialDerivative(final int k) { + return new MultivariateRealFunction() { + private static final long serialVersionUID = 3073956364104833888L; + public double value(double[] point) { + return gradient(point)[k]; + } + }; + } +} diff --git a/src/test/java/org/apache/commons/math/optimization/general/CircleVectorial.java b/src/test/java/org/apache/commons/math/optimization/general/CircleVectorial.java new file mode 100644 index 000000000..992ff2c23 --- /dev/null +++ b/src/test/java/org/apache/commons/math/optimization/general/CircleVectorial.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.optimization.general; + +import java.awt.geom.Point2D; +import java.util.ArrayList; +import org.apache.commons.math.analysis.DifferentiableMultivariateVectorialFunction; +import org.apache.commons.math.analysis.MultivariateMatrixFunction; + +/** + * Class used in the tests. + */ +class CircleVectorial implements DifferentiableMultivariateVectorialFunction { + private ArrayList points; + + public CircleVectorial() { + points = new ArrayList(); + } + + public void addPoint(double px, double py) { + points.add(new Point2D.Double(px, py)); + } + + public int getN() { + return points.size(); + } + + public double getRadius(Point2D.Double center) { + double r = 0; + for (Point2D.Double point : points) { + r += point.distance(center); + } + return r / points.size(); + } + + private double[][] jacobian(double[] point) { + int n = points.size(); + Point2D.Double center = new Point2D.Double(point[0], point[1]); + + // gradient of the optimal radius + double dRdX = 0; + double dRdY = 0; + for (Point2D.Double pk : points) { + double dk = pk.distance(center); + dRdX += (center.x - pk.x) / dk; + dRdY += (center.y - pk.y) / dk; + } + dRdX /= n; + dRdY /= n; + + // jacobian of the radius residuals + double[][] jacobian = new double[n][2]; + for (int i = 0; i < n; ++i) { + Point2D.Double pi = points.get(i); + double di = pi.distance(center); + jacobian[i][0] = (center.x - pi.x) / di - dRdX; + jacobian[i][1] = (center.y - pi.y) / di - dRdY; + } + + return jacobian; + } + + public double[] value(double[] variables) { + Point2D.Double center = new Point2D.Double(variables[0], variables[1]); + double radius = getRadius(center); + + double[] residuals = new double[points.size()]; + for (int i = 0; i < residuals.length; ++i) { + residuals[i] = points.get(i).distance(center) - radius; + } + + return residuals; + } + + public MultivariateMatrixFunction jacobian() { + return new MultivariateMatrixFunction() { + private static final long serialVersionUID = -4340046230875165095L; + public double[][] value(double[] point) { + return jacobian(point); + } + }; + } +} diff --git a/src/test/java/org/apache/commons/math/optimization/general/GaussNewtonOptimizerTest.java b/src/test/java/org/apache/commons/math/optimization/general/GaussNewtonOptimizerTest.java index a511ef4e0..fcd9eb6fe 100644 --- a/src/test/java/org/apache/commons/math/optimization/general/GaussNewtonOptimizerTest.java +++ b/src/test/java/org/apache/commons/math/optimization/general/GaussNewtonOptimizerTest.java @@ -368,7 +368,7 @@ extends TestCase { } public void testMaxEvaluations() throws Exception { - Circle circle = new Circle(); + CircleVectorial circle = new CircleVectorial(); circle.addPoint( 30.0, 68.0); circle.addPoint( 50.0, -6.0); circle.addPoint(110.0, -20.0); @@ -388,7 +388,7 @@ extends TestCase { } public void testCircleFitting() throws MathUserException { - Circle circle = new Circle(); + CircleVectorial circle = new CircleVectorial(); circle.addPoint( 30.0, 68.0); circle.addPoint( 50.0, -6.0); circle.addPoint(110.0, -20.0); @@ -409,7 +409,7 @@ extends TestCase { } public void testCircleFittingBadInit() throws MathUserException { - Circle circle = new Circle(); + CircleVectorial circle = new CircleVectorial(); double[][] points = new double[][] { {-0.312967, 0.072366}, {-0.339248, 0.132965}, {-0.379780, 0.202724}, {-0.390426, 0.260487}, {-0.361212, 0.328325}, {-0.346039, 0.392619}, @@ -488,86 +488,5 @@ extends TestCase { } }; } - } - - private static class Circle implements DifferentiableMultivariateVectorialFunction, Serializable { - - private static final long serialVersionUID = -7165774454925027042L; - private ArrayList points; - - public Circle() { - points = new ArrayList(); - } - - public void addPoint(double px, double py) { - points.add(new Point2D.Double(px, py)); - } - - public int getN() { - return points.size(); - } - - public double getRadius(Point2D.Double center) { - double r = 0; - for (Point2D.Double point : points) { - r += point.distance(center); - } - return r / points.size(); - } - - private double[][] jacobian(double[] variables) { - - int n = points.size(); - Point2D.Double center = new Point2D.Double(variables[0], variables[1]); - - // gradient of the optimal radius - double dRdX = 0; - double dRdY = 0; - for (Point2D.Double pk : points) { - double dk = pk.distance(center); - dRdX += (center.x - pk.x) / dk; - dRdY += (center.y - pk.y) / dk; - } - dRdX /= n; - dRdY /= n; - - // jacobian of the radius residuals - double[][] jacobian = new double[n][2]; - for (int i = 0; i < n; ++i) { - Point2D.Double pi = points.get(i); - double di = pi.distance(center); - jacobian[i][0] = (center.x - pi.x) / di - dRdX; - jacobian[i][1] = (center.y - pi.y) / di - dRdY; - } - - return jacobian; - - } - - public double[] value(double[] variables) { - - Point2D.Double center = new Point2D.Double(variables[0], variables[1]); - double radius = getRadius(center); - - double[] residuals = new double[points.size()]; - for (int i = 0; i < residuals.length; ++i) { - residuals[i] = points.get(i).distance(center) - radius; - } - - return residuals; - - } - - public MultivariateMatrixFunction jacobian() { - return new MultivariateMatrixFunction() { - private static final long serialVersionUID = -4340046230875165095L; - public double[][] value(double[] point) { - return jacobian(point); - } - }; - } - - } - } diff --git a/src/test/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizerTest.java b/src/test/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizerTest.java index 0e4d4b8e5..8490a85f4 100644 --- a/src/test/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizerTest.java +++ b/src/test/java/org/apache/commons/math/optimization/general/LevenbergMarquardtOptimizerTest.java @@ -25,20 +25,19 @@ import java.util.List; import junit.framework.TestCase; -import org.apache.commons.math.analysis.DifferentiableMultivariateVectorialFunction; -import org.apache.commons.math.analysis.MultivariateMatrixFunction; +import org.apache.commons.math.exception.SingularMatrixException; import org.apache.commons.math.exception.ConvergenceException; import org.apache.commons.math.exception.DimensionMismatchException; -import org.apache.commons.math.exception.MathUserException; -import org.apache.commons.math.exception.NumberIsTooSmallException; -import org.apache.commons.math.exception.SingularMatrixException; import org.apache.commons.math.exception.TooManyEvaluationsException; +import org.apache.commons.math.exception.NumberIsTooSmallException; +import org.apache.commons.math.analysis.DifferentiableMultivariateVectorialFunction; +import org.apache.commons.math.analysis.MultivariateMatrixFunction; import org.apache.commons.math.linear.BlockRealMatrix; import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.optimization.SimpleVectorialValueChecker; import org.apache.commons.math.optimization.VectorialPointValuePair; -import org.apache.commons.math.util.FastMath; import org.apache.commons.math.util.MathUtils; +import org.apache.commons.math.util.FastMath; /** *

    Some of the unit tests are re-implementations of the MINPACK 0.1); } - public void testInconsistentSizes() throws MathUserException { + public void testInconsistentSizes() { LinearProblem problem = new LinearProblem(new double[][] { { 1, 0 }, { 0, 1 } }, new double[] { -1, 1 }); LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(); @@ -346,8 +345,8 @@ public class LevenbergMarquardtOptimizerTest } } - public void testControlParameters() throws MathUserException { - Circle circle = new Circle(); + public void testControlParameters() { + CircleVectorial circle = new CircleVectorial(); circle.addPoint( 30.0, 68.0); circle.addPoint( 50.0, -6.0); circle.addPoint(110.0, -20.0); @@ -363,7 +362,7 @@ public class LevenbergMarquardtOptimizerTest private void checkEstimate(DifferentiableMultivariateVectorialFunction problem, double initialStepBoundFactor, int maxCostEval, double costRelativeTolerance, double parRelativeTolerance, - double orthoTolerance, boolean shouldFail) throws MathUserException { + double orthoTolerance, boolean shouldFail) { try { LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(initialStepBoundFactor, @@ -382,8 +381,8 @@ public class LevenbergMarquardtOptimizerTest } } - public void testCircleFitting() throws MathUserException { - Circle circle = new Circle(); + public void testCircleFitting() { + CircleVectorial circle = new CircleVectorial(); circle.addPoint( 30.0, 68.0); circle.addPoint( 50.0, -6.0); circle.addPoint(110.0, -20.0); @@ -430,8 +429,8 @@ public class LevenbergMarquardtOptimizerTest assertEquals(0.004, errors[1], 0.001); } - public void testCircleFittingBadInit() throws MathUserException { - Circle circle = new Circle(); + public void testCircleFittingBadInit() { + CircleVectorial circle = new CircleVectorial(); double[][] points = new double[][] { {-0.312967, 0.072366}, {-0.339248, 0.132965}, {-0.379780, 0.202724}, {-0.390426, 0.260487}, {-0.361212, 0.328325}, {-0.346039, 0.392619}, @@ -483,7 +482,7 @@ public class LevenbergMarquardtOptimizerTest assertEquals( 0.2075001, center.y, 1.0e-6); } - public void testMath199() throws MathUserException { + public void testMath199() { try { QuadraticProblem problem = new QuadraticProblem(); problem.addPoint (0, -3.182591015485607); @@ -527,83 +526,6 @@ public class LevenbergMarquardtOptimizerTest } } - private static class Circle implements DifferentiableMultivariateVectorialFunction, Serializable { - - private static final long serialVersionUID = -4711170319243817874L; - - private ArrayList points; - - public Circle() { - points = new ArrayList(); - } - - public void addPoint(double px, double py) { - points.add(new Point2D.Double(px, py)); - } - - public int getN() { - return points.size(); - } - - public double getRadius(Point2D.Double center) { - double r = 0; - for (Point2D.Double point : points) { - r += point.distance(center); - } - return r / points.size(); - } - - private double[][] jacobian(double[] point) { - - int n = points.size(); - Point2D.Double center = new Point2D.Double(point[0], point[1]); - - // gradient of the optimal radius - double dRdX = 0; - double dRdY = 0; - for (Point2D.Double pk : points) { - double dk = pk.distance(center); - dRdX += (center.x - pk.x) / dk; - dRdY += (center.y - pk.y) / dk; - } - dRdX /= n; - dRdY /= n; - - // jacobian of the radius residuals - double[][] jacobian = new double[n][2]; - for (int i = 0; i < n; ++i) { - Point2D.Double pi = points.get(i); - double di = pi.distance(center); - jacobian[i][0] = (center.x - pi.x) / di - dRdX; - jacobian[i][1] = (center.y - pi.y) / di - dRdY; - } - - return jacobian; - } - - public double[] value(double[] variables) { - - Point2D.Double center = new Point2D.Double(variables[0], variables[1]); - double radius = getRadius(center); - - double[] residuals = new double[points.size()]; - for (int i = 0; i < residuals.length; ++i) { - residuals[i] = points.get(i).distance(center) - radius; - } - - return residuals; - } - - public MultivariateMatrixFunction jacobian() { - return new MultivariateMatrixFunction() { - private static final long serialVersionUID = -4340046230875165095L; - public double[][] value(double[] point) { - return jacobian(point); - } - }; - } - } - private static class QuadraticProblem implements DifferentiableMultivariateVectorialFunction, Serializable { private static final long serialVersionUID = 7072187082052755854L; diff --git a/src/test/java/org/apache/commons/math/optimization/general/NonLinearConjugateGradientOptimizerTest.java b/src/test/java/org/apache/commons/math/optimization/general/NonLinearConjugateGradientOptimizerTest.java index c13659313..b8b84bb6e 100644 --- a/src/test/java/org/apache/commons/math/optimization/general/NonLinearConjugateGradientOptimizerTest.java +++ b/src/test/java/org/apache/commons/math/optimization/general/NonLinearConjugateGradientOptimizerTest.java @@ -21,18 +21,18 @@ import java.awt.geom.Point2D; import java.io.Serializable; import java.util.ArrayList; -import junit.framework.TestCase; - import org.apache.commons.math.analysis.DifferentiableMultivariateRealFunction; import org.apache.commons.math.analysis.MultivariateRealFunction; import org.apache.commons.math.analysis.MultivariateVectorialFunction; +import org.apache.commons.math.analysis.solvers.UnivariateRealSolver; import org.apache.commons.math.analysis.solvers.BrentSolver; -import org.apache.commons.math.exception.MathUserException; import org.apache.commons.math.linear.BlockRealMatrix; import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.optimization.GoalType; import org.apache.commons.math.optimization.RealPointValuePair; import org.apache.commons.math.optimization.SimpleScalarValueChecker; +import org.junit.Assert; +import org.junit.Test; /** *

    Some of the unit tests are re-implementations of the MINPACK 0.5); + Assert.assertTrue(optimum.getValue() > 0.5); } - public void testIllConditioned() throws MathUserException { + @Test + public void testIllConditioned() { LinearProblem problem1 = new LinearProblem(new double[][] { { 10.0, 7.0, 8.0, 7.0 }, { 7.0, 5.0, 6.0, 5.0 }, @@ -239,16 +237,14 @@ extends TestCase { new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE); optimizer.setMaxEvaluations(100); optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-13, 1.0e-13)); - BrentSolver solver = new BrentSolver(); - solver.setAbsoluteAccuracy(1.0e-15); - solver.setRelativeAccuracy(1.0e-15); + BrentSolver solver = new BrentSolver(1e-15, 1e-15); optimizer.setLineSearchSolver(solver); RealPointValuePair optimum1 = optimizer.optimize(problem1, GoalType.MINIMIZE, new double[] { 0, 1, 2, 3 }); - assertEquals(1.0, optimum1.getPoint()[0], 1.0e-5); - assertEquals(1.0, optimum1.getPoint()[1], 1.0e-5); - assertEquals(1.0, optimum1.getPoint()[2], 1.0e-5); - assertEquals(1.0, optimum1.getPoint()[3], 1.0e-5); + Assert.assertEquals(1.0, optimum1.getPoint()[0], 1.0e-4); + Assert.assertEquals(1.0, optimum1.getPoint()[1], 1.0e-4); + Assert.assertEquals(1.0, optimum1.getPoint()[2], 1.0e-4); + Assert.assertEquals(1.0, optimum1.getPoint()[3], 1.0e-4); LinearProblem problem2 = new LinearProblem(new double[][] { { 10.00, 7.00, 8.10, 7.20 }, @@ -258,16 +254,15 @@ extends TestCase { }, new double[] { 32, 23, 33, 31 }); RealPointValuePair optimum2 = optimizer.optimize(problem2, GoalType.MINIMIZE, new double[] { 0, 1, 2, 3 }); - assertEquals(-81.0, optimum2.getPoint()[0], 1.0e-1); - assertEquals(137.0, optimum2.getPoint()[1], 1.0e-1); - assertEquals(-34.0, optimum2.getPoint()[2], 1.0e-1); - assertEquals( 22.0, optimum2.getPoint()[3], 1.0e-1); + Assert.assertEquals(-81.0, optimum2.getPoint()[0], 1.0e-1); + Assert.assertEquals(137.0, optimum2.getPoint()[1], 1.0e-1); + Assert.assertEquals(-34.0, optimum2.getPoint()[2], 1.0e-1); + Assert.assertEquals( 22.0, optimum2.getPoint()[3], 1.0e-1); } - public void testMoreEstimatedParametersSimple() - throws MathUserException { - + @Test + public void testMoreEstimatedParametersSimple() { LinearProblem problem = new LinearProblem(new double[][] { { 3.0, 2.0, 0.0, 0.0 }, { 0.0, 1.0, -1.0, 1.0 }, @@ -280,12 +275,12 @@ extends TestCase { optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6)); RealPointValuePair optimum = optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 7, 6, 5, 4 }); - assertEquals(0, optimum.getValue(), 1.0e-10); + Assert.assertEquals(0, optimum.getValue(), 1.0e-10); } - public void testMoreEstimatedParametersUnsorted() - throws MathUserException { + @Test + public void testMoreEstimatedParametersUnsorted() { LinearProblem problem = new LinearProblem(new double[][] { { 1.0, 1.0, 0.0, 0.0, 0.0, 0.0 }, { 0.0, 0.0, 1.0, 1.0, 1.0, 0.0 }, @@ -299,10 +294,11 @@ extends TestCase { optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6)); RealPointValuePair optimum = optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 2, 2, 2, 2, 2, 2 }); - assertEquals(0, optimum.getValue(), 1.0e-10); + Assert.assertEquals(0, optimum.getValue(), 1.0e-10); } - public void testRedundantEquations() throws MathUserException { + @Test + public void testRedundantEquations() { LinearProblem problem = new LinearProblem(new double[][] { { 1.0, 1.0 }, { 1.0, -1.0 }, @@ -315,12 +311,13 @@ extends TestCase { optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6)); RealPointValuePair optimum = optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 1, 1 }); - assertEquals(2.0, optimum.getPoint()[0], 1.0e-8); - assertEquals(1.0, optimum.getPoint()[1], 1.0e-8); + Assert.assertEquals(2.0, optimum.getPoint()[0], 1.0e-8); + Assert.assertEquals(1.0, optimum.getPoint()[1], 1.0e-8); } - public void testInconsistentEquations() throws MathUserException { + @Test + public void testInconsistentEquations() { LinearProblem problem = new LinearProblem(new double[][] { { 1.0, 1.0 }, { 1.0, -1.0 }, @@ -333,12 +330,13 @@ extends TestCase { optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-6, 1.0e-6)); RealPointValuePair optimum = optimizer.optimize(problem, GoalType.MINIMIZE, new double[] { 1, 1 }); - assertTrue(optimum.getValue() > 0.1); + Assert.assertTrue(optimum.getValue() > 0.1); } - public void testCircleFitting() throws MathUserException { - Circle circle = new Circle(); + @Test + public void testCircleFitting() { + CircleScalar circle = new CircleScalar(); circle.addPoint( 30.0, 68.0); circle.addPoint( 50.0, -6.0); circle.addPoint(110.0, -20.0); @@ -348,16 +346,14 @@ extends TestCase { new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE); optimizer.setMaxEvaluations(100); optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-30, 1.0e-30)); - BrentSolver solver = new BrentSolver(); - solver.setAbsoluteAccuracy(1.0e-13); - solver.setRelativeAccuracy(1.0e-15); + UnivariateRealSolver solver = new BrentSolver(1e-15, 1e-13); optimizer.setLineSearchSolver(solver); RealPointValuePair optimum = optimizer.optimize(circle, GoalType.MINIMIZE, new double[] { 98.680, 47.345 }); Point2D.Double center = new Point2D.Double(optimum.getPointRef()[0], optimum.getPointRef()[1]); - assertEquals(69.960161753, circle.getRadius(center), 1.0e-8); - assertEquals(96.075902096, center.x, 1.0e-8); - assertEquals(48.135167894, center.y, 1.0e-8); + Assert.assertEquals(69.960161753, circle.getRadius(center), 1.0e-8); + Assert.assertEquals(96.075902096, center.x, 1.0e-8); + Assert.assertEquals(48.135167894, center.y, 1.0e-8); } private static class LinearProblem implements DifferentiableMultivariateRealFunction, Serializable { @@ -382,7 +378,7 @@ extends TestCase { return p; } - public double value(double[] variables) throws MathUserException { + public double value(double[] variables) { double[] y = factors.operate(variables); double sum = 0; for (int i = 0; i < y.length; ++i) { @@ -409,86 +405,5 @@ extends TestCase { } }; } - } - - private static class Circle implements DifferentiableMultivariateRealFunction, Serializable { - - private static final long serialVersionUID = -4711170319243817874L; - - private ArrayList points; - - public Circle() { - points = new ArrayList(); - } - - public void addPoint(double px, double py) { - points.add(new Point2D.Double(px, py)); - } - - public double getRadius(Point2D.Double center) { - double r = 0; - for (Point2D.Double point : points) { - r += point.distance(center); - } - return r / points.size(); - } - - private double[] gradient(double[] point) { - - // optimal radius - Point2D.Double center = new Point2D.Double(point[0], point[1]); - double radius = getRadius(center); - - // gradient of the sum of squared residuals - double dJdX = 0; - double dJdY = 0; - for (Point2D.Double pk : points) { - double dk = pk.distance(center); - dJdX += (center.x - pk.x) * (dk - radius) / dk; - dJdY += (center.y - pk.y) * (dk - radius) / dk; - } - dJdX *= 2; - dJdY *= 2; - - return new double[] { dJdX, dJdY }; - - } - - public double value(double[] variables) - throws IllegalArgumentException, MathUserException { - - Point2D.Double center = new Point2D.Double(variables[0], variables[1]); - double radius = getRadius(center); - - double sum = 0; - for (Point2D.Double point : points) { - double di = point.distance(center) - radius; - sum += di * di; - } - - return sum; - - } - - public MultivariateVectorialFunction gradient() { - return new MultivariateVectorialFunction() { - private static final long serialVersionUID = 3174909643301201710L; - public double[] value(double[] point) { - return gradient(point); - } - }; - } - - public MultivariateRealFunction partialDerivative(final int k) { - return new MultivariateRealFunction() { - private static final long serialVersionUID = 3073956364104833888L; - public double value(double[] point) { - return gradient(point)[k]; - } - }; - } - - } - }