diff --git a/src/main/java/org/apache/commons/math4/distribution/UniformIntegerDistribution.java b/src/main/java/org/apache/commons/math4/distribution/UniformIntegerDistribution.java
index 1e89c4309..2909cc87d 100644
--- a/src/main/java/org/apache/commons/math4/distribution/UniformIntegerDistribution.java
+++ b/src/main/java/org/apache/commons/math4/distribution/UniformIntegerDistribution.java
@@ -22,10 +22,8 @@ import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.rng.UniformRandomProvider;
/**
- * Implementation of the uniform integer distribution.
- *
- * @see Uniform distribution (discrete), at Wikipedia
+ * Implementation of the
+ * uniform integer distribution.
*
* @since 3.0
*/
@@ -36,6 +34,10 @@ public class UniformIntegerDistribution extends AbstractIntegerDistribution {
private final int lower;
/** Upper bound (inclusive) of this distribution. */
private final int upper;
+ /** "upper" + "lower" (to avoid overflow). */
+ private final double upperPlusLower;
+ /** "upper" - "lower" (to avoid overflow). */
+ private final double upperMinusLower;
/**
* Creates a new uniform integer distribution using the given lower and
@@ -55,6 +57,8 @@ public class UniformIntegerDistribution extends AbstractIntegerDistribution {
}
this.lower = lower;
this.upper = upper;
+ upperPlusLower = (double) upper + (double) lower;
+ upperMinusLower = (double) upper - (double) lower;
}
/** {@inheritDoc} */
@@ -63,7 +67,7 @@ public class UniformIntegerDistribution extends AbstractIntegerDistribution {
if (x < lower || x > upper) {
return 0;
}
- return 1.0 / (upper - lower + 1);
+ return 1.0 / (upperMinusLower + 1);
}
/** {@inheritDoc} */
@@ -75,7 +79,7 @@ public class UniformIntegerDistribution extends AbstractIntegerDistribution {
if (x > upper) {
return 1;
}
- return (x - lower + 1.0) / (upper - lower + 1.0);
+ return (x - lower + 1.0) / (upperMinusLower + 1.0);
}
/**
@@ -86,7 +90,7 @@ public class UniformIntegerDistribution extends AbstractIntegerDistribution {
*/
@Override
public double getNumericalMean() {
- return 0.5 * (lower + upper);
+ return 0.5 * upperPlusLower;
}
/**
@@ -97,7 +101,7 @@ public class UniformIntegerDistribution extends AbstractIntegerDistribution {
*/
@Override
public double getNumericalVariance() {
- double n = upper - lower + 1;
+ double n = upperMinusLower + 1;
return (n * n - 1) / 12.0;
}
diff --git a/src/test/java/org/apache/commons/math4/distribution/UniformIntegerDistributionTest.java b/src/test/java/org/apache/commons/math4/distribution/UniformIntegerDistributionTest.java
index 3ae15dda0..d1fc3161e 100644
--- a/src/test/java/org/apache/commons/math4/distribution/UniformIntegerDistributionTest.java
+++ b/src/test/java/org/apache/commons/math4/distribution/UniformIntegerDistributionTest.java
@@ -22,6 +22,7 @@ import org.junit.Test;
import org.apache.commons.math4.distribution.IntegerDistribution;
import org.apache.commons.math4.distribution.UniformIntegerDistribution;
import org.apache.commons.math4.exception.NumberIsTooLargeException;
+import org.apache.commons.math4.util.Precision;
/**
* Test cases for UniformIntegerDistribution. See class javadoc for
@@ -112,4 +113,30 @@ public class UniformIntegerDistributionTest extends IntegerDistributionAbstractT
// Degenerate case is allowed.
new UniformIntegerDistribution(0, 0);
}
+
+ // MATH-1396
+ @Test
+ public void testLargeRangeSubtractionOverflow() {
+ final int hi = Integer.MAX_VALUE / 2 + 10;
+ UniformIntegerDistribution dist = new UniformIntegerDistribution(-hi, hi - 1);
+
+ final double tol = Math.ulp(1d);
+ Assert.assertEquals(0.5 / hi, dist.probability(123456), tol);
+ Assert.assertEquals(0.5, dist.cumulativeProbability(-1), tol);
+
+ Assert.assertTrue(Precision.equals((Math.pow(2d * hi, 2) - 1) / 12, dist.getNumericalVariance(), 1));
+ }
+
+ // MATH-1396
+ @Test
+ public void testLargeRangeAdditionOverflow() {
+ final int hi = Integer.MAX_VALUE / 2 + 10;
+ UniformIntegerDistribution dist = new UniformIntegerDistribution(hi - 1, hi + 1);
+
+ final double tol = Math.ulp(1d);
+ Assert.assertEquals(1d / 3d, dist.probability(hi), tol);
+ Assert.assertEquals(2d / 3d, dist.cumulativeProbability(hi), tol);
+
+ Assert.assertTrue(Precision.equals(hi, dist.getNumericalMean(), 1));
+ }
}