MATH-1558: Fix MidPointIntegrator incremental implementation

This commit is contained in:
Sam Ritchie 2020-10-20 16:24:29 -06:00
parent 0b3629b4b0
commit e0b2efc2ac
2 changed files with 53 additions and 19 deletions

View File

@ -25,7 +25,7 @@ import org.apache.commons.math4.exception.TooManyEvaluationsException;
import org.apache.commons.math4.util.FastMath;
/**
* Implements the <a href="http://en.wikipedia.org/wiki/Midpoint_method">
* Implements the <a href="https://en.wikipedia.org/wiki/Riemann_sum#Midpoint_rule">
* Midpoint Rule</a> for integration of real univariate functions. For
* reference, see <b>Numerical Mathematics</b>, ISBN 0387989595,
* chapter 9.2.
@ -36,8 +36,10 @@ import org.apache.commons.math4.util.FastMath;
*/
public class MidPointIntegrator extends BaseAbstractUnivariateIntegrator {
/** Maximum number of iterations for midpoint. */
private static final int MIDPOINT_MAX_ITERATIONS_COUNT = 63;
/** Maximum number of iterations for midpoint. 39 = floor(log_3(2^63)), the
* maximum number of triplings allowed before exceeding 64-bit bounds.
*/
private static final int MIDPOINT_MAX_ITERATIONS_COUNT = 39;
/**
* Build a midpoint integrator with given accuracies and iterations counts.
@ -50,7 +52,7 @@ public class MidPointIntegrator extends BaseAbstractUnivariateIntegrator {
* @exception NumberIsTooSmallException if maximal number of iterations
* is lesser than or equal to the minimal number of iterations
* @exception NumberIsTooLargeException if maximal number of iterations
* is greater than 63.
* is greater than 39.
*/
public MidPointIntegrator(final double relativeAccuracy,
final double absoluteAccuracy,
@ -73,7 +75,7 @@ public class MidPointIntegrator extends BaseAbstractUnivariateIntegrator {
* @exception NumberIsTooSmallException if maximal number of iterations
* is lesser than or equal to the minimal number of iterations
* @exception NumberIsTooLargeException if maximal number of iterations
* is greater than 63.
* is greater than 39.
*/
public MidPointIntegrator(final int minimalIterationCount,
final int maximalIterationCount)
@ -98,11 +100,11 @@ public class MidPointIntegrator extends BaseAbstractUnivariateIntegrator {
* This function should only be called by API <code>integrate()</code> in the package.
* To save time it does not verify arguments - caller does.
* <p>
* The interval is divided equally into 2^n sections rather than an
* The interval is divided equally into 3^n sections rather than an
* arbitrary m sections because this configuration can best utilize the
* already computed values.</p>
*
* @param n the stage of 1/2 refinement. Must be larger than 0.
* @param n the stage of 1/3 refinement. Must be larger than 0.
* @param previousStageResult Result from the previous call to the
* {@code stage} method.
* @param min Lower bound of the integration interval.
@ -118,21 +120,29 @@ public class MidPointIntegrator extends BaseAbstractUnivariateIntegrator {
double diffMaxMin)
throws TooManyEvaluationsException {
// number of new points in this stage
final long np = 1L << (n - 1);
// number of points in the previous stage. This stage will contribute
// 2*3^{n-1} more points.
final long np = (long) FastMath.pow(3, n - 1);
double sum = 0;
// spacing between adjacent new points
final double spacing = diffMaxMin / np;
final double leftOffset = spacing / 6;
final double rightOffset = 5 * leftOffset;
// the first new point
double x = min + 0.5 * spacing;
double x = min;
for (long i = 0; i < np; i++) {
sum += computeObjectiveValue(x);
// The first and second new points are located at the new midpoints
// generated when each previous integration slice is split into 3.
//
// |--------x--------|
// |--x--|--x--|--x-|
sum += computeObjectiveValue(x + leftOffset);
sum += computeObjectiveValue(x + rightOffset);
x += spacing;
}
// add the new sum to previously calculated result
return 0.5 * (previousStageResult + sum * spacing);
return (previousStageResult + sum * spacing) / 3.0;
}

View File

@ -35,6 +35,25 @@ import org.junit.Test;
public final class MidPointIntegratorTest {
private static final int NUM_ITER = 30;
/**
* The initial iteration contributes 1 evaluation. Each successive iteration
* contributes 2 points to each previous slice.
*
* The total evaluation count == 1 + 2*3^0 + 2*3^1 + ... 2*3^n
*
* the series 3^0 + 3^1 + ... + 3^n sums to 3^(n-1) / (3-1), so the total
* expected evaluations == 1 + 2*(3^(n-1) - 1)/2 == 3^(n-1).
*
* The n in the series above is offset by 1 from the MidPointIntegrator
* iteration count so the actual result == 3^n.
*
* Without the incremental implementation, the same result would require
* (3^(n + 1) - 1) / 2 evaluations; just under 50% more.
*/
private long expectedEvaluations(int iterations) {
return (long) FastMath.pow(3, iterations);
}
/**
* Test of integrator for the sine function.
*/
@ -48,8 +67,9 @@ public final class MidPointIntegratorTest {
double expected = -3697001.0 / 48.0;
double tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
double result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
Assert.assertTrue(integrator.getIterations() < NUM_ITER);
Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations());
Assert.assertEquals(expected, result, tolerance);
}
@ -67,8 +87,9 @@ public final class MidPointIntegratorTest {
double expected = 2;
double tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
double result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
Assert.assertTrue(integrator.getIterations() < NUM_ITER);
Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations());
Assert.assertEquals(expected, result, tolerance);
min = -FastMath.PI/3;
@ -76,8 +97,9 @@ public final class MidPointIntegratorTest {
expected = -0.5;
tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
Assert.assertTrue(integrator.getIterations() < NUM_ITER);
Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations());
Assert.assertEquals(expected, result, tolerance);
}
@ -95,8 +117,9 @@ public final class MidPointIntegratorTest {
double expected = -1.0 / 48;
double tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
double result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
Assert.assertTrue(integrator.getIterations() < NUM_ITER);
Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations());
Assert.assertEquals(expected, result, tolerance);
min = 0;
@ -104,7 +127,7 @@ public final class MidPointIntegratorTest {
expected = 11.0 / 768;
tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
Assert.assertTrue(integrator.getIterations() < NUM_ITER);
Assert.assertEquals(expected, result, tolerance);
@ -113,8 +136,9 @@ public final class MidPointIntegratorTest {
expected = 2048 / 3.0 - 78 + 1.0 / 48;
tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
Assert.assertTrue(integrator.getIterations() < NUM_ITER);
Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations());
Assert.assertEquals(expected, result, tolerance);
}