MATH-902
Added a constructor in the custom checkers that enables normal termination of an optimization algorithm (i.e. returning the curent best point after a selected number of iterations have been performed). git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1411807 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
a20e951321
commit
8da1dede24
|
@ -19,6 +19,7 @@ package org.apache.commons.math3.optimization;
|
||||||
|
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
import org.apache.commons.math3.util.Pair;
|
import org.apache.commons.math3.util.Pair;
|
||||||
|
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Simple implementation of the {@link ConvergenceChecker} interface using
|
* Simple implementation of the {@link ConvergenceChecker} interface using
|
||||||
|
@ -28,6 +29,10 @@ import org.apache.commons.math3.util.Pair;
|
||||||
* difference between each point coordinate are smaller than a threshold
|
* difference between each point coordinate are smaller than a threshold
|
||||||
* or if either the absolute difference between the point coordinates are
|
* or if either the absolute difference between the point coordinates are
|
||||||
* smaller than another threshold.
|
* smaller than another threshold.
|
||||||
|
* <br/>
|
||||||
|
* The {@link #converged(int,Pair,Pair) converged} method will also return
|
||||||
|
* {@code true} if the number of iterations has been set (see
|
||||||
|
* {@link #SimplePointChecker(double,double,int) this constructor}).
|
||||||
*
|
*
|
||||||
* @param <PAIR> Type of the (point, value) pair.
|
* @param <PAIR> Type of the (point, value) pair.
|
||||||
* The type of the "value" part of the pair (not used by this class).
|
* The type of the "value" part of the pair (not used by this class).
|
||||||
|
@ -37,12 +42,27 @@ import org.apache.commons.math3.util.Pair;
|
||||||
*/
|
*/
|
||||||
public class SimplePointChecker<PAIR extends Pair<double[], ? extends Object>>
|
public class SimplePointChecker<PAIR extends Pair<double[], ? extends Object>>
|
||||||
extends AbstractConvergenceChecker<PAIR> {
|
extends AbstractConvergenceChecker<PAIR> {
|
||||||
|
/**
|
||||||
|
* If {@link #maxIterationCount} is set to this value, the number of
|
||||||
|
* iterations will never cause {@link #converged(int,Pair,Pair>)}
|
||||||
|
* to return {@code true}.
|
||||||
|
*/
|
||||||
|
private static final int ITERATION_CHECK_DISABLED = -1;
|
||||||
|
/**
|
||||||
|
* Number of iterations after which the
|
||||||
|
* {@link #converged(int,PointVectorValuePair,PointVectorValuePair)} method
|
||||||
|
* will return true (unless the check is disabled).
|
||||||
|
*/
|
||||||
|
private final int maxIterationCount;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build an instance with default threshold.
|
* Build an instance with default threshold.
|
||||||
* @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()}
|
* @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()}
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public SimplePointChecker() {}
|
public SimplePointChecker() {
|
||||||
|
maxIterationCount = ITERATION_CHECK_DISABLED;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build an instance with specified thresholds.
|
* Build an instance with specified thresholds.
|
||||||
|
@ -56,12 +76,38 @@ public class SimplePointChecker<PAIR extends Pair<double[], ? extends Object>>
|
||||||
public SimplePointChecker(final double relativeThreshold,
|
public SimplePointChecker(final double relativeThreshold,
|
||||||
final double absoluteThreshold) {
|
final double absoluteThreshold) {
|
||||||
super(relativeThreshold, absoluteThreshold);
|
super(relativeThreshold, absoluteThreshold);
|
||||||
|
maxIterationCount = ITERATION_CHECK_DISABLED;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an instance with specified thresholds.
|
||||||
|
* In order to perform only relative checks, the absolute tolerance
|
||||||
|
* must be set to a negative value. In order to perform only absolute
|
||||||
|
* checks, the relative tolerance must be set to a negative value.
|
||||||
|
*
|
||||||
|
* @param relativeThreshold Relative tolerance threshold.
|
||||||
|
* @param absoluteThreshold Absolute tolerance threshold.
|
||||||
|
* @param maxIter Maximum iteration count. Setting it to a negative
|
||||||
|
* value will disable this stopping criterion.
|
||||||
|
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
|
||||||
|
*
|
||||||
|
* @since 3.1
|
||||||
|
*/
|
||||||
|
public SimplePointChecker(final double relativeThreshold,
|
||||||
|
final double absoluteThreshold,
|
||||||
|
final int maxIter) {
|
||||||
|
super(relativeThreshold, absoluteThreshold);
|
||||||
|
|
||||||
|
if (maxIter <= 0) {
|
||||||
|
throw new NotStrictlyPositiveException(maxIter);
|
||||||
|
}
|
||||||
|
maxIterationCount = maxIter;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if the optimization algorithm has converged considering the
|
* Check if the optimization algorithm has converged considering the
|
||||||
* last two points.
|
* last two points.
|
||||||
* This method may be called several time from the same algorithm
|
* This method may be called several times from the same algorithm
|
||||||
* iteration with different points. This can be detected by checking the
|
* iteration with different points. This can be detected by checking the
|
||||||
* iteration number at each call if needed. Each time this method is
|
* iteration number at each call if needed. Each time this method is
|
||||||
* called, the previous and current point correspond to points with the
|
* called, the previous and current point correspond to points with the
|
||||||
|
@ -72,12 +118,18 @@ public class SimplePointChecker<PAIR extends Pair<double[], ? extends Object>>
|
||||||
* @param iteration Index of current iteration
|
* @param iteration Index of current iteration
|
||||||
* @param previous Best point in the previous iteration.
|
* @param previous Best point in the previous iteration.
|
||||||
* @param current Best point in the current iteration.
|
* @param current Best point in the current iteration.
|
||||||
* @return {@code true} if the algorithm has converged.
|
* @return {@code true} if the arguments satify the convergence criterion.
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public boolean converged(final int iteration,
|
public boolean converged(final int iteration,
|
||||||
final PAIR previous,
|
final PAIR previous,
|
||||||
final PAIR current) {
|
final PAIR current) {
|
||||||
|
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
|
||||||
|
if (iteration >= maxIterationCount) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
final double[] p = previous.getKey();
|
final double[] p = previous.getKey();
|
||||||
final double[] c = current.getKey();
|
final double[] c = current.getKey();
|
||||||
for (int i = 0; i < p.length; ++i) {
|
for (int i = 0; i < p.length; ++i) {
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.commons.math3.optimization;
|
package org.apache.commons.math3.optimization;
|
||||||
|
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
|
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Simple implementation of the {@link ConvergenceChecker} interface using
|
* Simple implementation of the {@link ConvergenceChecker} interface using
|
||||||
|
@ -27,18 +28,38 @@ import org.apache.commons.math3.util.FastMath;
|
||||||
* difference between the objective function values is smaller than a
|
* difference between the objective function values is smaller than a
|
||||||
* threshold or if either the absolute difference between the objective
|
* threshold or if either the absolute difference between the objective
|
||||||
* function values is smaller than another threshold.
|
* function values is smaller than another threshold.
|
||||||
|
* <br/>
|
||||||
|
* The {@link #converged(int,PointValuePair,PointValuePair) converged}
|
||||||
|
* method will also return {@code true} if the number of iterations has been set
|
||||||
|
* (see {@link #SimpleValueChecker(double,double,int) this constructor}).
|
||||||
*
|
*
|
||||||
* @version $Id$
|
* @version $Id$
|
||||||
* @since 3.0
|
* @since 3.0
|
||||||
*/
|
*/
|
||||||
public class SimpleValueChecker
|
public class SimpleValueChecker
|
||||||
extends AbstractConvergenceChecker<PointValuePair> {
|
extends AbstractConvergenceChecker<PointValuePair> {
|
||||||
|
/**
|
||||||
|
* If {@link #maxIterationCount} is set to this value, the number of
|
||||||
|
* iterations will never cause
|
||||||
|
* {@link #converged(int,PointValuePair,PointValuePair)}
|
||||||
|
* to return {@code true}.
|
||||||
|
*/
|
||||||
|
private static final int ITERATION_CHECK_DISABLED = -1;
|
||||||
|
/**
|
||||||
|
* Number of iterations after which the
|
||||||
|
* {@link #converged(int,PointValuePair,PointValuePair)} method
|
||||||
|
* will return true (unless the check is disabled).
|
||||||
|
*/
|
||||||
|
private final int maxIterationCount;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build an instance with default thresholds.
|
* Build an instance with default thresholds.
|
||||||
* @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()}
|
* @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()}
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public SimpleValueChecker() {}
|
public SimpleValueChecker() {
|
||||||
|
maxIterationCount = ITERATION_CHECK_DISABLED;
|
||||||
|
}
|
||||||
|
|
||||||
/** Build an instance with specified thresholds.
|
/** Build an instance with specified thresholds.
|
||||||
*
|
*
|
||||||
|
@ -52,6 +73,33 @@ public class SimpleValueChecker
|
||||||
public SimpleValueChecker(final double relativeThreshold,
|
public SimpleValueChecker(final double relativeThreshold,
|
||||||
final double absoluteThreshold) {
|
final double absoluteThreshold) {
|
||||||
super(relativeThreshold, absoluteThreshold);
|
super(relativeThreshold, absoluteThreshold);
|
||||||
|
maxIterationCount = ITERATION_CHECK_DISABLED;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an instance with specified thresholds.
|
||||||
|
*
|
||||||
|
* In order to perform only relative checks, the absolute tolerance
|
||||||
|
* must be set to a negative value. In order to perform only absolute
|
||||||
|
* checks, the relative tolerance must be set to a negative value.
|
||||||
|
*
|
||||||
|
* @param relativeThreshold relative tolerance threshold
|
||||||
|
* @param absoluteThreshold absolute tolerance threshold
|
||||||
|
* @param maxIter Maximum iteration count. Setting it to a negative
|
||||||
|
* value will disable this stopping criterion.
|
||||||
|
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
|
||||||
|
*
|
||||||
|
* @since 3.1
|
||||||
|
*/
|
||||||
|
public SimpleValueChecker(final double relativeThreshold,
|
||||||
|
final double absoluteThreshold,
|
||||||
|
final int maxIter) {
|
||||||
|
super(relativeThreshold, absoluteThreshold);
|
||||||
|
|
||||||
|
if (maxIter <= 0) {
|
||||||
|
throw new NotStrictlyPositiveException(maxIter);
|
||||||
|
}
|
||||||
|
maxIterationCount = maxIter;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -74,6 +122,12 @@ public class SimpleValueChecker
|
||||||
public boolean converged(final int iteration,
|
public boolean converged(final int iteration,
|
||||||
final PointValuePair previous,
|
final PointValuePair previous,
|
||||||
final PointValuePair current) {
|
final PointValuePair current) {
|
||||||
|
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
|
||||||
|
if (iteration >= maxIterationCount) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
final double p = previous.getValue();
|
final double p = previous.getValue();
|
||||||
final double c = current.getValue();
|
final double c = current.getValue();
|
||||||
final double difference = FastMath.abs(p - c);
|
final double difference = FastMath.abs(p - c);
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.commons.math3.optimization;
|
package org.apache.commons.math3.optimization;
|
||||||
|
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
|
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Simple implementation of the {@link ConvergenceChecker} interface using
|
* Simple implementation of the {@link ConvergenceChecker} interface using
|
||||||
|
@ -27,18 +28,38 @@ import org.apache.commons.math3.util.FastMath;
|
||||||
* difference between the objective function values is smaller than a
|
* difference between the objective function values is smaller than a
|
||||||
* threshold or if either the absolute difference between the objective
|
* threshold or if either the absolute difference between the objective
|
||||||
* function values is smaller than another threshold for all vectors elements.
|
* function values is smaller than another threshold for all vectors elements.
|
||||||
|
* <br/>
|
||||||
|
* The {@link #converged(int,PointVectorValuePair,PointVectorValuePair) converged}
|
||||||
|
* method will also return {@code true} if the number of iterations has been set
|
||||||
|
* (see {@link #SimpleVectorValueChecker(double,double,int) this constructor}).
|
||||||
*
|
*
|
||||||
* @version $Id$
|
* @version $Id$
|
||||||
* @since 3.0
|
* @since 3.0
|
||||||
*/
|
*/
|
||||||
public class SimpleVectorValueChecker
|
public class SimpleVectorValueChecker
|
||||||
extends AbstractConvergenceChecker<PointVectorValuePair> {
|
extends AbstractConvergenceChecker<PointVectorValuePair> {
|
||||||
|
/**
|
||||||
|
* If {@link #maxIterationCount} is set to this value, the number of
|
||||||
|
* iterations will never cause
|
||||||
|
* {@link #converged(int,PointVectorValuePair,PointVectorValuePair)}
|
||||||
|
* to return {@code true}.
|
||||||
|
*/
|
||||||
|
private static final int ITERATION_CHECK_DISABLED = -1;
|
||||||
|
/**
|
||||||
|
* Number of iterations after which the
|
||||||
|
* {@link #converged(int,PointVectorValuePair,PointVectorValuePair)} method
|
||||||
|
* will return true (unless the check is disabled).
|
||||||
|
*/
|
||||||
|
private final int maxIterationCount;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build an instance with default thresholds.
|
* Build an instance with default thresholds.
|
||||||
* @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()}
|
* @deprecated See {@link AbstractConvergenceChecker#AbstractConvergenceChecker()}
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public SimpleVectorValueChecker() {}
|
public SimpleVectorValueChecker() {
|
||||||
|
maxIterationCount = ITERATION_CHECK_DISABLED;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build an instance with specified thresholds.
|
* Build an instance with specified thresholds.
|
||||||
|
@ -51,14 +72,42 @@ public class SimpleVectorValueChecker
|
||||||
* @param absoluteThreshold absolute tolerance threshold
|
* @param absoluteThreshold absolute tolerance threshold
|
||||||
*/
|
*/
|
||||||
public SimpleVectorValueChecker(final double relativeThreshold,
|
public SimpleVectorValueChecker(final double relativeThreshold,
|
||||||
final double absoluteThreshold) {
|
final double absoluteThreshold) {
|
||||||
super(relativeThreshold, absoluteThreshold);
|
super(relativeThreshold, absoluteThreshold);
|
||||||
|
maxIterationCount = ITERATION_CHECK_DISABLED;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an instance with specified tolerance thresholds and
|
||||||
|
* iteration count.
|
||||||
|
*
|
||||||
|
* In order to perform only relative checks, the absolute tolerance
|
||||||
|
* must be set to a negative value. In order to perform only absolute
|
||||||
|
* checks, the relative tolerance must be set to a negative value.
|
||||||
|
*
|
||||||
|
* @param relativeThreshold Relative tolerance threshold.
|
||||||
|
* @param absoluteThreshold Absolute tolerance threshold.
|
||||||
|
* @param maxIter Maximum iteration count. Setting it to a negative
|
||||||
|
* value will disable this stopping criterion.
|
||||||
|
* @throws NotStrictlyPositiveException if {@code maxIter <= 0}.
|
||||||
|
*
|
||||||
|
* @since 3.1
|
||||||
|
*/
|
||||||
|
public SimpleVectorValueChecker(final double relativeThreshold,
|
||||||
|
final double absoluteThreshold,
|
||||||
|
final int maxIter) {
|
||||||
|
super(relativeThreshold, absoluteThreshold);
|
||||||
|
|
||||||
|
if (maxIter <= 0) {
|
||||||
|
throw new NotStrictlyPositiveException(maxIter);
|
||||||
|
}
|
||||||
|
maxIterationCount = maxIter;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if the optimization algorithm has converged considering the
|
* Check if the optimization algorithm has converged considering the
|
||||||
* last two points.
|
* last two points.
|
||||||
* This method may be called several time from the same algorithm
|
* This method may be called several times from the same algorithm
|
||||||
* iteration with different points. This can be detected by checking the
|
* iteration with different points. This can be detected by checking the
|
||||||
* iteration number at each call if needed. Each time this method is
|
* iteration number at each call if needed. Each time this method is
|
||||||
* called, the previous and current point correspond to points with the
|
* called, the previous and current point correspond to points with the
|
||||||
|
@ -69,12 +118,18 @@ public class SimpleVectorValueChecker
|
||||||
* @param iteration Index of current iteration
|
* @param iteration Index of current iteration
|
||||||
* @param previous Best point in the previous iteration.
|
* @param previous Best point in the previous iteration.
|
||||||
* @param current Best point in the current iteration.
|
* @param current Best point in the current iteration.
|
||||||
* @return {@code true} if the algorithm has converged.
|
* @return {@code true} if the arguments satify the convergence criterion.
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public boolean converged(final int iteration,
|
public boolean converged(final int iteration,
|
||||||
final PointVectorValuePair previous,
|
final PointVectorValuePair previous,
|
||||||
final PointVectorValuePair current) {
|
final PointVectorValuePair current) {
|
||||||
|
if (maxIterationCount != ITERATION_CHECK_DISABLED) {
|
||||||
|
if (iteration >= maxIterationCount) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
final double[] p = previous.getValueRef();
|
final double[] p = previous.getValueRef();
|
||||||
final double[] c = current.getValueRef();
|
final double[] c = current.getValueRef();
|
||||||
for (int i = 0; i < p.length; ++i) {
|
for (int i = 0; i < p.length; ++i) {
|
||||||
|
|
|
@ -146,6 +146,28 @@ public class PolynomialFitterTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This test shows that the user can set the maximum number of iterations
|
||||||
|
* to avoid running for too long.
|
||||||
|
* Even if the real problem is that the tolerance is way too stringent, it
|
||||||
|
* is possible to get the best solution so far, i.e. a checker will return
|
||||||
|
* the point when the maximum iteration count has been reached.
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void testMath798WithToleranceTooLowButNoException() {
|
||||||
|
final double tol = 1e-100;
|
||||||
|
final double[] init = new double[] { 0, 0 };
|
||||||
|
final int maxEval = 10000; // Trying hard to fit.
|
||||||
|
final SimpleVectorValueChecker checker = new SimpleVectorValueChecker(tol, tol, maxEval);
|
||||||
|
|
||||||
|
final double[] lm = doMath798(new LevenbergMarquardtOptimizer(checker), maxEval, init);
|
||||||
|
final double[] gn = doMath798(new GaussNewtonOptimizer(checker), maxEval, init);
|
||||||
|
|
||||||
|
for (int i = 0; i <= 1; i++) {
|
||||||
|
Assert.assertEquals(lm[i], gn[i], 1e-15);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param optimizer Optimizer.
|
* @param optimizer Optimizer.
|
||||||
* @param maxEval Maximum number of function evaluations.
|
* @param maxEval Maximum number of function evaluations.
|
||||||
|
|
Loading…
Reference in New Issue