MATH-1158

Adapt "examples" code to the new sampler API.
This commit is contained in:
Gilles 2016-03-26 13:00:08 +01:00
parent b577805347
commit e366894658
3 changed files with 66 additions and 43 deletions

View File

@ -34,6 +34,8 @@ import java.util.List;
import javax.swing.JComponent; import javax.swing.JComponent;
import javax.swing.JLabel; import javax.swing.JLabel;
import org.apache.commons.math4.distribution.RealDistribution;
import org.apache.commons.math4.distribution.UniformRealDistribution;
import org.apache.commons.math4.distribution.NormalDistribution; import org.apache.commons.math4.distribution.NormalDistribution;
import org.apache.commons.math4.geometry.euclidean.twod.Vector2D; import org.apache.commons.math4.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math4.ml.clustering.CentroidCluster; import org.apache.commons.math4.ml.clustering.CentroidCluster;
@ -45,9 +47,9 @@ import org.apache.commons.math4.ml.clustering.DoublePoint;
import org.apache.commons.math4.ml.clustering.FuzzyKMeansClusterer; import org.apache.commons.math4.ml.clustering.FuzzyKMeansClusterer;
import org.apache.commons.math4.ml.clustering.KMeansPlusPlusClusterer; import org.apache.commons.math4.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math4.random.RandomAdaptor; import org.apache.commons.math4.random.RandomAdaptor;
import org.apache.commons.math4.random.RandomGenerator;
import org.apache.commons.math4.random.SobolSequenceGenerator; import org.apache.commons.math4.random.SobolSequenceGenerator;
import org.apache.commons.math4.random.Well19937c; import org.apache.commons.math4.rng.UniformRandomProvider;
import org.apache.commons.math4.rng.RandomSource;
import org.apache.commons.math4.util.FastMath; import org.apache.commons.math4.util.FastMath;
import org.apache.commons.math4.util.Pair; import org.apache.commons.math4.util.Pair;
import org.apache.commons.math4.userguide.ExampleUtils.ExampleFrame; import org.apache.commons.math4.userguide.ExampleUtils.ExampleFrame;
@ -59,12 +61,16 @@ import org.apache.commons.math4.userguide.ExampleUtils.ExampleFrame;
*/ */
public class ClusterAlgorithmComparison { public class ClusterAlgorithmComparison {
public static List<Vector2D> makeCircles(int samples, boolean shuffle, double noise, double factor, final RandomGenerator random) { public static List<Vector2D> makeCircles(int samples,
boolean shuffle,
double noise,
double factor,
UniformRandomProvider rng) {
if (factor < 0 || factor > 1) { if (factor < 0 || factor > 1) {
throw new IllegalArgumentException(); throw new IllegalArgumentException();
} }
NormalDistribution dist = new NormalDistribution(random, 0.0, noise, 1e-9); RealDistribution.Sampler dist = new NormalDistribution(0.0, noise).createSampler(rng);
List<Vector2D> points = new ArrayList<Vector2D>(); List<Vector2D> points = new ArrayList<Vector2D>();
double range = 2.0 * FastMath.PI; double range = 2.0 * FastMath.PI;
@ -78,14 +84,18 @@ public class ClusterAlgorithmComparison {
} }
if (shuffle) { if (shuffle) {
Collections.shuffle(points, new RandomAdaptor(random)); // Collections.shuffle(points, new RandomAdaptor(rng)); // XXX TODO
Collections.shuffle(points); // XXX temporary workaround
} }
return points; return points;
} }
public static List<Vector2D> makeMoons(int samples, boolean shuffle, double noise, RandomGenerator random) { public static List<Vector2D> makeMoons(int samples,
NormalDistribution dist = new NormalDistribution(random, 0.0, noise, 1e-9); boolean shuffle,
double noise,
UniformRandomProvider rng) {
RealDistribution.Sampler dist = new NormalDistribution(0.0, noise).createSampler(rng);
int nSamplesOut = samples / 2; int nSamplesOut = samples / 2;
int nSamplesIn = samples - nSamplesOut; int nSamplesIn = samples - nSamplesOut;
@ -105,23 +115,26 @@ public class ClusterAlgorithmComparison {
} }
if (shuffle) { if (shuffle) {
Collections.shuffle(points, new RandomAdaptor(random)); // Collections.shuffle(points, new RandomAdaptor(rng)); // XXX TODO
Collections.shuffle(points); // XXX temporary workaround
} }
return points; return points;
} }
public static List<Vector2D> makeBlobs(int samples, int centers, double clusterStd, public static List<Vector2D> makeBlobs(int samples,
double min, double max, boolean shuffle, RandomGenerator random) { int centers,
double clusterStd,
double min,
double max,
boolean shuffle,
UniformRandomProvider rng) {
RealDistribution.Sampler uniform = new UniformRealDistribution(min, max).createSampler(rng);
RealDistribution.Sampler gauss = new NormalDistribution(0.0, clusterStd).createSampler(rng);
NormalDistribution dist = new NormalDistribution(random, 0.0, clusterStd, 1e-9);
double range = max - min;
Vector2D[] centerPoints = new Vector2D[centers]; Vector2D[] centerPoints = new Vector2D[centers];
for (int i = 0; i < centers; i++) { for (int i = 0; i < centers; i++) {
double x = random.nextDouble() * range + min; centerPoints[i] = new Vector2D(uniform.sample(), uniform.sample());
double y = random.nextDouble() * range + min;
centerPoints[i] = new Vector2D(x, y);
} }
int[] nSamplesPerCenter = new int[centers]; int[] nSamplesPerCenter = new int[centers];
@ -135,13 +148,13 @@ public class ClusterAlgorithmComparison {
List<Vector2D> points = new ArrayList<Vector2D>(); List<Vector2D> points = new ArrayList<Vector2D>();
for (int i = 0; i < centers; i++) { for (int i = 0; i < centers; i++) {
for (int j = 0; j < nSamplesPerCenter[i]; j++) { for (int j = 0; j < nSamplesPerCenter[i]; j++) {
Vector2D point = new Vector2D(dist.sample(), dist.sample()); points.add(centerPoints[i].add(generateNoiseVector(gauss)));
points.add(point.add(centerPoints[i]));
} }
} }
if (shuffle) { if (shuffle) {
Collections.shuffle(points, new RandomAdaptor(random)); // Collections.shuffle(points, new RandomAdaptor(rng)); // XXX TODO
Collections.shuffle(points); // XXX temporary workaround
} }
return points; return points;
@ -162,11 +175,15 @@ public class ClusterAlgorithmComparison {
return points; return points;
} }
public static Vector2D generateNoiseVector(NormalDistribution distribution) { public static Vector2D generateNoiseVector(RealDistribution.Sampler distribution) {
return new Vector2D(distribution.sample(), distribution.sample()); return new Vector2D(distribution.sample(), distribution.sample());
} }
public static List<DoublePoint> normalize(final List<Vector2D> input, double minX, double maxX, double minY, double maxY) { public static List<DoublePoint> normalize(final List<Vector2D> input,
double minX,
double maxX,
double minY,
double maxY) {
double rangeX = maxX - minX; double rangeX = maxX - minX;
double rangeY = maxY - minY; double rangeY = maxY - minY;
List<DoublePoint> points = new ArrayList<DoublePoint>(); List<DoublePoint> points = new ArrayList<DoublePoint>();
@ -190,7 +207,7 @@ public class ClusterAlgorithmComparison {
int nSamples = 1500; int nSamples = 1500;
RandomGenerator rng = new Well19937c(0); UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_19937_C, 0);
List<List<DoublePoint>> datasets = new ArrayList<List<DoublePoint>>(); List<List<DoublePoint>> datasets = new ArrayList<List<DoublePoint>>();
datasets.add(normalize(makeCircles(nSamples, true, 0.04, 0.5, rng), -1, 1, -1, 1)); datasets.add(normalize(makeCircles(nSamples, true, 0.04, 0.5, rng), -1, 1, -1, 1));
@ -200,11 +217,16 @@ public class ClusterAlgorithmComparison {
List<Pair<String, Clusterer<DoublePoint>>> algorithms = new ArrayList<Pair<String, Clusterer<DoublePoint>>>(); List<Pair<String, Clusterer<DoublePoint>>> algorithms = new ArrayList<Pair<String, Clusterer<DoublePoint>>>();
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("KMeans\n(k=2)", new KMeansPlusPlusClusterer<DoublePoint>(2))); algorithms.add(new Pair<String, Clusterer<DoublePoint>>("KMeans\n(k=2)",
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("KMeans\n(k=3)", new KMeansPlusPlusClusterer<DoublePoint>(3))); new KMeansPlusPlusClusterer<DoublePoint>(2)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("FuzzyKMeans\n(k=3, fuzzy=2)", new FuzzyKMeansClusterer<DoublePoint>(3, 2))); algorithms.add(new Pair<String, Clusterer<DoublePoint>>("KMeans\n(k=3)",
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("FuzzyKMeans\n(k=3, fuzzy=10)", new FuzzyKMeansClusterer<DoublePoint>(3, 10))); new KMeansPlusPlusClusterer<DoublePoint>(3)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("DBSCAN\n(eps=.1, min=3)", new DBSCANClusterer<DoublePoint>(0.1, 3))); algorithms.add(new Pair<String, Clusterer<DoublePoint>>("FuzzyKMeans\n(k=3, fuzzy=2)",
new FuzzyKMeansClusterer<DoublePoint>(3, 2)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("FuzzyKMeans\n(k=3, fuzzy=10)",
new FuzzyKMeansClusterer<DoublePoint>(3, 10)));
algorithms.add(new Pair<String, Clusterer<DoublePoint>>("DBSCAN\n(eps=.1, min=3)",
new DBSCANClusterer<DoublePoint>(0.1, 3)));
GridBagConstraints c = new GridBagConstraints(); GridBagConstraints c = new GridBagConstraints();
c.fill = GridBagConstraints.VERTICAL; c.fill = GridBagConstraints.VERTICAL;

View File

@ -46,8 +46,6 @@ import org.apache.commons.math4.distribution.ParetoDistribution;
import org.apache.commons.math4.distribution.RealDistribution; import org.apache.commons.math4.distribution.RealDistribution;
import org.apache.commons.math4.distribution.TDistribution; import org.apache.commons.math4.distribution.TDistribution;
import org.apache.commons.math4.distribution.WeibullDistribution; import org.apache.commons.math4.distribution.WeibullDistribution;
import org.apache.commons.math4.random.MersenneTwister;
import org.apache.commons.math4.random.RandomGenerator;
import org.apache.commons.math4.util.FastMath; import org.apache.commons.math4.util.FastMath;
import org.apache.commons.math4.userguide.ExampleUtils.ExampleFrame; import org.apache.commons.math4.userguide.ExampleUtils.ExampleFrame;
@ -242,14 +240,13 @@ public class RealDistributionComparison {
container.add(comp, c); container.add(comp, c);
c.gridx++; c.gridx++;
RandomGenerator rng = new MersenneTwister(0);
comp = createComponent("Levy", 0, 3, comp = createComponent("Levy", 0, 3,
new String[] { "c=0.5", "c=1", "c=2", "c=4", "c=8" }, new String[] { "c=0.5", "c=1", "c=2", "c=4", "c=8" },
new LevyDistribution(rng, 0, 0.5), new LevyDistribution(0, 0.5),
new LevyDistribution(rng, 0, 1), new LevyDistribution(0, 1),
new LevyDistribution(rng, 0, 2), new LevyDistribution(0, 2),
new LevyDistribution(rng, 0, 4), new LevyDistribution(0, 4),
new LevyDistribution(rng, 0, 8)); new LevyDistribution(0, 8));
container.add(comp, c); container.add(comp, c);
c.gridy++; c.gridy++;

View File

@ -20,6 +20,8 @@ package org.apache.commons.math4.userguide.sofm;
import org.apache.commons.math4.geometry.euclidean.threed.Vector3D; import org.apache.commons.math4.geometry.euclidean.threed.Vector3D;
import org.apache.commons.math4.geometry.euclidean.threed.Rotation; import org.apache.commons.math4.geometry.euclidean.threed.Rotation;
import org.apache.commons.math4.random.UnitSphereRandomVectorGenerator; import org.apache.commons.math4.random.UnitSphereRandomVectorGenerator;
import org.apache.commons.math4.rng.UniformRandomProvider;
import org.apache.commons.math4.rng.RandomSource;
import org.apache.commons.math4.distribution.RealDistribution; import org.apache.commons.math4.distribution.RealDistribution;
import org.apache.commons.math4.distribution.UniformRealDistribution; import org.apache.commons.math4.distribution.UniformRealDistribution;
@ -57,11 +59,13 @@ public class ChineseRings {
final UnitSphereRandomVectorGenerator unit final UnitSphereRandomVectorGenerator unit
= new UnitSphereRandomVectorGenerator(2); = new UnitSphereRandomVectorGenerator(2);
final RealDistribution radius1 final UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_19937_C);
final RealDistribution.Sampler radius1
= new UniformRealDistribution(radiusRing1 - halfWidthRing1, = new UniformRealDistribution(radiusRing1 - halfWidthRing1,
radiusRing1 + halfWidthRing1); radiusRing1 + halfWidthRing1).createSampler(rng);
final RealDistribution widthRing1 final RealDistribution.Sampler widthRing1
= new UniformRealDistribution(-halfWidthRing1, halfWidthRing1); = new UniformRealDistribution(-halfWidthRing1, halfWidthRing1).createSampler(rng);
for (int i = 0; i < numPointsRing1; i++) { for (int i = 0; i < numPointsRing1; i++) {
final double[] v = unit.nextVector(); final double[] v = unit.nextVector();
@ -72,11 +76,11 @@ public class ChineseRings {
widthRing1.sample()); widthRing1.sample());
} }
final RealDistribution radius2 final RealDistribution.Sampler radius2
= new UniformRealDistribution(radiusRing2 - halfWidthRing2, = new UniformRealDistribution(radiusRing2 - halfWidthRing2,
radiusRing2 + halfWidthRing2); radiusRing2 + halfWidthRing2).createSampler(rng);
final RealDistribution widthRing2 final RealDistribution.Sampler widthRing2
= new UniformRealDistribution(-halfWidthRing2, halfWidthRing2); = new UniformRealDistribution(-halfWidthRing2, halfWidthRing2).createSampler(rng);
for (int i = 0; i < numPointsRing2; i++) { for (int i = 0; i < numPointsRing2; i++) {
final double[] v = unit.nextVector(); final double[] v = unit.nextVector();