MATH-1616: Refactor "EmpiricalDistribution".

* No default bin count (cf. MATH-1462).
* No data loading from external sources (file, URL).
* No data abstraction layer.
* Return defensive copies of the internal state.
* Make class immutable.
* Allow user-defined within-bin kernel.
This commit is contained in:
Gilles Sadowski 2021-07-20 11:45:45 +02:00
parent 8968416790
commit 9dbceb0ed1
2 changed files with 277 additions and 663 deletions

View File

@ -17,29 +17,16 @@
package org.apache.commons.math4.legacy.distribution;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import org.apache.commons.statistics.distribution.NormalDistribution;
import org.apache.commons.statistics.distribution.ContinuousDistribution;
import org.apache.commons.statistics.distribution.ConstantContinuousDistribution;
import org.apache.commons.numbers.core.Precision;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
import org.apache.commons.math4.legacy.exception.MathInternalError;
import org.apache.commons.math4.legacy.exception.NullArgumentException;
import org.apache.commons.math4.legacy.exception.OutOfRangeException;
import org.apache.commons.math4.legacy.exception.ZeroException;
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
import org.apache.commons.math4.legacy.stat.descriptive.StatisticalSummary;
import org.apache.commons.math4.legacy.stat.descriptive.SummaryStatistics;
import org.apache.commons.math4.legacy.core.jdkmath.AccurateMath;
@ -99,290 +86,138 @@ import org.apache.commons.math4.legacy.core.jdkmath.AccurateMath;
*
* <strong>USAGE NOTES:</strong>
* <ul>
* <li>The {@code binCount} is set by default to 1000. A good rule of thumb
* is to set the bin count to approximately the length of the input file divided
* by 10. </li>
* <li>The input file <i>must</i> be a plain text file containing one valid numeric
* entry per line.</li>
* <li>
* The {@code binCount} is set by default to 1000. A good rule of thumb
* is to set the bin count to approximately the length of the input file
* divided by 10. </li>
* <li>
* The input file <i>must</i> be a plain text file containing one valid
* numeric entry per line.</li>
* </ul>
*/
public class EmpiricalDistribution extends AbstractRealDistribution
public final class EmpiricalDistribution extends AbstractRealDistribution
implements ContinuousDistribution {
/** Default bin count. */
public static final int DEFAULT_BIN_COUNT = 1000;
/** Character set for file input. */
private static final String FILE_CHARSET = "US-ASCII";
/** Bins' characteristics. */
/** Bins characteristics. */
private final List<SummaryStatistics> binStats;
/** Sample statistics. */
private SummaryStatistics sampleStats;
private final SummaryStatistics sampleStats;
/** Max loaded value. */
private double max = Double.NEGATIVE_INFINITY;
private final double max;
/** Min loaded value. */
private double min = Double.POSITIVE_INFINITY;
private final double min;
/** Grid size. */
private double delta;
private final double delta;
/** Number of bins. */
private final int binCount;
/** Whether the distribution is loaded. */
private boolean loaded;
/** Upper bounds of subintervals in (0,1) belonging to the bins. */
private double[] upperBounds;
/** Upper bounds of subintervals in (0, 1) belonging to the bins. */
private final double[] upperBounds;
/** Kernel factory. */
private final Function<SummaryStatistics, ContinuousDistribution> kernelFactory;
/**
* Creates a new EmpiricalDistribution with the default bin count.
*/
public EmpiricalDistribution() {
this(DEFAULT_BIN_COUNT);
}
/**
* Creates a new EmpiricalDistribution with the specified bin count.
* Creates a new instance with the specified data.
*
* @param binCount number of bins. Must be strictly positive.
* @param binCount Number of bins. Must be strictly positive.
* @param input Input data. Cannot be {@code null}.
* @param kernelFactory Kernel factory.
* @throws NotStrictlyPositiveException if {@code binCount <= 0}.
*/
public EmpiricalDistribution(int binCount) {
private EmpiricalDistribution(int binCount,
double[] input,
Function<SummaryStatistics, ContinuousDistribution> kernelFactory) {
if (binCount <= 0) {
throw new NotStrictlyPositiveException(binCount);
}
this.binCount = binCount;
binStats = new ArrayList<>();
}
/**
* Computes the empirical distribution from the provided
* array of numbers.
*
* @param in the input data array
* @exception NullArgumentException if in is null
*/
public void load(double[] in) {
DataAdapter da = new ArrayDataAdapter(in);
try {
da.computeStats();
// new adapter for the second pass
fillBinStats(new ArrayDataAdapter(in));
} catch (IOException ex) {
// Can't happen
throw new MathInternalError();
}
loaded = true;
}
/**
* Computes the empirical distribution using data read from a URL.
*
* <p>The input file <i>must</i> be an ASCII text file containing one
* valid numeric entry per line.</p>
*
* @param url url of the input file
*
* @throws IOException if an IO error occurs
* @throws NullArgumentException if url is null
* @throws ZeroException if URL contains no data
*/
public void load(URL url) throws IOException {
NullArgumentException.check(url);
Charset charset = Charset.forName(FILE_CHARSET);
BufferedReader in =
new BufferedReader(new InputStreamReader(url.openStream(), charset));
try {
DataAdapter da = new StreamDataAdapter(in);
da.computeStats();
if (sampleStats.getN() == 0) {
throw new ZeroException(LocalizedFormats.URL_CONTAINS_NO_DATA, url);
}
// new adapter for the second pass
in = new BufferedReader(new InputStreamReader(url.openStream(), charset));
fillBinStats(new StreamDataAdapter(in));
loaded = true;
} finally {
try {
in.close();
} catch (IOException ex) { //NOPMD
// ignore
}
}
}
/**
* Computes the empirical distribution from the input file.
*
* <p>The input file <i>must</i> be an ASCII text file containing one
* valid numeric entry per line.</p>
*
* @param file the input file
* @throws IOException if an IO error occurs
* @throws NullArgumentException if file is null
*/
public void load(File file) throws IOException {
NullArgumentException.check(file);
Charset charset = Charset.forName(FILE_CHARSET);
InputStream is = new FileInputStream(file);
BufferedReader in = new BufferedReader(new InputStreamReader(is, charset));
try {
DataAdapter da = new StreamDataAdapter(in);
da.computeStats();
// new adapter for second pass
is = new FileInputStream(file);
in = new BufferedReader(new InputStreamReader(is, charset));
fillBinStats(new StreamDataAdapter(in));
loaded = true;
} finally {
try {
in.close();
} catch (IOException ex) { //NOPMD
// ignore
}
}
}
/**
* Provides methods for computing {@code sampleStats} and
* {@code beanStats} abstracting the source of data.
*/
private abstract class DataAdapter {
/**
* Compute bin stats.
*
* @throws IOException if an error occurs computing bin stats
*/
public abstract void computeBinStats() throws IOException;
/**
* Compute sample statistics.
*
* @throws IOException if an error occurs computing sample stats
*/
public abstract void computeStats() throws IOException;
}
/**
* {@code DataAdapter} for data provided through some input stream.
*/
private class StreamDataAdapter extends DataAdapter {
/** Input stream providing access to the data. */
private final BufferedReader inputStream;
/**
* Create a StreamDataAdapter from a BufferedReader.
*
* @param in BufferedReader input stream
*/
StreamDataAdapter(BufferedReader in){
inputStream = in;
// First pass through the data.
sampleStats = new SummaryStatistics();
for (int i = 0; i < input.length; i++) {
sampleStats.addValue(input[i]);
}
/** {@inheritDoc} */
@Override
public void computeBinStats() throws IOException {
String str = null;
double val = 0.0d;
while ((str = inputStream.readLine()) != null) {
val = Double.parseDouble(str);
SummaryStatistics stats = binStats.get(findBin(val));
stats.addValue(val);
}
inputStream.close();
}
/** {@inheritDoc} */
@Override
public void computeStats() throws IOException {
String str = null;
double val = 0.0;
sampleStats = new SummaryStatistics();
while ((str = inputStream.readLine()) != null) {
val = Double.parseDouble(str);
sampleStats.addValue(val);
}
inputStream.close();
}
}
/**
* {@code DataAdapter} for data provided as array of doubles.
*/
private class ArrayDataAdapter extends DataAdapter {
/** Array of input data values. */
private final double[] inputArray;
/**
* Construct an ArrayDataAdapter from a double[] array.
*
* @param in double[] array holding the data
* @throws NullArgumentException if in is null
*/
ArrayDataAdapter(double[] in) {
NullArgumentException.check(in);
inputArray = in;
}
/** {@inheritDoc} */
@Override
public void computeStats() throws IOException {
sampleStats = new SummaryStatistics();
for (int i = 0; i < inputArray.length; i++) {
sampleStats.addValue(inputArray[i]);
}
}
/** {@inheritDoc} */
@Override
public void computeBinStats() throws IOException {
for (int i = 0; i < inputArray.length; i++) {
SummaryStatistics stats = binStats.get(findBin(inputArray[i]));
stats.addValue(inputArray[i]);
}
}
}
/**
* Fills binStats array (second pass through data file).
*
* @param da object providing access to the data
* @throws IOException if an IO error occurs
*/
private void fillBinStats(final DataAdapter da) throws IOException {
// Set up grid
// Set up grid.
min = sampleStats.getMin();
max = sampleStats.getMax();
delta = (max - min) / binCount;
// Initialize binStats ArrayList
if (!binStats.isEmpty()) {
binStats.clear();
}
for (int i = 0; i < binCount; i++) {
SummaryStatistics stats = new SummaryStatistics();
binStats.add(i, stats);
}
// Second pass through the data.
binStats = createBinStats(input);
// Filling data in binStats Array
da.computeBinStats();
// Assign upperBounds based on bin counts
// Assign upper bounds based on bin counts.
upperBounds = new double[binCount];
upperBounds[0] = binStats.get(0).getN() / (double) sampleStats.getN();
final double n = (double) sampleStats.getN();
upperBounds[0] = binStats.get(0).getN() / n;
for (int i = 1; i < binCount - 1; i++) {
upperBounds[i] = upperBounds[i - 1] +
binStats.get(i).getN() / (double) sampleStats.getN();
upperBounds[i] = upperBounds[i - 1] + binStats.get(i).getN() / n;
}
upperBounds[binCount - 1] = 1d;
this.kernelFactory = kernelFactory;
}
/**
* Factory that creates a new instance from the specified data.
*
* @param binCount Number of bins. Must be strictly positive.
* @param input Input data. Cannot be {@code null}.
* @param kernelFactory Factory for creating within-bin kernels.
* @return a new instance.
* @throws NotStrictlyPositiveException if {@code binCount <= 0}.
*/
public static EmpiricalDistribution from(int binCount,
double[] input,
Function<SummaryStatistics, ContinuousDistribution> kernelFactory) {
return new EmpiricalDistribution(binCount,
input,
kernelFactory);
}
/**
* Factory that creates a new instance from the specified data.
*
* @param binCount Number of bins. Must be strictly positive.
* @param input Input data. Cannot be {@code null}.
* @return a new instance.
* @throws NotStrictlyPositiveException if {@code binCount <= 0}.
*/
public static EmpiricalDistribution from(int binCount,
double[] input) {
return from(binCount, input, defaultKernel());
}
/**
* Create statistics (second pass through the data).
*
* @param input Input data.
* @return bins statistics.
*/
private List<SummaryStatistics> createBinStats(double[] input) {
final List<SummaryStatistics> binStats = new ArrayList<>();
for (int i = 0; i < binCount; i++) {
binStats.add(i, new SummaryStatistics());
}
// Second pass though the data.
for (int i = 0; i < input.length; i++) {
final double v = input[i];
binStats.get(findBin(v)).addValue(v);
}
return binStats;
}
/**
* Returns the index of the bin to which the given value belongs.
*
* @param value the value whose bin we are trying to find
* @return the index of the bin containing the value
* @param value Value whose bin we are trying to find.
* @return the index of the bin containing the value.
*/
private int findBin(double value) {
return AccurateMath.min(AccurateMath.max((int) AccurateMath.ceil((value - min) / delta) - 1, 0),
binCount - 1);
return Math.min(Math.max((int) AccurateMath.ceil((value - min) / delta) - 1,
0),
binCount - 1);
}
/**
@ -394,7 +229,7 @@ public class EmpiricalDistribution extends AbstractRealDistribution
* @throws IllegalStateException if the distribution has not been loaded
*/
public StatisticalSummary getSampleStats() {
return sampleStats;
return sampleStats.copy();
}
/**
@ -407,27 +242,33 @@ public class EmpiricalDistribution extends AbstractRealDistribution
}
/**
* Returns a List of {@link SummaryStatistics} instances containing
* statistics describing the values in each of the bins. The list is
* indexed on the bin number.
* Returns a copy of the {@link SummaryStatistics} instances containing
* statistics describing the values in each of the bins.
* The list is indexed on the bin number.
*
* @return List of bin statistics.
* @return the bins statistics.
*/
public List<SummaryStatistics> getBinStats() {
return binStats;
final List<SummaryStatistics> copy = new ArrayList<>();
for (SummaryStatistics s : binStats) {
copy.add(s.copy());
}
return copy;
}
/**
* <p>Returns a fresh copy of the array of upper bounds for the bins.
* Bins are: <br>
* [min,upperBounds[0]],(upperBounds[0],upperBounds[1]],...,
* (upperBounds[binCount-2], upperBounds[binCount-1] = max].</p>
* Returns the upper bounds of the bins.
*
* <p>Note: In versions 1.0-2.0 of commons-math, this method
* incorrectly returned the array of probability generator upper
* bounds now returned by {@link #getGeneratorUpperBounds()}.</p>
* Assuming array {@code u} is returned by this method, the bins are:
* <ul>
* <li>{@code (min, u[0])},</li>
* <li>{@code (u[0], u[1])},</li>
* <li>... ,</li>
* <li>{@code (u[binCount - 2], u[binCount - 1] = max)},</li>
* </ul>
*
* @return the bins upper bounds.
*
* @return array of bin upper bounds
* @since 2.1
*/
public double[] getUpperBounds() {
@ -440,20 +281,18 @@ public class EmpiricalDistribution extends AbstractRealDistribution
}
/**
* <p>Returns a fresh copy of the array of upper bounds of the subintervals
* of [0,1] used in generating data from the empirical distribution.
* Subintervals correspond to bins with lengths proportional to bin counts.</p>
* Returns the upper bounds of the subintervals of [0, 1] used in generating
* data from the empirical distribution.
* Subintervals correspond to bins with lengths proportional to bin counts.
*
* <strong>Preconditions:</strong><ul>
* <li>the distribution must be loaded before invoking this method</li></ul>
*
* <p>In versions 1.0-2.0 of commons-math, this array was (incorrectly) returned
* by {@link #getUpperBounds()}.</p>
*
* @since 2.1
* @return array of upper bounds of subintervals used in data generation
* @throws NullPointerException unless a {@code load} method has been
* called beforehand.
*
* @since 2.1
*/
public double[] getGeneratorUpperBounds() {
int len = upperBounds.length;
@ -462,15 +301,6 @@ public class EmpiricalDistribution extends AbstractRealDistribution
return out;
}
/**
* Property indicating whether or not the distribution has been loaded.
*
* @return true if the distribution has been loaded
*/
public boolean isLoaded() {
return loaded;
}
// Distribution methods.
/**
@ -485,15 +315,18 @@ public class EmpiricalDistribution extends AbstractRealDistribution
/**
* {@inheritDoc}
*
* <p>Returns the kernel density normalized so that its integral over each bin
* equals the bin mass.</p>
* Returns the kernel density normalized so that its integral over each bin
* equals the bin mass.
*
* Algorithm description:
* <ol>
* <li>Find the bin B that x belongs to.</li>
* <li>Compute K(B) = the mass of B with respect to the within-bin kernel (i.e., the
* integral of the kernel density over B).</li>
* <li>Return k(x) * P(B) / K(B), where k is the within-bin kernel density
* and P(B) is the mass of B.</li>
* </ol>
*
* <p>Algorithm description: <ol>
* <li>Find the bin B that x belongs to.</li>
* <li>Compute K(B) = the mass of B with respect to the within-bin kernel (i.e., the
* integral of the kernel density over B).</li>
* <li>Return k(x) * P(B) / K(B), where k is the within-bin kernel density
* and P(B) is the mass of B.</li></ol>
* @since 3.1
*/
@Override
@ -509,13 +342,15 @@ public class EmpiricalDistribution extends AbstractRealDistribution
/**
* {@inheritDoc}
*
* <p>Algorithm description:<ol>
* <li>Find the bin B that x belongs to.</li>
* <li>Compute P(B) = the mass of B and P(B-) = the combined mass of the bins below B.</li>
* <li>Compute K(B) = the probability mass of B with respect to the within-bin kernel
* and K(B-) = the kernel distribution evaluated at the lower endpoint of B</li>
* <li>Return P(B-) + P(B) * [K(x) - K(B-)] / K(B) where
* K(x) is the within-bin kernel distribution function evaluated at x.</li></ol>
* Algorithm description:
* <ol>
* <li>Find the bin B that x belongs to.</li>
* <li>Compute P(B) = the mass of B and P(B-) = the combined mass of the bins below B.</li>
* <li>Compute K(B) = the probability mass of B with respect to the within-bin kernel
* and K(B-) = the kernel distribution evaluated at the lower endpoint of B</li>
* <li>Return P(B-) + P(B) * [K(x) - K(B-)] / K(B) where
* K(x) is the within-bin kernel distribution function evaluated at x.</li>
* </ol>
* If K is a constant distribution, we return P(B-) + P(B) (counting the full
* mass of B).
*
@ -550,20 +385,24 @@ public class EmpiricalDistribution extends AbstractRealDistribution
/**
* {@inheritDoc}
*
* <p>Algorithm description:<ol>
* <li>Find the smallest i such that the sum of the masses of the bins
* through i is at least p.</li>
* <li>
* Let K be the within-bin kernel distribution for bin i.<br>
* Let K(B) be the mass of B under K. <br>
* Let K(B-) be K evaluated at the lower endpoint of B (the combined
* mass of the bins below B under K).<br>
* Let P(B) be the probability of bin i.<br>
* Let P(B-) be the sum of the bin masses below bin i. <br>
* Let pCrit = p - P(B-)<br>
* <li>Return the inverse of K evaluated at <br>
* Algorithm description:
* <ol>
* <li>Find the smallest i such that the sum of the masses of the bins
* through i is at least p.</li>
* <li>
* <ol>
* <li>Let K be the within-bin kernel distribution for bin i.</li>
* <li>Let K(B) be the mass of B under K.</li>
* <li>Let K(B-) be K evaluated at the lower endpoint of B (the combined
* mass of the bins below B under K).</li>
* <li>Let P(B) be the probability of bin i.</li>
* <li>Let P(B-) be the sum of the bin masses below bin i.</li>
* <li>Let pCrit = p - P(B-)</li>
* </ol>
* </li>
* <li>Return the inverse of K evaluated at
* K(B-) + pCrit * K(B) / P(B) </li>
* </ol>
* </ol>
*
* @since 3.1
*/
@ -651,15 +490,6 @@ public class EmpiricalDistribution extends AbstractRealDistribution
return true;
}
/**{@inheritDoc} */
@Override
public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
if (!loaded) {
throw new MathIllegalStateException(LocalizedFormats.DISTRIBUTION_NOT_LOADED);
}
return super.createSampler(rng);
}
/**
* The probability of bin i.
*
@ -717,19 +547,29 @@ public class EmpiricalDistribution extends AbstractRealDistribution
}
/**
* The within-bin smoothing kernel. Returns a Gaussian distribution
* parameterized by {@code bStats}, unless the bin contains only one
* observation, in which case a constant distribution is returned.
*
* @param bStats summary statistics for the bin
* @return within-bin kernel parameterized by bStats
* @param stats Bin statistics.
* @return the within-bin kernel.
*/
protected ContinuousDistribution getKernel(SummaryStatistics bStats) {
if (bStats.getN() <= 1 ||
bStats.getVariance() == 0) {
return new ConstantContinuousDistribution(bStats.getMean());
} else {
return new NormalDistribution(bStats.getMean(), bStats.getStandardDeviation());
}
private ContinuousDistribution getKernel(SummaryStatistics stats) {
return kernelFactory.apply(stats);
}
/**
* The within-bin smoothing kernel: A Gaussian distribution
* (unless the bin contains only one observation, in which case
* a constant distribution is returned).
*
* @return the within-bin kernel factory.
*/
private static Function<SummaryStatistics, ContinuousDistribution> defaultKernel() {
return stats -> {
if (stats.getN() <= 1 ||
stats.getVariance() == 0) {
return new ConstantContinuousDistribution(stats.getMean());
} else {
return new NormalDistribution(stats.getMean(),
stats.getStandardDeviation());
}
};
}
}

View File

@ -17,7 +17,6 @@
package org.apache.commons.math4.legacy.distribution;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URL;
@ -35,8 +34,6 @@ import org.apache.commons.math4.legacy.TestUtils;
import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
import org.apache.commons.math4.legacy.analysis.integration.BaseAbstractUnivariateIntegrator;
import org.apache.commons.math4.legacy.analysis.integration.IterativeLegendreGaussIntegrator;
import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
import org.apache.commons.math4.legacy.exception.NullArgumentException;
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.legacy.stat.descriptive.SummaryStatistics;
import org.apache.commons.math4.legacy.core.jdkmath.AccurateMath;
@ -48,32 +45,28 @@ import org.junit.Test;
* Test cases for the {@link EmpiricalDistribution} class.
*/
public final class EmpiricalDistributionTest extends RealDistributionAbstractTest {
protected EmpiricalDistribution empiricalDistribution = null;
protected EmpiricalDistribution empiricalDistribution2 = null;
protected File file = null;
protected URL url = null;
protected double[] dataArray = null;
protected final int n = 10000;
private EmpiricalDistribution empiricalDistribution = null;
private double[] dataArray = null;
private final int n = 10000;
/** Uniform bin mass = 10/10001 == mass of all but the first bin */
private final double binMass = 10d / (n + 1);
/** Mass of first bin = 11/10001 */
private final double firstBinMass = 11d / (n + 1);
@Override
@Before
public void setUp() {
super.setUp();
empiricalDistribution = new EmpiricalDistribution(100);
url = getClass().getResource("testData.txt");
final URL url = getClass().getResource("testData.txt");
final ArrayList<Double> list = new ArrayList<>();
try {
empiricalDistribution2 = new EmpiricalDistribution(100);
BufferedReader in =
new BufferedReader(new InputStreamReader(
url.openStream()));
final BufferedReader in = new BufferedReader(new InputStreamReader(url.openStream()));
String str = null;
while ((str = in.readLine()) != null) {
list.add(Double.valueOf(str));
}
in.close();
in = null;
} catch (IOException ex) {
Assert.fail("IOException " + ex);
}
@ -84,193 +77,68 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
dataArray[i] = data.doubleValue();
i++;
}
empiricalDistribution = EmpiricalDistribution.from(100, dataArray);
}
// MATH-1279
@Test(expected=NotStrictlyPositiveException.class)
public void testPrecondition1() {
new EmpiricalDistribution(0);
EmpiricalDistribution.from(0, new double[] {1,2,3});
}
/**
* Test EmpiricalDistribution.load() using sample data file.<br>
* Check that the sampleCount, mu and sigma match data in
* the sample data file. Also verify that load is idempotent.
* Test using data taken from sample data file.
* Check that the sampleCount, mu and sigma match data in the sample data file.
*/
@Test
public void testLoad() throws Exception {
// Load from a URL
empiricalDistribution.load(url);
checkDistribution();
// Load again from a file (also verifies idempotency of load)
File file = new File(url.toURI());
empiricalDistribution.load(file);
checkDistribution();
}
private void checkDistribution() {
public void testDoubleLoad() {
// testData File has 10000 values, with mean ~ 5.0, std dev ~ 1
// Make sure that loaded distribution matches this
Assert.assertEquals(empiricalDistribution.getSampleStats().getN(),1000,10E-7);
Assert.assertEquals(empiricalDistribution.getSampleStats().getN(),
1000, 1e-7);
//TODO: replace with statistical tests
Assert.assertEquals(empiricalDistribution.getSampleStats().getMean(),
5.069831575018909,10E-7);
5.069831575018909, 1e-7);
Assert.assertEquals(empiricalDistribution.getSampleStats().getStandardDeviation(),
1.0173699343977738,10E-7);
}
1.0173699343977738, 1e-7);
/**
* Test EmpiricalDistribution.load(double[]) using data taken from
* sample data file.<br>
* Check that the sampleCount, mu and sigma match data in
* the sample data file.
*/
@Test
public void testDoubleLoad() throws Exception {
empiricalDistribution2.load(dataArray);
// testData File has 10000 values, with mean ~ 5.0, std dev ~ 1
// Make sure that loaded distribution matches this
Assert.assertEquals(empiricalDistribution2.getSampleStats().getN(),1000,10E-7);
//TODO: replace with statistical tests
Assert.assertEquals(empiricalDistribution2.getSampleStats().getMean(),
5.069831575018909,10E-7);
Assert.assertEquals(empiricalDistribution2.getSampleStats().getStandardDeviation(),
1.0173699343977738,10E-7);
double[] bounds = empiricalDistribution2.getGeneratorUpperBounds();
double[] bounds = empiricalDistribution.getGeneratorUpperBounds();
Assert.assertEquals(bounds.length, 100);
Assert.assertEquals(bounds[99], 1.0, 10e-12);
}
// MATH-1531
@Test
public void testMath1531() {
final EmpiricalDistribution inputDistribution = new EmpiricalDistribution(120);
inputDistribution.load(new double[] {
50.993456376721454,
49.455345691918055,
49.527276095295804,
50.017183448668845,
49.10508147470046,
49.813998274118696,
50.87195348756139,
50.419474110037,
50.63614906979689,
49.49694777179407,
50.71799078406067,
50.03192853759164,
49.915092423165994,
49.56895392597687,
51.034638001064934,
50.681227971275945,
50.43749845081759,
49.86513120270245,
50.21475262482965,
49.99202971042547,
50.02382189838519,
49.386888585302884,
49.45585010202781,
49.988009479855435,
49.8136712206123,
49.6715197127997,
50.1981278397565,
49.842297508010276,
49.62491227740015,
50.05101916097176,
48.834912763303926,
49.806787657848574,
49.478236106374695,
49.56648347371614,
49.95069238081982,
49.71845132077346,
50.6097468705947,
49.80724637775541,
49.90448813086025,
49.39641861662603,
50.434295712893714,
49.227176959566734,
49.541126466050905,
49.03416593170446,
49.11584328494423,
49.61387482435674,
49.92877857995328,
50.70638552955101,
50.60078208448842,
49.39326233277838,
49.21488424364095,
49.69503351015096,
50.13733214001718,
50.22084761458942,
51.09804435604931,
49.18559131120419,
49.52286371605357,
49.34804374996689,
49.6901827776375,
50.01316351359638,
48.7751460520373,
50.12961836291053,
49.9978419772511,
49.885658399408584,
49.673438879979834,
49.45565980965606,
50.429747484906564,
49.40129274804164,
50.13034614008073,
49.87685735146651,
50.12967905393557,
50.323560376181696,
49.83519233651367,
49.37333369733053,
49.70074301611427,
50.11626105774947,
50.28249500380083,
50.543354367136466,
50.05866241335002,
50.39516515672527,
49.4838561463057,
50.451757089234796,
50.31370674203726,
49.79063762614284,
50.19652349768548,
49.75881420748814,
49.98371855036422,
49.82171344472916,
48.810793204162415,
49.37040569084592,
50.050641186203976,
50.48360952263646,
50.86666450358076,
50.463268776129844,
50.137489751888666,
50.23823061444118,
49.881460479468004,
50.641174398764356,
49.09314136851421,
48.80877928574451,
50.46197084844826,
49.97691704141741,
49.99933997561926,
50.25692254481885,
49.52973451252715,
49.81229858420664,
48.996112655915994,
48.740531054814674,
50.026642633066416,
49.98696633604899,
49.61307159972952,
50.5115278979726,
50.75245152442404,
50.51807785445929,
49.60929671768147,
49.1079533564074,
49.65347196551866,
49.31684818724059,
50.4906368627049,
50.37483603684714});
inputDistribution.inverseCumulativeProbability(0.7166666666666669);
final double[] data = new double[] {
50.993456376721454, 49.455345691918055, 49.527276095295804, 50.017183448668845, 49.10508147470046,
49.813998274118696, 50.87195348756139, 50.419474110037, 50.63614906979689, 49.49694777179407,
50.71799078406067, 50.03192853759164, 49.915092423165994, 49.56895392597687, 51.034638001064934,
50.681227971275945, 50.43749845081759, 49.86513120270245, 50.21475262482965, 49.99202971042547,
50.02382189838519, 49.386888585302884, 49.45585010202781, 49.988009479855435, 49.8136712206123,
49.6715197127997, 50.1981278397565, 49.842297508010276, 49.62491227740015, 50.05101916097176,
48.834912763303926, 49.806787657848574, 49.478236106374695, 49.56648347371614, 49.95069238081982,
49.71845132077346, 50.6097468705947, 49.80724637775541, 49.90448813086025, 49.39641861662603,
50.434295712893714, 49.227176959566734, 49.541126466050905, 49.03416593170446, 49.11584328494423,
49.61387482435674, 49.92877857995328, 50.70638552955101, 50.60078208448842, 49.39326233277838,
49.21488424364095, 49.69503351015096, 50.13733214001718, 50.22084761458942, 51.09804435604931,
49.18559131120419, 49.52286371605357, 49.34804374996689, 49.6901827776375, 50.01316351359638,
48.7751460520373, 50.12961836291053, 49.9978419772511, 49.885658399408584, 49.673438879979834,
49.45565980965606, 50.429747484906564, 49.40129274804164, 50.13034614008073, 49.87685735146651,
50.12967905393557, 50.323560376181696, 49.83519233651367, 49.37333369733053, 49.70074301611427,
50.11626105774947, 50.28249500380083, 50.543354367136466, 50.05866241335002, 50.39516515672527,
49.4838561463057, 50.451757089234796, 50.31370674203726, 49.79063762614284, 50.19652349768548,
49.75881420748814, 49.98371855036422, 49.82171344472916, 48.810793204162415, 49.37040569084592,
50.050641186203976, 50.48360952263646, 50.86666450358076, 50.463268776129844, 50.137489751888666,
50.23823061444118, 49.881460479468004, 50.641174398764356, 49.09314136851421, 48.80877928574451,
50.46197084844826, 49.97691704141741, 49.99933997561926, 50.25692254481885, 49.52973451252715,
49.81229858420664, 48.996112655915994, 48.740531054814674, 50.026642633066416, 49.98696633604899,
49.61307159972952, 50.5115278979726, 50.75245152442404, 50.51807785445929, 49.60929671768147,
49.1079533564074, 49.65347196551866, 49.31684818724059, 50.4906368627049, 50.37483603684714
};
EmpiricalDistribution.from(120, data).inverseCumulativeProbability(0.7166666666666669);
}
/**
@ -279,84 +147,42 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
* these tests will fail even if the code is working as designed.
*/
@Test
public void testNext() throws Exception {
tstGen(0.1);
tstDoubleGen(0.1);
}
/**
* Make sure exception thrown if sampling is attempted
* before loading empiricalDistribution.
*/
@Test
public void testNextFail1() {
try {
empiricalDistribution.createSampler(RandomSource.create(RandomSource.JDK)).sample();
Assert.fail("Expecting MathIllegalStateException");
} catch (MathIllegalStateException ex) {
// expected
}
}
/**
* Make sure exception thrown if sampling is attempted
* before loading empiricalDistribution.
*/
@Test
public void testNextFail2() {
try {
empiricalDistribution2.createSampler(RandomSource.create(RandomSource.JDK)).sample();
Assert.fail("Expecting MathIllegalStateException");
} catch (MathIllegalStateException ex) {
// expected
}
public void testNext() {
tstGen(empiricalDistribution,
0.1);
}
/**
* Make sure we can handle a grid size that is too fine
*/
@Test
public void testGridTooFine() throws Exception {
empiricalDistribution = new EmpiricalDistribution(1001);
tstGen(0.1);
empiricalDistribution2 = new EmpiricalDistribution(1001);
tstDoubleGen(0.1);
public void testGridTooFine() {
tstGen(EmpiricalDistribution.from(1001, dataArray),
0.1);
}
/**
* How about too fat?
*/
@Test
public void testGridTooFat() throws Exception {
empiricalDistribution = new EmpiricalDistribution(1);
tstGen(5); // ridiculous tolerance; but ridiculous grid size
public void testGridTooFat() {
tstGen(EmpiricalDistribution.from(1, dataArray),
5); // ridiculous tolerance; but ridiculous grid size
// really just checking to make sure we do not bomb
empiricalDistribution2 = new EmpiricalDistribution(1);
tstDoubleGen(5);
}
/**
* Test bin index overflow problem (BZ 36450)
*/
@Test
public void testBinIndexOverflow() throws Exception {
public void testBinIndexOverflow() {
double[] x = new double[] {9474.94326071674, 2080107.8865462579};
new EmpiricalDistribution().load(x);
EmpiricalDistribution.from(1000, x);
}
@Test(expected=NullArgumentException.class)
@Test(expected=NullPointerException.class)
public void testLoadNullDoubleArray() {
new EmpiricalDistribution().load((double[]) null);
}
@Test(expected=NullArgumentException.class)
public void testLoadNullURL() throws Exception {
new EmpiricalDistribution().load((URL) null);
}
@Test(expected=NullArgumentException.class)
public void testLoadNullFile() throws Exception {
new EmpiricalDistribution().load((File) null);
EmpiricalDistribution.from(1000, null);
}
/**
@ -365,8 +191,7 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
@Test
public void testGetBinUpperBounds() {
double[] testData = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10};
EmpiricalDistribution dist = new EmpiricalDistribution(5);
dist.load(testData);
EmpiricalDistribution dist = EmpiricalDistribution.from(5, testData);
double[] expectedBinUpperBounds = {2, 4, 6, 8, 10};
double[] expectedGeneratorUpperBounds = {4d/13d, 7d/13d, 9d/13d, 11d/13d, 1};
double tol = 10E-12;
@ -374,23 +199,22 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
TestUtils.assertEquals(expectedGeneratorUpperBounds, dist.getGeneratorUpperBounds(), tol);
}
private void verifySame(EmpiricalDistribution d1, EmpiricalDistribution d2) {
Assert.assertEquals(d1.isLoaded(), d2.isLoaded());
private void verifySame(EmpiricalDistribution d1,
EmpiricalDistribution d2) {
Assert.assertEquals(d1.getBinCount(), d2.getBinCount());
Assert.assertEquals(d1.getSampleStats(), d2.getSampleStats());
if (d1.isLoaded()) {
for (int i = 0; i < d1.getUpperBounds().length; i++) {
Assert.assertEquals(d1.getUpperBounds()[i], d2.getUpperBounds()[i], 0);
}
Assert.assertEquals(d1.getBinStats(), d2.getBinStats());
for (int i = 0; i < d1.getUpperBounds().length; i++) {
Assert.assertEquals(d1.getUpperBounds()[i], d2.getUpperBounds()[i], 0);
}
Assert.assertEquals(d1.getBinStats(), d2.getBinStats());
}
private void tstGen(double tolerance)throws Exception {
empiricalDistribution.load(url);
ContinuousDistribution.Sampler sampler
= empiricalDistribution.createSampler(RandomSource.create(RandomSource.WELL_19937_C, 1000));
SummaryStatistics stats = new SummaryStatistics();
private void tstGen(EmpiricalDistribution dist,
double tolerance) {
final ContinuousDistribution.Sampler sampler
= dist.createSampler(RandomSource.WELL_19937_C.create(1000));
final SummaryStatistics stats = new SummaryStatistics();
for (int i = 1; i < 1000; i++) {
stats.addValue(sampler.sample());
}
@ -398,38 +222,19 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
Assert.assertEquals("std dev", 1.0173699343977738, stats.getStandardDeviation(),tolerance);
}
private void tstDoubleGen(double tolerance)throws Exception {
empiricalDistribution2.load(dataArray);
ContinuousDistribution.Sampler sampler
= empiricalDistribution2.createSampler(RandomSource.create(RandomSource.WELL_19937_C, 1000));
SummaryStatistics stats = new SummaryStatistics();
for (int i = 1; i < 1000; i++) {
stats.addValue(sampler.sample());
}
Assert.assertEquals("mean", 5.069831575018909, stats.getMean(), tolerance);
Assert.assertEquals("std dev", 1.0173699343977738, stats.getStandardDeviation(), tolerance);
}
// Setup for distribution tests
@Override
public ContinuousDistribution makeDistribution() {
// Create a uniform distribution on [0, 10,000]
// Create a uniform distribution on [0, 10,000].
final double[] sourceData = new double[n + 1];
for (int i = 0; i < n + 1; i++) {
sourceData[i] = i;
}
EmpiricalDistribution dist = new EmpiricalDistribution();
dist.load(sourceData);
EmpiricalDistribution dist = EmpiricalDistribution.from(1000, sourceData);
return dist;
}
/** Uniform bin mass = 10/10001 == mass of all but the first bin */
private final double binMass = 10d / (n + 1);
/** Mass of first bin = 11/10001 */
private final double firstBinMass = 11d / (n + 1);
@Override
public double[] makeCumulativeTestPoints() {
final double[] testPoints = new double[] {9, 10, 15, 1000, 5004, 9999};
@ -529,10 +334,9 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
for (int i = 51; i < 100; i++) {
data[i] = 1 - 1 / (100 - (double) i + 2);
}
EmpiricalDistribution dist = new EmpiricalDistribution(10);
dist.load(data);
EmpiricalDistribution dist = EmpiricalDistribution.from(10, data);
ContinuousDistribution.Sampler sampler
= dist.createSampler(RandomSource.create(RandomSource.WELL_19937_C, 1000));
= dist.createSampler(RandomSource.WELL_19937_C.create(1000));
for (int i = 0; i < 1000; i++) {
final double dev = sampler.sample();
Assert.assertTrue(dev < 1);
@ -546,10 +350,9 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
@Test
public void testNoBinVariance() {
final double[] data = {0, 0, 1, 1};
EmpiricalDistribution dist = new EmpiricalDistribution(2);
dist.load(data);
EmpiricalDistribution dist = EmpiricalDistribution.from(2, data);
ContinuousDistribution.Sampler sampler
= dist.createSampler(RandomSource.create(RandomSource.WELL_19937_C, 1000));
= dist.createSampler(RandomSource.WELL_19937_C.create(1000));
for (int i = 0; i < 1000; i++) {
final double dev = sampler.sample();
Assert.assertTrue(dev == 0 || dev == 1);
@ -588,17 +391,18 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
@Test
public void testKernelOverrideConstant() {
final EmpiricalDistribution dist = new ConstantKernelEmpiricalDistribution(5);
final double[] data = {1d,2d,3d, 4d,5d,6d, 7d,8d,9d, 10d,11d,12d, 13d,14d,15d};
dist.load(data);
ContinuousDistribution.Sampler sampler
= dist.createSampler(RandomSource.create(RandomSource.WELL_19937_C, 1000));
final double[] data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
final EmpiricalDistribution dist =
EmpiricalDistribution.from(5, data,
s -> new ConstantContinuousDistribution(s.getMean()));
final ContinuousDistribution.Sampler sampler
= dist.createSampler(RandomSource.WELL_19937_C.create(1000));
// Bin masses concentrated on 2, 5, 8, 11, 14 <- effectively discrete uniform distribution over these
double[] values = {2d, 5d, 8d, 11d, 14d};
final double[] values = {2d, 5d, 8d, 11d, 14d};
for (int i = 0; i < 20; i++) {
Assert.assertTrue(Arrays.binarySearch(values, sampler.sample()) >= 0);
}
final double tol = 10E-12;
final double tol = 1e-12;
Assert.assertEquals(0.0, dist.cumulativeProbability(1), tol);
Assert.assertEquals(0.2, dist.cumulativeProbability(2), tol);
Assert.assertEquals(0.6, dist.cumulativeProbability(10), tol);
@ -616,11 +420,12 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
@Test
public void testKernelOverrideUniform() {
final EmpiricalDistribution dist = new UniformKernelEmpiricalDistribution(5);
final double[] data = {1d,2d,3d, 4d,5d,6d, 7d,8d,9d, 10d,11d,12d, 13d,14d,15d};
dist.load(data);
ContinuousDistribution.Sampler sampler
= dist.createSampler(RandomSource.create(RandomSource.WELL_19937_C, 1000));
final double[] data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
final EmpiricalDistribution dist =
EmpiricalDistribution.from(5, data,
s -> new UniformContinuousDistribution(s.getMin(), s.getMax()));
final ContinuousDistribution.Sampler sampler
= dist.createSampler(RandomSource.WELL_19937_C.create(1000));
// Kernels are uniform distributions on [1,3], [4,6], [7,9], [10,12], [13,15]
final double bounds[] = {3d, 6d, 9d, 12d};
final double tol = 10E-12;
@ -648,7 +453,7 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
@Test
public void testMath1431() {
final UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_19937_C, 1000);
final UniformRandomProvider rng = RandomSource.WELL_19937_C.create(1000);
final ContinuousDistribution.Sampler exponentialDistributionSampler
= new ExponentialDistribution(0.05).createSampler(rng);
final double[] empiricalDataPoints = new double[3000];
@ -656,8 +461,7 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
empiricalDataPoints[i] = exponentialDistributionSampler.sample();
}
final EmpiricalDistribution testDistribution = new EmpiricalDistribution(100);
testDistribution.load(empiricalDataPoints);
final EmpiricalDistribution testDistribution = EmpiricalDistribution.from(100, empiricalDataPoints);
for (int i = 0; i < 1000; i++) {
final double point = rng.nextDouble();
@ -684,8 +488,7 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
6461.3944, 6384.1345
};
final EmpiricalDistribution ed = new EmpiricalDistribution(data.length / 10);
ed.load(data);
final EmpiricalDistribution ed = EmpiricalDistribution.from(data.length / 10, data);
double v;
double p;
@ -702,33 +505,4 @@ public final class EmpiricalDistributionTest extends RealDistributionAbstractTes
v = ed.inverseCumulativeProbability(p);
Assert.assertTrue("p=" + p + " => v=" + v, v < 6350);
}
/**
* Empirical distribution using a constant smoothing kernel.
*/
private class ConstantKernelEmpiricalDistribution extends EmpiricalDistribution {
private static final long serialVersionUID = 1L;
ConstantKernelEmpiricalDistribution(int i) {
super(i);
}
// Use constant distribution equal to bin mean within bin
@Override
protected ContinuousDistribution getKernel(SummaryStatistics bStats) {
return new ConstantContinuousDistribution(bStats.getMean());
}
}
/**
* Empirical distribution using a uniform smoothing kernel.
*/
private class UniformKernelEmpiricalDistribution extends EmpiricalDistribution {
private static final long serialVersionUID = 2963149194515159653L;
UniformKernelEmpiricalDistribution(int i) {
super(i);
}
@Override
protected ContinuousDistribution getKernel(SummaryStatistics bStats) {
return new UniformContinuousDistribution(bStats.getMin(), bStats.getMax());
}
}
}