diff --git a/src/main/java/org/apache/commons/math4/distribution/UniformIntegerDistribution.java b/src/main/java/org/apache/commons/math4/distribution/UniformIntegerDistribution.java index 1e89c4309..2909cc87d 100644 --- a/src/main/java/org/apache/commons/math4/distribution/UniformIntegerDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/UniformIntegerDistribution.java @@ -22,10 +22,8 @@ import org.apache.commons.math4.exception.util.LocalizedFormats; import org.apache.commons.rng.UniformRandomProvider; /** - * Implementation of the uniform integer distribution. - * - * @see Uniform distribution (discrete), at Wikipedia + * Implementation of the + * uniform integer distribution. * * @since 3.0 */ @@ -36,6 +34,10 @@ public class UniformIntegerDistribution extends AbstractIntegerDistribution { private final int lower; /** Upper bound (inclusive) of this distribution. */ private final int upper; + /** "upper" + "lower" (to avoid overflow). */ + private final double upperPlusLower; + /** "upper" - "lower" (to avoid overflow). */ + private final double upperMinusLower; /** * Creates a new uniform integer distribution using the given lower and @@ -55,6 +57,8 @@ public class UniformIntegerDistribution extends AbstractIntegerDistribution { } this.lower = lower; this.upper = upper; + upperPlusLower = (double) upper + (double) lower; + upperMinusLower = (double) upper - (double) lower; } /** {@inheritDoc} */ @@ -63,7 +67,7 @@ public class UniformIntegerDistribution extends AbstractIntegerDistribution { if (x < lower || x > upper) { return 0; } - return 1.0 / (upper - lower + 1); + return 1.0 / (upperMinusLower + 1); } /** {@inheritDoc} */ @@ -75,7 +79,7 @@ public class UniformIntegerDistribution extends AbstractIntegerDistribution { if (x > upper) { return 1; } - return (x - lower + 1.0) / (upper - lower + 1.0); + return (x - lower + 1.0) / (upperMinusLower + 1.0); } /** @@ -86,7 +90,7 @@ public class UniformIntegerDistribution extends AbstractIntegerDistribution { */ @Override public double getNumericalMean() { - return 0.5 * (lower + upper); + return 0.5 * upperPlusLower; } /** @@ -97,7 +101,7 @@ public class UniformIntegerDistribution extends AbstractIntegerDistribution { */ @Override public double getNumericalVariance() { - double n = upper - lower + 1; + double n = upperMinusLower + 1; return (n * n - 1) / 12.0; } diff --git a/src/test/java/org/apache/commons/math4/distribution/UniformIntegerDistributionTest.java b/src/test/java/org/apache/commons/math4/distribution/UniformIntegerDistributionTest.java index 3ae15dda0..d1fc3161e 100644 --- a/src/test/java/org/apache/commons/math4/distribution/UniformIntegerDistributionTest.java +++ b/src/test/java/org/apache/commons/math4/distribution/UniformIntegerDistributionTest.java @@ -22,6 +22,7 @@ import org.junit.Test; import org.apache.commons.math4.distribution.IntegerDistribution; import org.apache.commons.math4.distribution.UniformIntegerDistribution; import org.apache.commons.math4.exception.NumberIsTooLargeException; +import org.apache.commons.math4.util.Precision; /** * Test cases for UniformIntegerDistribution. See class javadoc for @@ -112,4 +113,30 @@ public class UniformIntegerDistributionTest extends IntegerDistributionAbstractT // Degenerate case is allowed. new UniformIntegerDistribution(0, 0); } + + // MATH-1396 + @Test + public void testLargeRangeSubtractionOverflow() { + final int hi = Integer.MAX_VALUE / 2 + 10; + UniformIntegerDistribution dist = new UniformIntegerDistribution(-hi, hi - 1); + + final double tol = Math.ulp(1d); + Assert.assertEquals(0.5 / hi, dist.probability(123456), tol); + Assert.assertEquals(0.5, dist.cumulativeProbability(-1), tol); + + Assert.assertTrue(Precision.equals((Math.pow(2d * hi, 2) - 1) / 12, dist.getNumericalVariance(), 1)); + } + + // MATH-1396 + @Test + public void testLargeRangeAdditionOverflow() { + final int hi = Integer.MAX_VALUE / 2 + 10; + UniformIntegerDistribution dist = new UniformIntegerDistribution(hi - 1, hi + 1); + + final double tol = Math.ulp(1d); + Assert.assertEquals(1d / 3d, dist.probability(hi), tol); + Assert.assertEquals(2d / 3d, dist.cumulativeProbability(hi), tol); + + Assert.assertTrue(Precision.equals(hi, dist.getNumericalMean(), 1)); + } }