From 4be687e19dfddd3de721003e0514b07386549908 Mon Sep 17 00:00:00 2001 From: Thomas Neidhart Date: Tue, 21 May 2013 17:34:50 +0000 Subject: [PATCH] Added first sample code to be included in the userguide. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1484885 13f79535-47bb-0310-9956-ffa450edef68 --- .../math3/userguide/ClusteringExamples.java | 314 ++++++++++++++++++ 1 file changed, 314 insertions(+) create mode 100644 src/test/java/org/apache/commons/math3/userguide/ClusteringExamples.java diff --git a/src/test/java/org/apache/commons/math3/userguide/ClusteringExamples.java b/src/test/java/org/apache/commons/math3/userguide/ClusteringExamples.java new file mode 100644 index 000000000..409d25e8d --- /dev/null +++ b/src/test/java/org/apache/commons/math3/userguide/ClusteringExamples.java @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.userguide; + +import java.awt.Color; +import java.awt.Dimension; +import java.awt.Graphics; +import java.awt.Graphics2D; +import java.awt.GridBagConstraints; +import java.awt.GridBagLayout; +import java.awt.Insets; +import java.awt.RenderingHints; +import java.awt.Shape; +import java.awt.geom.Ellipse2D; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import javax.swing.JComponent; +import javax.swing.JFrame; +import javax.swing.JTextArea; +import javax.swing.SwingUtilities; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.geometry.euclidean.twod.Vector2D; +import org.apache.commons.math3.ml.clustering.CentroidCluster; +import org.apache.commons.math3.ml.clustering.Cluster; +import org.apache.commons.math3.ml.clustering.Clusterable; +import org.apache.commons.math3.ml.clustering.Clusterer; +import org.apache.commons.math3.ml.clustering.DBSCANClusterer; +import org.apache.commons.math3.ml.clustering.DoublePoint; +import org.apache.commons.math3.ml.clustering.FuzzyKMeansClusterer; +import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer; +import org.apache.commons.math3.random.RandomAdaptor; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.random.SobolSequenceGenerator; +import org.apache.commons.math3.random.Well19937c; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.Pair; + +/** + * Plots clustering results for various algorithms and datasets. + * Based on + * scikit learn. + */ +public class ClusteringExamples { + + public static List makeCircles(int samples, boolean shuffle, double noise, double factor, final RandomGenerator random) { + if (factor < 0 || factor > 1) { + throw new IllegalArgumentException(); + } + + NormalDistribution dist = new NormalDistribution(random, 0.0, noise, 1e-9); + + List points = new ArrayList(); + double range = 2.0 * FastMath.PI; + double step = range / (samples / 2.0 + 1); + for (double angle = 0; angle < range; angle += step) { + Vector2D outerCircle = new Vector2D(FastMath.cos(angle), FastMath.sin(angle)); + Vector2D innerCircle = outerCircle.scalarMultiply(factor); + + points.add(outerCircle.add(generateNoiseVector(dist))); + points.add(innerCircle.add(generateNoiseVector(dist))); + } + + if (shuffle) { + Collections.shuffle(points, new RandomAdaptor(random)); + } + + return points; + } + + public static List makeMoons(int samples, boolean shuffle, double noise, RandomGenerator random) { + NormalDistribution dist = new NormalDistribution(random, 0.0, noise, 1e-9); + + int nSamplesOut = samples / 2; + int nSamplesIn = samples - nSamplesOut; + + List points = new ArrayList(); + double range = FastMath.PI; + double step = range / (nSamplesOut / 2.0); + for (double angle = 0; angle < range; angle += step) { + Vector2D outerCircle = new Vector2D(FastMath.cos(angle), FastMath.sin(angle)); + points.add(outerCircle.add(generateNoiseVector(dist))); + } + + step = range / (nSamplesIn / 2.0); + for (double angle = 0; angle < range; angle += step) { + Vector2D innerCircle = new Vector2D(1 - FastMath.cos(angle), 1 - FastMath.sin(angle) - 0.5); + points.add(innerCircle.add(generateNoiseVector(dist))); + } + + if (shuffle) { + Collections.shuffle(points, new RandomAdaptor(random)); + } + + return points; + } + + public static List makeBlobs(int samples, int centers, double clusterStd, + double min, double max, boolean shuffle, RandomGenerator random) { + + NormalDistribution dist = new NormalDistribution(random, 0.0, clusterStd, 1e-9); + + double range = max - min; + Vector2D[] centerPoints = new Vector2D[centers]; + for (int i = 0; i < centers; i++) { + double x = random.nextDouble() * range + min; + double y = random.nextDouble() * range + min; + centerPoints[i] = new Vector2D(x, y); + } + + int[] nSamplesPerCenter = new int[centers]; + int count = (samples / centers) * centers; + Arrays.fill(nSamplesPerCenter, count); + + for (int i = 0; i < samples % centers; i++) { + nSamplesPerCenter[i]++; + } + + List points = new ArrayList(); + for (int i = 0; i < centers; i++) { + for (int j = 0; j < nSamplesPerCenter[i]; j++) { + Vector2D point = new Vector2D(dist.sample(), dist.sample()); + points.add(point.add(centerPoints[i])); + } + } + + if (shuffle) { + Collections.shuffle(points, new RandomAdaptor(random)); + } + + return points; + } + + public static List makeRandom(int samples) { + SobolSequenceGenerator generator = new SobolSequenceGenerator(2); + generator.skipTo(999999); + List points = new ArrayList(); + for (double i = 0; i < samples; i++) { + double[] vector = generator.nextVector(); + vector[0] = vector[0] * 2 - 1; + vector[1] = vector[1] * 2 - 1; + Vector2D point = new Vector2D(vector); + points.add(point); + } + + return points; + } + + public static Vector2D generateNoiseVector(NormalDistribution distribution) { + return new Vector2D(distribution.sample(), distribution.sample()); + } + + public static List normalize(final List input, double minX, double maxX, double minY, double maxY) { + double rangeX = maxX - minX; + double rangeY = maxY - minY; + List points = new ArrayList(); + for (Vector2D p : input) { + double[] arr = p.toArray(); + arr[0] = (arr[0] - minX) / rangeX * 2 - 1; + arr[1] = (arr[1] - minY) / rangeY * 2 - 1; + points.add(new DoublePoint(arr)); + } + return points; + } + + public static class Display extends JFrame { + + private static final long serialVersionUID = -8846964550416589808L; + + public Display() { + setTitle("Clustering examples"); + setSize(800, 800); + setLocationRelativeTo(null); + setDefaultCloseOperation(EXIT_ON_CLOSE); + + setLayout(new GridBagLayout()); + + int nSamples = 1500; + + RandomGenerator rng = new Well19937c(0); + List> datasets = new ArrayList>(); + + datasets.add(normalize(makeCircles(nSamples, true, 0.04, 0.5, rng), -1, 1, -1, 1)); + datasets.add(normalize(makeMoons(nSamples, true, 0.04, rng), -1, 2, -1, 1)); + datasets.add(normalize(makeBlobs(nSamples, 3, 1.0, -10, 10, true, rng), -12, 12, -12, 12)); + datasets.add(normalize(makeRandom(nSamples), -1, 1, -1, 1)); + + List>> algorithms = new ArrayList>>(); + + algorithms.add(new Pair>("KMeans\n(k=2)", new KMeansPlusPlusClusterer(2))); + algorithms.add(new Pair>("KMeans\n(k=3)", new KMeansPlusPlusClusterer(3))); + algorithms.add(new Pair>("FuzzyKMeans\n(k=3, fuzzy=2)", new FuzzyKMeansClusterer(3, 2))); + algorithms.add(new Pair>("FuzzyKMeans\n(k=3, fuzzy=10)", new FuzzyKMeansClusterer(3, 10))); + algorithms.add(new Pair>("DBSCAN\n(eps=.1, min=3)", new DBSCANClusterer(0.1, 3))); + + GridBagConstraints c = new GridBagConstraints(); + c.fill = GridBagConstraints.VERTICAL; + c.gridx = 0; + c.gridy = 0; + c.insets = new Insets(2, 2, 2, 2); + + for (Pair> pair : algorithms) { + JTextArea text = new JTextArea(pair.getFirst()); + text.setEditable(false); + text.setOpaque(false); + add(text, c); + c.gridx++; + } + c.gridy++; + + for (List dataset : datasets) { + c.gridx = 0; + for (Pair> pair : algorithms) { + long start = System.currentTimeMillis(); + List> clusters = pair.getSecond().cluster(dataset); + long end = System.currentTimeMillis(); + add(new ClusterPlot(clusters, end - start), c); + c.gridx++; + } + c.gridy++; + } + } + } + + public static class ClusterPlot extends JComponent { + + private static final long serialVersionUID = 4546352048750419587L; + + private static double PAD = 10; + + private List> clusters; + private long duration; + + public ClusterPlot(final List> clusters, long duration) { + this.clusters = clusters; + this.duration = duration; + } + + protected void paintComponent(Graphics g) { + super.paintComponent(g); + Graphics2D g2 = (Graphics2D)g; + g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, + RenderingHints.VALUE_ANTIALIAS_ON); + + int w = getWidth(); + int h = getHeight(); + + g2.clearRect(0, 0, w, h); + + g2.setPaint(Color.black); + g2.drawRect(0, 0, w - 1, h - 1); + + int index = 0; + Color[] colors = new Color[] { Color.red, Color.blue, Color.green.darker() }; + for (Cluster cluster : clusters) { + g2.setPaint(colors[index++]); + for (DoublePoint point : cluster.getPoints()) { + Clusterable p = transform(point, w, h); + double[] arr = p.getPoint(); + g2.fill(new Ellipse2D.Double(arr[0] - 2, arr[1] - 2, 4, 4)); + } + + if (cluster instanceof CentroidCluster) { + Clusterable p = transform(((CentroidCluster) cluster).getCenter(), w, h); + double[] arr = p.getPoint(); + Shape s = new Ellipse2D.Double(arr[0] - 4, arr[1] - 4, 8, 8); + g2.fill(s); + g2.setPaint(Color.black); + g2.draw(s); + } + } + + g2.setPaint(Color.black); + g2.drawString(String.format("%.2f s", duration / 1e3), w - 30, h - 5); + } + + @Override + public Dimension getPreferredSize() { + return new Dimension(150, 150); + } + + private Clusterable transform(Clusterable point, int width, int height) { + double[] arr = point.getPoint(); + return new DoublePoint(new double[] { PAD + (arr[0] + 1) / 2.0 * (width - 2 * PAD), + height - PAD - (arr[1] + 1) / 2.0 * (height - 2 * PAD) }); + } + } + + public static void main(String[] args) { + SwingUtilities.invokeLater(new Runnable() { + public void run() { + Display d = new Display(); + d.setVisible(true); + } + }); + } +}