diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/filter/KalmanFilterTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/filter/KalmanFilterTest.java index 0dbba2cf5..ae3c592fb 100644 --- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/filter/KalmanFilterTest.java +++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/filter/KalmanFilterTest.java @@ -14,17 +14,18 @@ package org.apache.commons.math4.legacy.filter; -import org.apache.commons.statistics.distribution.ContinuousDistribution; -import org.apache.commons.statistics.distribution.NormalDistribution; +import org.apache.commons.rng.simple.RandomSource; +import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler; +import org.apache.commons.rng.sampling.distribution.GaussianSampler; +import org.apache.commons.rng.sampling.distribution.ContinuousSampler; +import org.apache.commons.numbers.core.Precision; import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix; import org.apache.commons.math4.legacy.linear.ArrayRealVector; import org.apache.commons.math4.legacy.linear.MatrixDimensionMismatchException; import org.apache.commons.math4.legacy.linear.MatrixUtils; import org.apache.commons.math4.legacy.linear.RealMatrix; import org.apache.commons.math4.legacy.linear.RealVector; -import org.apache.commons.rng.simple.RandomSource; import org.apache.commons.math4.legacy.core.jdkmath.AccurateMath; -import org.apache.commons.numbers.core.Precision; import org.junit.Assert; import org.junit.Test; @@ -121,7 +122,7 @@ public class KalmanFilterTest { RealVector pNoise = new ArrayRealVector(1); RealVector mNoise = new ArrayRealVector(1); - final ContinuousDistribution.Sampler rand = new NormalDistribution(0, 1).createSampler(RandomSource.create(RandomSource.WELL_19937_C)); + final ContinuousSampler rand = createGaussianSampler(0, 1); // iterate 60 steps for (int i = 0; i < 60; i++) { @@ -210,7 +211,7 @@ public class KalmanFilterTest { double[] expectedInitialState = new double[] { 0.0, 0.0 }; assertVectorEquals(expectedInitialState, filter.getStateEstimation()); - final ContinuousDistribution.Sampler rand = new NormalDistribution(0, 1).createSampler(RandomSource.create(RandomSource.WELL_19937_C)); + final ContinuousSampler rand = createGaussianSampler(0, 1); RealVector tmpPNoise = new ArrayRealVector( new double[] { AccurateMath.pow(dt, 2d) / 2d, dt }); @@ -387,7 +388,7 @@ public class KalmanFilterTest { final MeasurementModel mm = new DefaultMeasurementModel(H, R); final KalmanFilter filter = new KalmanFilter(pm, mm); - final ContinuousDistribution.Sampler dist = new NormalDistribution(0, measurementNoise).createSampler(RandomSource.create(RandomSource.WELL_19937_C)); + final ContinuousSampler rand = createGaussianSampler(0, measurementNoise); for (int i = 0; i < iterations; i++) { // get the "real" cannonball position @@ -395,8 +396,8 @@ public class KalmanFilterTest { double y = cannonball.getY(); // apply measurement noise to current cannonball position - double nx = x + dist.sample(); - double ny = y + dist.sample(); + double nx = x + rand.sample(); + double ny = y + rand.sample(); cannonball.step(); @@ -439,4 +440,15 @@ public class KalmanFilterTest { } } } + + /** + * @param mu Mean + * @param sigma Standard deviation. + * @return a sampler that follows the N(mu,sigma) distribution. + */ + private ContinuousSampler createGaussianSampler(double mu, + double sigma) { + return GaussianSampler.of(ZigguratNormalizedGaussianSampler.of(RandomSource.create(RandomSource.JSF_64)), + mu, sigma); + } }