From ec46f40b0be95ff28b07204a5304de8e5262aef1 Mon Sep 17 00:00:00 2001 From: Gilles Sadowski Date: Sat, 24 Nov 2012 11:11:10 +0000 Subject: [PATCH] MATH-902 Allow stopping condition based on the number of iterations (for "univariate" optimizers). git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1413171 13f79535-47bb-0310-9956-ffa450edef68 --- .../SimpleUnivariateValueChecker.java | 56 ++++++++++++++++++- .../SimpleUnivariateValueCheckerTest.java | 52 +++++++++++++++++ 2 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 src/test/java/org/apache/commons/math3/optimization/univariate/SimpleUnivariateValueCheckerTest.java diff --git a/src/main/java/org/apache/commons/math3/optimization/univariate/SimpleUnivariateValueChecker.java b/src/main/java/org/apache/commons/math3/optimization/univariate/SimpleUnivariateValueChecker.java index 7a7a4277d..1aa7b55e0 100644 --- a/src/main/java/org/apache/commons/math3/optimization/univariate/SimpleUnivariateValueChecker.java +++ b/src/main/java/org/apache/commons/math3/optimization/univariate/SimpleUnivariateValueChecker.java @@ -18,6 +18,7 @@ package org.apache.commons.math3.optimization.univariate; import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; import org.apache.commons.math3.optimization.AbstractConvergenceChecker; /** @@ -29,18 +30,39 @@ import org.apache.commons.math3.optimization.AbstractConvergenceChecker; * difference between the objective function values is smaller than a * threshold or if either the absolute difference between the objective * function values is smaller than another threshold. + *
+ * The {@link #converged(int,UnivariatePointValuePair,UnivariatePointValuePair) + * converged} method will also return {@code true} if the number of iterations + * has been set (see {@link #SimpleUnivariateValueChecker(double,double,int) + * this constructor}). * * @version $Id$ * @since 3.1 */ public class SimpleUnivariateValueChecker extends AbstractConvergenceChecker { + /** + * If {@link #maxIterationCount} is set to this value, the number of + * iterations will never cause + * {@link #converged(int,UnivariatePointValuePair,UnivariatePointValuePair)} + * to return {@code true}. + */ + private static final int ITERATION_CHECK_DISABLED = -1; + /** + * Number of iterations after which the + * {@link #converged(int,UnivariatePointValuePair,UnivariatePointValuePair)} + * method will return true (unless the check is disabled). + */ + private final int maxIterationCount; + /** * Build an instance with default thresholds. * @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()} */ @Deprecated - public SimpleUnivariateValueChecker() {} + public SimpleUnivariateValueChecker() { + maxIterationCount = ITERATION_CHECK_DISABLED; + } /** Build an instance with specified thresholds. * @@ -54,6 +76,32 @@ public class SimpleUnivariateValueChecker public SimpleUnivariateValueChecker(final double relativeThreshold, final double absoluteThreshold) { super(relativeThreshold, absoluteThreshold); + maxIterationCount = ITERATION_CHECK_DISABLED; + } + + /** + * Builds an instance with specified thresholds. + * + * In order to perform only relative checks, the absolute tolerance + * must be set to a negative value. In order to perform only absolute + * checks, the relative tolerance must be set to a negative value. + * + * @param relativeThreshold relative tolerance threshold + * @param absoluteThreshold absolute tolerance threshold + * @param maxIter Maximum iteration count. + * @throws NotStrictlyPositiveException if {@code maxIter <= 0}. + * + * @since 3.1 + */ + public SimpleUnivariateValueChecker(final double relativeThreshold, + final double absoluteThreshold, + final int maxIter) { + super(relativeThreshold, absoluteThreshold); + + if (maxIter <= 0) { + throw new NotStrictlyPositiveException(maxIter); + } + maxIterationCount = maxIter; } /** @@ -76,6 +124,12 @@ public class SimpleUnivariateValueChecker public boolean converged(final int iteration, final UnivariatePointValuePair previous, final UnivariatePointValuePair current) { + if (maxIterationCount != ITERATION_CHECK_DISABLED) { + if (iteration >= maxIterationCount) { + return true; + } + } + final double p = previous.getValue(); final double c = current.getValue(); final double difference = FastMath.abs(p - c); diff --git a/src/test/java/org/apache/commons/math3/optimization/univariate/SimpleUnivariateValueCheckerTest.java b/src/test/java/org/apache/commons/math3/optimization/univariate/SimpleUnivariateValueCheckerTest.java new file mode 100644 index 000000000..c10a26f7b --- /dev/null +++ b/src/test/java/org/apache/commons/math3/optimization/univariate/SimpleUnivariateValueCheckerTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.optimization.univariate; + +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.junit.Test; +import org.junit.Assert; + +public class SimpleUnivariateValueCheckerTest { + @Test(expected=NotStrictlyPositiveException.class) + public void testIterationCheckPrecondition() { + new SimpleUnivariateValueChecker(1e-1, 1e-2, 0); + } + + @Test + public void testIterationCheck() { + final int max = 10; + final SimpleUnivariateValueChecker checker = new SimpleUnivariateValueChecker(1e-1, 1e-2, max); + Assert.assertTrue(checker.converged(max, null, null)); + Assert.assertTrue(checker.converged(max + 1, null, null)); + } + + @Test + public void testIterationCheckDisabled() { + final SimpleUnivariateValueChecker checker = new SimpleUnivariateValueChecker(1e-8, 1e-8); + + final UnivariatePointValuePair a = new UnivariatePointValuePair(1d, 1d); + final UnivariatePointValuePair b = new UnivariatePointValuePair(10d, 10d); + + Assert.assertFalse(checker.converged(-1, a, b)); + Assert.assertFalse(checker.converged(0, a, b)); + Assert.assertFalse(checker.converged(1000000, a, b)); + + Assert.assertTrue(checker.converged(-1, a, a)); + Assert.assertTrue(checker.converged(-1, b, b)); + } + +}