Added a constructor in the custom checkers that enables normal termination
of an optimization algorithm (i.e. returning the curent best point after a
selected number of iterations have been performed).


git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1411807 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Gilles Sadowski 2012-11-20 19:29:16 +00:00
parent a20e951321
commit 8da1dede24
4 changed files with 191 additions and 8 deletions

View File

@ -19,6 +19,7 @@ package org.apache.commons.math3.optimization;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair; import org.apache.commons.math3.util.Pair;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
/** /**
* Simple implementation of the {@link ConvergenceChecker} interface using * Simple implementation of the {@link ConvergenceChecker} interface using
@ -28,6 +29,10 @@ import org.apache.commons.math3.util.Pair;
* difference between each point coordinate are smaller than a threshold * difference between each point coordinate are smaller than a threshold
* or if either the absolute difference between the point coordinates are * or if either the absolute difference between the point coordinates are
* smaller than another threshold. * smaller than another threshold.
* <br/>
* The {@link #converged(int,Pair,Pair) converged} method will also return
* {@code true} if the number of iterations has been set (see
* {@link #SimplePointChecker(double,double,int) this constructor}).
* *
* @param <PAIR> Type of the (point, value) pair. * @param <PAIR> Type of the (point, value) pair.
* The type of the "value" part of the pair (not used by this class). * The type of the "value" part of the pair (not used by this class).
@ -37,12 +42,27 @@ import org.apache.commons.math3.util.Pair;
*/ */
public class SimplePointChecker<PAIR extends Pair<double[], ? extends Object>> public class SimplePointChecker<PAIR extends Pair<double[], ? extends Object>>
extends AbstractConvergenceChecker<PAIR> { extends AbstractConvergenceChecker<PAIR> {
/**
* If {@link #maxIterationCount} is set to this value, the number of
* iterations will never cause {@link #converged(int,Pair,Pair>)}
* to return {@code true}.
*/
private static final int ITERATION_CHECK_DISABLED = -1;
/**
* Number of iterations after which the
* {@link #converged(int,PointVectorValuePair,PointVectorValuePair)} method
* will return true (unless the check is disabled).
*/
private final int maxIterationCount;
/** /**
* Build an instance with default threshold. * Build an instance with default threshold.
* @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()} * @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()}
*/ */
@Deprecated @Deprecated
public SimplePointChecker() {} public SimplePointChecker() {
maxIterationCount = ITERATION_CHECK_DISABLED;
}
/** /**
* Build an instance with specified thresholds. * Build an instance with specified thresholds.
@ -56,12 +76,38 @@ public class SimplePointChecker<PAIR extends Pair<double[], ? extends Object>>
public SimplePointChecker(final double relativeThreshold, public SimplePointChecker(final double relativeThreshold,
final double absoluteThreshold) { final double absoluteThreshold) {
super(relativeThreshold, absoluteThreshold); super(relativeThreshold, absoluteThreshold);
maxIterationCount = ITERATION_CHECK_DISABLED;
}
/**
* Builds an instance with specified thresholds.
* In order to perform only relative checks, the absolute tolerance
* must be set to a negative value. In order to perform only absolute
* checks, the relative tolerance must be set to a negative value.
*
* @param relativeThreshold Relative tolerance threshold.
* @param absoluteThreshold Absolute tolerance threshold.
* @param maxIter Maximum iteration count. Setting it to a negative
* value will disable this stopping criterion.
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
*
* @since 3.1
*/
public SimplePointChecker(final double relativeThreshold,
final double absoluteThreshold,
final int maxIter) {
super(relativeThreshold, absoluteThreshold);
if (maxIter <= 0) {
throw new NotStrictlyPositiveException(maxIter);
}
maxIterationCount = maxIter;
} }
/** /**
* Check if the optimization algorithm has converged considering the * Check if the optimization algorithm has converged considering the
* last two points. * last two points.
* This method may be called several time from the same algorithm * This method may be called several times from the same algorithm
* iteration with different points. This can be detected by checking the * iteration with different points. This can be detected by checking the
* iteration number at each call if needed. Each time this method is * iteration number at each call if needed. Each time this method is
* called, the previous and current point correspond to points with the * called, the previous and current point correspond to points with the
@ -72,12 +118,18 @@ public class SimplePointChecker<PAIR extends Pair<double[], ? extends Object>>
* @param iteration Index of current iteration * @param iteration Index of current iteration
* @param previous Best point in the previous iteration. * @param previous Best point in the previous iteration.
* @param current Best point in the current iteration. * @param current Best point in the current iteration.
* @return {@code true} if the algorithm has converged. * @return {@code true} if the arguments satify the convergence criterion.
*/ */
@Override @Override
public boolean converged(final int iteration, public boolean converged(final int iteration,
final PAIR previous, final PAIR previous,
final PAIR current) { final PAIR current) {
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
if (iteration >= maxIterationCount) {
return true;
}
}
final double[] p = previous.getKey(); final double[] p = previous.getKey();
final double[] c = current.getKey(); final double[] c = current.getKey();
for (int i = 0; i < p.length; ++i) { for (int i = 0; i < p.length; ++i) {

View File

@ -18,6 +18,7 @@
package org.apache.commons.math3.optimization; package org.apache.commons.math3.optimization;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
/** /**
* Simple implementation of the {@link ConvergenceChecker} interface using * Simple implementation of the {@link ConvergenceChecker} interface using
@ -27,18 +28,38 @@ import org.apache.commons.math3.util.FastMath;
* difference between the objective function values is smaller than a * difference between the objective function values is smaller than a
* threshold or if either the absolute difference between the objective * threshold or if either the absolute difference between the objective
* function values is smaller than another threshold. * function values is smaller than another threshold.
* <br/>
* The {@link #converged(int,PointValuePair,PointValuePair) converged}
* method will also return {@code true} if the number of iterations has been set
* (see {@link #SimpleValueChecker(double,double,int) this constructor}).
* *
* @version $Id$ * @version $Id$
* @since 3.0 * @since 3.0
*/ */
public class SimpleValueChecker public class SimpleValueChecker
extends AbstractConvergenceChecker<PointValuePair> { extends AbstractConvergenceChecker<PointValuePair> {
/**
* If {@link #maxIterationCount} is set to this value, the number of
* iterations will never cause
* {@link #converged(int,PointValuePair,PointValuePair)}
* to return {@code true}.
*/
private static final int ITERATION_CHECK_DISABLED = -1;
/**
* Number of iterations after which the
* {@link #converged(int,PointValuePair,PointValuePair)} method
* will return true (unless the check is disabled).
*/
private final int maxIterationCount;
/** /**
* Build an instance with default thresholds. * Build an instance with default thresholds.
* @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()} * @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()}
*/ */
@Deprecated @Deprecated
public SimpleValueChecker() {} public SimpleValueChecker() {
maxIterationCount = ITERATION_CHECK_DISABLED;
}
/** Build an instance with specified thresholds. /** Build an instance with specified thresholds.
* *
@ -52,6 +73,33 @@ public class SimpleValueChecker
public SimpleValueChecker(final double relativeThreshold, public SimpleValueChecker(final double relativeThreshold,
final double absoluteThreshold) { final double absoluteThreshold) {
super(relativeThreshold, absoluteThreshold); super(relativeThreshold, absoluteThreshold);
maxIterationCount = ITERATION_CHECK_DISABLED;
}
/**
* Builds an instance with specified thresholds.
*
* In order to perform only relative checks, the absolute tolerance
* must be set to a negative value. In order to perform only absolute
* checks, the relative tolerance must be set to a negative value.
*
* @param relativeThreshold relative tolerance threshold
* @param absoluteThreshold absolute tolerance threshold
* @param maxIter Maximum iteration count. Setting it to a negative
* value will disable this stopping criterion.
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
*
* @since 3.1
*/
public SimpleValueChecker(final double relativeThreshold,
final double absoluteThreshold,
final int maxIter) {
super(relativeThreshold, absoluteThreshold);
if (maxIter <= 0) {
throw new NotStrictlyPositiveException(maxIter);
}
maxIterationCount = maxIter;
} }
/** /**
@ -74,6 +122,12 @@ public class SimpleValueChecker
public boolean converged(final int iteration, public boolean converged(final int iteration,
final PointValuePair previous, final PointValuePair previous,
final PointValuePair current) { final PointValuePair current) {
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
if (iteration >= maxIterationCount) {
return true;
}
}
final double p = previous.getValue(); final double p = previous.getValue();
final double c = current.getValue(); final double c = current.getValue();
final double difference = FastMath.abs(p - c); final double difference = FastMath.abs(p - c);

View File

@ -18,6 +18,7 @@
package org.apache.commons.math3.optimization; package org.apache.commons.math3.optimization;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
/** /**
* Simple implementation of the {@link ConvergenceChecker} interface using * Simple implementation of the {@link ConvergenceChecker} interface using
@ -27,18 +28,38 @@ import org.apache.commons.math3.util.FastMath;
* difference between the objective function values is smaller than a * difference between the objective function values is smaller than a
* threshold or if either the absolute difference between the objective * threshold or if either the absolute difference between the objective
* function values is smaller than another threshold for all vectors elements. * function values is smaller than another threshold for all vectors elements.
* <br/>
* The {@link #converged(int,PointVectorValuePair,PointVectorValuePair) converged}
* method will also return {@code true} if the number of iterations has been set
* (see {@link #SimpleVectorValueChecker(double,double,int) this constructor}).
* *
* @version $Id$ * @version $Id$
* @since 3.0 * @since 3.0
*/ */
public class SimpleVectorValueChecker public class SimpleVectorValueChecker
extends AbstractConvergenceChecker<PointVectorValuePair> { extends AbstractConvergenceChecker<PointVectorValuePair> {
/**
* If {@link #maxIterationCount} is set to this value, the number of
* iterations will never cause
* {@link #converged(int,PointVectorValuePair,PointVectorValuePair)}
* to return {@code true}.
*/
private static final int ITERATION_CHECK_DISABLED = -1;
/**
* Number of iterations after which the
* {@link #converged(int,PointVectorValuePair,PointVectorValuePair)} method
* will return true (unless the check is disabled).
*/
private final int maxIterationCount;
/** /**
* Build an instance with default thresholds. * Build an instance with default thresholds.
* @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()} * @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()}
*/ */
@Deprecated @Deprecated
public SimpleVectorValueChecker() {} public SimpleVectorValueChecker() {
maxIterationCount = ITERATION_CHECK_DISABLED;
}
/** /**
* Build an instance with specified thresholds. * Build an instance with specified thresholds.
@ -53,12 +74,40 @@ public class SimpleVectorValueChecker
public SimpleVectorValueChecker(final double relativeThreshold, public SimpleVectorValueChecker(final double relativeThreshold,
final double absoluteThreshold) { final double absoluteThreshold) {
super(relativeThreshold, absoluteThreshold); super(relativeThreshold, absoluteThreshold);
maxIterationCount = ITERATION_CHECK_DISABLED;
}
/**
* Builds an instance with specified tolerance thresholds and
* iteration count.
*
* In order to perform only relative checks, the absolute tolerance
* must be set to a negative value. In order to perform only absolute
* checks, the relative tolerance must be set to a negative value.
*
* @param relativeThreshold Relative tolerance threshold.
* @param absoluteThreshold Absolute tolerance threshold.
* @param maxIter Maximum iteration count. Setting it to a negative
* value will disable this stopping criterion.
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
*
* @since 3.1
*/
public SimpleVectorValueChecker(final double relativeThreshold,
final double absoluteThreshold,
final int maxIter) {
super(relativeThreshold, absoluteThreshold);
if (maxIter <= 0) {
throw new NotStrictlyPositiveException(maxIter);
}
maxIterationCount = maxIter;
} }
/** /**
* Check if the optimization algorithm has converged considering the * Check if the optimization algorithm has converged considering the
* last two points. * last two points.
* This method may be called several time from the same algorithm * This method may be called several times from the same algorithm
* iteration with different points. This can be detected by checking the * iteration with different points. This can be detected by checking the
* iteration number at each call if needed. Each time this method is * iteration number at each call if needed. Each time this method is
* called, the previous and current point correspond to points with the * called, the previous and current point correspond to points with the
@ -69,12 +118,18 @@ public class SimpleVectorValueChecker
* @param iteration Index of current iteration * @param iteration Index of current iteration
* @param previous Best point in the previous iteration. * @param previous Best point in the previous iteration.
* @param current Best point in the current iteration. * @param current Best point in the current iteration.
* @return {@code true} if the algorithm has converged. * @return {@code true} if the arguments satify the convergence criterion.
*/ */
@Override @Override
public boolean converged(final int iteration, public boolean converged(final int iteration,
final PointVectorValuePair previous, final PointVectorValuePair previous,
final PointVectorValuePair current) { final PointVectorValuePair current) {
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
if (iteration >= maxIterationCount) {
return true;
}
}
final double[] p = previous.getValueRef(); final double[] p = previous.getValueRef();
final double[] c = current.getValueRef(); final double[] c = current.getValueRef();
for (int i = 0; i < p.length; ++i) { for (int i = 0; i < p.length; ++i) {

View File

@ -146,6 +146,28 @@ public class PolynomialFitterTest {
} }
} }
/**
* This test shows that the user can set the maximum number of iterations
* to avoid running for too long.
* Even if the real problem is that the tolerance is way too stringent, it
* is possible to get the best solution so far, i.e. a checker will return
* the point when the maximum iteration count has been reached.
*/
@Test
public void testMath798WithToleranceTooLowButNoException() {
final double tol = 1e-100;
final double[] init = new double[] { 0, 0 };
final int maxEval = 10000; // Trying hard to fit.
final SimpleVectorValueChecker checker = new SimpleVectorValueChecker(tol, tol, maxEval);
final double[] lm = doMath798(new LevenbergMarquardtOptimizer(checker), maxEval, init);
final double[] gn = doMath798(new GaussNewtonOptimizer(checker), maxEval, init);
for (int i = 0; i <= 1; i++) {
Assert.assertEquals(lm[i], gn[i], 1e-15);
}
}
/** /**
* @param optimizer Optimizer. * @param optimizer Optimizer.
* @param maxEval Maximum number of function evaluations. * @param maxEval Maximum number of function evaluations.