MATH-1443: Depend on "Commons Statistics".

This commit is contained in:
Gilles 2018-01-25 17:54:31 +01:00
parent c3ff46e303
commit b2d4b2ac3a
3 changed files with 27 additions and 19 deletions

View File

@ -28,6 +28,9 @@ import java.nio.charset.Charset;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.apache.commons.statistics.distribution.NormalDistribution;
import org.apache.commons.statistics.distribution.ContinuousDistribution;
import org.apache.commons.statistics.distribution.ConstantContinuousDistribution;
import org.apache.commons.math4.exception.MathIllegalStateException; import org.apache.commons.math4.exception.MathIllegalStateException;
import org.apache.commons.math4.exception.MathInternalError; import org.apache.commons.math4.exception.MathInternalError;
import org.apache.commons.math4.exception.NullArgumentException; import org.apache.commons.math4.exception.NullArgumentException;
@ -517,7 +520,7 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
return 0d; return 0d;
} }
final int binIndex = findBin(x); final int binIndex = findBin(x);
final RealDistribution kernel = getKernel(binStats.get(binIndex)); final ContinuousDistribution kernel = getKernel(binStats.get(binIndex));
return kernel.density(x) * pB(binIndex) / kB(binIndex); return kernel.density(x) * pB(binIndex) / kB(binIndex);
} }
@ -546,9 +549,9 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
final int binIndex = findBin(x); final int binIndex = findBin(x);
final double pBminus = pBminus(binIndex); final double pBminus = pBminus(binIndex);
final double pB = pB(binIndex); final double pB = pB(binIndex);
final RealDistribution kernel = k(x); final ContinuousDistribution kernel = k(x);
if (kernel instanceof ConstantRealDistribution) { if (kernel instanceof ConstantContinuousDistribution) {
if (x < kernel.getNumericalMean()) { if (x < kernel.getMean()) {
return pBminus; return pBminus;
} else { } else {
return pBminus + pB; return pBminus + pB;
@ -601,7 +604,7 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
i++; i++;
} }
final RealDistribution kernel = getKernel(binStats.get(i)); final ContinuousDistribution kernel = getKernel(binStats.get(i));
final double kB = kB(i); final double kB = kB(i);
final double[] binBounds = getUpperBounds(); final double[] binBounds = getUpperBounds();
final double lower = i == 0 ? min : binBounds[i - 1]; final double lower = i == 0 ? min : binBounds[i - 1];
@ -699,7 +702,7 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
*/ */
private double kB(int i) { private double kB(int i) {
final double[] binBounds = getUpperBounds(); final double[] binBounds = getUpperBounds();
final RealDistribution kernel = getKernel(binStats.get(i)); final ContinuousDistribution kernel = getKernel(binStats.get(i));
return i == 0 ? kernel.probability(min, binBounds[0]) : return i == 0 ? kernel.probability(min, binBounds[0]) :
kernel.probability(binBounds[i - 1], binBounds[i]); kernel.probability(binBounds[i - 1], binBounds[i]);
} }
@ -710,7 +713,7 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
* @param x the value to locate within a bin * @param x the value to locate within a bin
* @return the within-bin kernel of the bin containing x * @return the within-bin kernel of the bin containing x
*/ */
private RealDistribution k(double x) { private ContinuousDistribution k(double x) {
final int binIndex = findBin(x); final int binIndex = findBin(x);
return getKernel(binStats.get(binIndex)); return getKernel(binStats.get(binIndex));
} }
@ -733,12 +736,11 @@ public class EmpiricalDistribution extends AbstractRealDistribution {
* @param bStats summary statistics for the bin * @param bStats summary statistics for the bin
* @return within-bin kernel parameterized by bStats * @return within-bin kernel parameterized by bStats
*/ */
protected RealDistribution getKernel(SummaryStatistics bStats) { protected ContinuousDistribution getKernel(SummaryStatistics bStats) {
if (bStats.getN() == 1 || bStats.getVariance() == 0) { if (bStats.getN() == 1 || bStats.getVariance() == 0) {
return new ConstantRealDistribution(bStats.getMean()); return new ConstantContinuousDistribution(bStats.getMean());
} else { } else {
return new NormalDistribution(bStats.getMean(), bStats.getStandardDeviation(), return new NormalDistribution(bStats.getMean(), bStats.getStandardDeviation());
NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
} }
} }
} }

View File

@ -16,6 +16,8 @@
*/ */
package org.apache.commons.math4.distribution; package org.apache.commons.math4.distribution;
import org.apache.commons.statistics.distribution.ContinuousDistribution;
import org.apache.commons.statistics.distribution.NormalDistribution;
import org.apache.commons.math4.exception.DimensionMismatchException; import org.apache.commons.math4.exception.DimensionMismatchException;
import org.apache.commons.math4.linear.Array2DRowRealMatrix; import org.apache.commons.math4.linear.Array2DRowRealMatrix;
import org.apache.commons.math4.linear.EigenDecomposition; import org.apache.commons.math4.linear.EigenDecomposition;
@ -179,7 +181,7 @@ public class MultivariateNormalDistribution
public MultivariateRealDistribution.Sampler createSampler(final UniformRandomProvider rng) { public MultivariateRealDistribution.Sampler createSampler(final UniformRandomProvider rng) {
return new MultivariateRealDistribution.Sampler() { return new MultivariateRealDistribution.Sampler() {
/** Normal distribution. */ /** Normal distribution. */
private final RealDistribution.Sampler gauss = new NormalDistribution().createSampler(rng); private final ContinuousDistribution.Sampler gauss = new NormalDistribution(0, 1).createSampler(rng);
/** {@inheritDoc} */ /** {@inheritDoc} */
@Override @Override

View File

@ -24,6 +24,10 @@ import java.net.URL;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import org.apache.commons.statistics.distribution.ContinuousDistribution;
import org.apache.commons.statistics.distribution.ConstantContinuousDistribution;
import org.apache.commons.statistics.distribution.UniformContinuousDistribution;
import org.apache.commons.statistics.distribution.NormalDistribution;
import org.apache.commons.math4.TestUtils; import org.apache.commons.math4.TestUtils;
import org.apache.commons.math4.analysis.UnivariateFunction; import org.apache.commons.math4.analysis.UnivariateFunction;
import org.apache.commons.math4.analysis.integration.BaseAbstractUnivariateIntegrator; import org.apache.commons.math4.analysis.integration.BaseAbstractUnivariateIntegrator;
@ -334,7 +338,7 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
// Compute bMinus = sum or mass of bins below the bin containing the point // Compute bMinus = sum or mass of bins below the bin containing the point
// First bin has mass 11 / 10000, the rest have mass 10 / 10000. // First bin has mass 11 / 10000, the rest have mass 10 / 10000.
final double bMinus = bin == 0 ? 0 : (bin - 1) * binMass + firstBinMass; final double bMinus = bin == 0 ? 0 : (bin - 1) * binMass + firstBinMass;
final RealDistribution kernel = findKernel(lower, upper); final ContinuousDistribution kernel = findKernel(lower, upper);
final double withinBinKernelMass = kernel.probability(lower, upper); final double withinBinKernelMass = kernel.probability(lower, upper);
final double kernelCum = kernel.probability(lower, testPoints[i]); final double kernelCum = kernel.probability(lower, testPoints[i]);
cumValues[i] = bMinus + (bin == 0 ? firstBinMass : binMass) * kernelCum/withinBinKernelMass; cumValues[i] = bMinus + (bin == 0 ? firstBinMass : binMass) * kernelCum/withinBinKernelMass;
@ -353,7 +357,7 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
final double lower = bin == 0 ? empiricalDistribution.getSupportLowerBound() : final double lower = bin == 0 ? empiricalDistribution.getSupportLowerBound() :
binBounds[bin - 1]; binBounds[bin - 1];
final double upper = binBounds[bin]; final double upper = binBounds[bin];
final RealDistribution kernel = findKernel(lower, upper); final ContinuousDistribution kernel = findKernel(lower, upper);
final double withinBinKernelMass = kernel.probability(lower, upper); final double withinBinKernelMass = kernel.probability(lower, upper);
final double density = kernel.density(testPoints[i]); final double density = kernel.density(testPoints[i]);
densityValues[i] = density * (bin == 0 ? firstBinMass : binMass) / withinBinKernelMass; densityValues[i] = density * (bin == 0 ? firstBinMass : binMass) / withinBinKernelMass;
@ -456,7 +460,7 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
* The first bin includes its lower bound, 0, so has different mean and * The first bin includes its lower bound, 0, so has different mean and
* standard deviation. * standard deviation.
*/ */
private RealDistribution findKernel(double lower, double upper) { private ContinuousDistribution findKernel(double lower, double upper) {
if (lower < 1) { if (lower < 1) {
return new NormalDistribution(5d, 3.3166247903554); return new NormalDistribution(5d, 3.3166247903554);
} else { } else {
@ -535,8 +539,8 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
} }
// Use constant distribution equal to bin mean within bin // Use constant distribution equal to bin mean within bin
@Override @Override
protected RealDistribution getKernel(SummaryStatistics bStats) { protected ContinuousDistribution getKernel(SummaryStatistics bStats) {
return new ConstantRealDistribution(bStats.getMean()); return new ConstantContinuousDistribution(bStats.getMean());
} }
} }
@ -549,8 +553,8 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
super(i); super(i);
} }
@Override @Override
protected RealDistribution getKernel(SummaryStatistics bStats) { protected ContinuousDistribution getKernel(SummaryStatistics bStats) {
return new UniformRealDistribution(bStats.getMin(), bStats.getMax()); return new UniformContinuousDistribution(bStats.getMin(), bStats.getMax());
} }
} }
} }