MATH-1158

Adapt "examples" code to the new sampler API.
This commit is contained in:
Gilles 2016-03-26 13:00:08 +01:00
parent b577805347
commit e366894658
3 changed files with 66 additions and 43 deletions

View File

@ -34,6 +34,8 @@ import java.util.List;
import javax.swing.JComponent;
import javax.swing.JLabel;
import org.apache.commons.math4.distribution.RealDistribution;
import org.apache.commons.math4.distribution.UniformRealDistribution;
import org.apache.commons.math4.distribution.NormalDistribution;
import org.apache.commons.math4.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math4.ml.clustering.CentroidCluster;
@ -45,9 +47,9 @@ import org.apache.commons.math4.ml.clustering.DoublePoint;
import org.apache.commons.math4.ml.clustering.FuzzyKMeansClusterer;
import org.apache.commons.math4.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math4.random.RandomAdaptor;
import org.apache.commons.math4.random.RandomGenerator;
import org.apache.commons.math4.random.SobolSequenceGenerator;
import org.apache.commons.math4.random.Well19937c;
import org.apache.commons.math4.rng.UniformRandomProvider;
import org.apache.commons.math4.rng.RandomSource;
import org.apache.commons.math4.util.FastMath;
import org.apache.commons.math4.util.Pair;
import org.apache.commons.math4.userguide.ExampleUtils.ExampleFrame;
@ -59,12 +61,16 @@ import org.apache.commons.math4.userguide.ExampleUtils.ExampleFrame;
*/
public class ClusterAlgorithmComparison {
public static List<Vector2D> makeCircles(int samples, boolean shuffle, double noise, double factor, final RandomGenerator random) {
public static List<Vector2D> makeCircles(int samples,
boolean shuffle,
double noise,
double factor,
UniformRandomProvider rng) {
if (factor < 0 || factor > 1) {
throw new IllegalArgumentException();
}
NormalDistribution dist = new NormalDistribution(random, 0.0, noise, 1e-9);
RealDistribution.Sampler dist = new NormalDistribution(0.0, noise).createSampler(rng);
List<Vector2D> points = new ArrayList<Vector2D>();
double range = 2.0 * FastMath.PI;
@ -78,14 +84,18 @@ public class ClusterAlgorithmComparison {
}
if (shuffle) {
Collections.shuffle(points, new RandomAdaptor(random));
// Collections.shuffle(points, new RandomAdaptor(rng)); // XXX TODO
Collections.shuffle(points); // XXX temporary workaround
}
return points;
}
public static List<Vector2D> makeMoons(int samples, boolean shuffle, double noise, RandomGenerator random) {
NormalDistribution dist = new NormalDistribution(random, 0.0, noise, 1e-9);
public static List<Vector2D> makeMoons(int samples,
boolean shuffle,
double noise,
UniformRandomProvider rng) {
RealDistribution.Sampler dist = new NormalDistribution(0.0, noise).createSampler(rng);
int nSamplesOut = samples / 2;
int nSamplesIn = samples - nSamplesOut;
@ -105,23 +115,26 @@ public class ClusterAlgorithmComparison {
}
if (shuffle) {
Collections.shuffle(points, new RandomAdaptor(random));
// Collections.shuffle(points, new RandomAdaptor(rng)); // XXX TODO
Collections.shuffle(points); // XXX temporary workaround
}
return points;
}
public static List<Vector2D> makeBlobs(int samples, int centers, double clusterStd,
double min, double max, boolean shuffle, RandomGenerator random) {
public static List<Vector2D> makeBlobs(int samples,
int centers,
double clusterStd,
double min,
double max,
boolean shuffle,
UniformRandomProvider rng) {
RealDistribution.Sampler uniform = new UniformRealDistribution(min, max).createSampler(rng);
RealDistribution.Sampler gauss = new NormalDistribution(0.0, clusterStd).createSampler(rng);
NormalDistribution dist = new NormalDistribution(random, 0.0, clusterStd, 1e-9);
double range = max - min;
Vector2D[] centerPoints = new Vector2D[centers];
for (int i = 0; i < centers; i++) {
double x = random.nextDouble() * range + min;
double y = random.nextDouble() * range + min;
centerPoints[i] = new Vector2D(x, y);
centerPoints[i] = new Vector2D(uniform.sample(), uniform.sample());
}
int[] nSamplesPerCenter = new int[centers];
@ -135,13 +148,13 @@ public class ClusterAlgorithmComparison {
List<Vector2D> points = new ArrayList<Vector2D>();
for (int i = 0; i < centers; i++) {
for (int j = 0; j < nSamplesPerCenter[i]; j++) {
Vector2D point = new Vector2D(dist.sample(), dist.sample());
points.add(point.add(centerPoints[i]));
points.add(centerPoints[i].add(generateNoiseVector(gauss)));
}
}
if (shuffle) {
Collections.shuffle(points, new RandomAdaptor(random));
// Collections.shuffle(points, new RandomAdaptor(rng)); // XXX TODO
Collections.shuffle(points); // XXX temporary workaround
}
return points;
@ -162,11 +175,15 @@ public class ClusterAlgorithmComparison {
return points;
}
public static Vector2D generateNoiseVector(NormalDistribution distribution) {
public static Vector2D generateNoiseVector(RealDistribution.Sampler distribution) {
return new Vector2D(distribution.sample(), distribution.sample());
}
public static List<DoublePoint> normalize(final List<Vector2D> input, double minX, double maxX, double minY, double maxY) {
public static List<DoublePoint> normalize(final List<Vector2D> input,
double minX,
double maxX,
double minY,
double maxY) {
double rangeX = maxX - minX;
double rangeY = maxY - minY;
List<DoublePoint> points = new ArrayList<DoublePoint>();
@ -190,7 +207,7 @@ public class ClusterAlgorithmComparison {
int nSamples = 1500;
RandomGenerator rng = new Well19937c(0);
UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_19937_C, 0);
List<List<DoublePoint>> datasets = new ArrayList<List<DoublePoint>>();
datasets.add(normalize(makeCircles(nSamples, true, 0.04, 0.5, rng), -1, 1, -1, 1));
@ -200,11 +217,16 @@ public class ClusterAlgorithmComparison {
List<Pair<String, Clusterer<DoublePoint>>> algorithms = new ArrayList<Pair<String, Clusterer<DoublePoint>>>();
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("KMeans\n(k=2)", new KMeansPlusPlusClusterer<DoublePoint>(2)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("KMeans\n(k=3)", new KMeansPlusPlusClusterer<DoublePoint>(3)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("FuzzyKMeans\n(k=3, fuzzy=2)", new FuzzyKMeansClusterer<DoublePoint>(3, 2)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("FuzzyKMeans\n(k=3, fuzzy=10)", new FuzzyKMeansClusterer<DoublePoint>(3, 10)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("DBSCAN\n(eps=.1, min=3)", new DBSCANClusterer<DoublePoint>(0.1, 3)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("KMeans\n(k=2)",
new KMeansPlusPlusClusterer<DoublePoint>(2)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("KMeans\n(k=3)",
new KMeansPlusPlusClusterer<DoublePoint>(3)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("FuzzyKMeans\n(k=3, fuzzy=2)",
new FuzzyKMeansClusterer<DoublePoint>(3, 2)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("FuzzyKMeans\n(k=3, fuzzy=10)",
new FuzzyKMeansClusterer<DoublePoint>(3, 10)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("DBSCAN\n(eps=.1, min=3)",
new DBSCANClusterer<DoublePoint>(0.1, 3)));
GridBagConstraints c = new GridBagConstraints();
c.fill = GridBagConstraints.VERTICAL;

View File

@ -46,8 +46,6 @@ import org.apache.commons.math4.distribution.ParetoDistribution;
import org.apache.commons.math4.distribution.RealDistribution;
import org.apache.commons.math4.distribution.TDistribution;
import org.apache.commons.math4.distribution.WeibullDistribution;
import org.apache.commons.math4.random.MersenneTwister;
import org.apache.commons.math4.random.RandomGenerator;
import org.apache.commons.math4.util.FastMath;
import org.apache.commons.math4.userguide.ExampleUtils.ExampleFrame;
@ -242,14 +240,13 @@ public class RealDistributionComparison {
container.add(comp, c);
c.gridx++;
RandomGenerator rng = new MersenneTwister(0);
comp = createComponent("Levy", 0, 3,
new String[] { "c=0.5", "c=1", "c=2", "c=4", "c=8" },
new LevyDistribution(rng, 0, 0.5),
new LevyDistribution(rng, 0, 1),
new LevyDistribution(rng, 0, 2),
new LevyDistribution(rng, 0, 4),
new LevyDistribution(rng, 0, 8));
new LevyDistribution(0, 0.5),
new LevyDistribution(0, 1),
new LevyDistribution(0, 2),
new LevyDistribution(0, 4),
new LevyDistribution(0, 8));
container.add(comp, c);
c.gridy++;

View File

@ -20,6 +20,8 @@ package org.apache.commons.math4.userguide.sofm;
import org.apache.commons.math4.geometry.euclidean.threed.Vector3D;
import org.apache.commons.math4.geometry.euclidean.threed.Rotation;
import org.apache.commons.math4.random.UnitSphereRandomVectorGenerator;
import org.apache.commons.math4.rng.UniformRandomProvider;
import org.apache.commons.math4.rng.RandomSource;
import org.apache.commons.math4.distribution.RealDistribution;
import org.apache.commons.math4.distribution.UniformRealDistribution;
@ -57,11 +59,13 @@ public class ChineseRings {
final UnitSphereRandomVectorGenerator unit
= new UnitSphereRandomVectorGenerator(2);
final RealDistribution radius1
final UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_19937_C);
final RealDistribution.Sampler radius1
= new UniformRealDistribution(radiusRing1 - halfWidthRing1,
radiusRing1 + halfWidthRing1);
final RealDistribution widthRing1
= new UniformRealDistribution(-halfWidthRing1, halfWidthRing1);
radiusRing1 + halfWidthRing1).createSampler(rng);
final RealDistribution.Sampler widthRing1
= new UniformRealDistribution(-halfWidthRing1, halfWidthRing1).createSampler(rng);
for (int i = 0; i < numPointsRing1; i++) {
final double[] v = unit.nextVector();
@ -72,11 +76,11 @@ public class ChineseRings {
widthRing1.sample());
}
final RealDistribution radius2
final RealDistribution.Sampler radius2
= new UniformRealDistribution(radiusRing2 - halfWidthRing2,
radiusRing2 + halfWidthRing2);
final RealDistribution widthRing2
= new UniformRealDistribution(-halfWidthRing2, halfWidthRing2);
radiusRing2 + halfWidthRing2).createSampler(rng);
final RealDistribution.Sampler widthRing2
= new UniformRealDistribution(-halfWidthRing2, halfWidthRing2).createSampler(rng);
for (int i = 0; i < numPointsRing2; i++) {
final double[] v = unit.nextVector();