diff --git a/algorithms-miscellaneous-3/pom.xml b/algorithms-miscellaneous-3/pom.xml
index 802cf74320..5999d33c86 100644
--- a/algorithms-miscellaneous-3/pom.xml
+++ b/algorithms-miscellaneous-3/pom.xml
@@ -1,5 +1,5 @@
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
4.0.0
algorithms-miscellaneous-3
0.0.1-SNAPSHOT
@@ -18,17 +18,28 @@
${org.assertj.core.version}
test
-
+
- org.apache.commons
- commons-collections4
- ${commons-collections4.version}
+ org.apache.commons
+ commons-collections4
+ ${commons-collections4.version}
-
+
- com.google.guava
- guava
- ${guava.version}
+ com.google.guava
+ guava
+ ${guava.version}
+
+
+
+ com.squareup.retrofit2
+ retrofit
+ ${retrofit.version}
+
+
+ com.squareup.retrofit2
+ converter-jackson
+ ${retrofit.version}
@@ -61,5 +72,6 @@
3.9.0
4.3
28.0-jre
+ 2.6.0
\ No newline at end of file
diff --git a/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/Centroid.java b/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/Centroid.java
new file mode 100644
index 0000000000..523d5b56a5
--- /dev/null
+++ b/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/Centroid.java
@@ -0,0 +1,45 @@
+package com.baeldung.algorithms.kmeans;
+
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Encapsulates all coordinates for a particular cluster centroid.
+ */
+public class Centroid {
+
+ /**
+ * The centroid coordinates.
+ */
+ private final Map coordinates;
+
+ public Centroid(Map coordinates) {
+ this.coordinates = coordinates;
+ }
+
+ public Map getCoordinates() {
+ return coordinates;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Centroid centroid = (Centroid) o;
+ return Objects.equals(getCoordinates(), centroid.getCoordinates());
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(getCoordinates());
+ }
+
+ @Override
+ public String toString() {
+ return "Centroid " + coordinates;
+ }
+}
diff --git a/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/Distance.java b/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/Distance.java
new file mode 100644
index 0000000000..30723cb6b3
--- /dev/null
+++ b/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/Distance.java
@@ -0,0 +1,20 @@
+package com.baeldung.algorithms.kmeans;
+
+import java.util.Map;
+
+/**
+ * Defines a contract to calculate distance between two feature vectors. The less the
+ * calculated distance, the more two items are similar to each other.
+ */
+public interface Distance {
+
+ /**
+ * Calculates the distance between two feature vectors.
+ *
+ * @param f1 The first set of features.
+ * @param f2 The second set of features.
+ * @return Calculated distance.
+ * @throws IllegalArgumentException If the given feature vectors are invalid.
+ */
+ double calculate(Map f1, Map f2);
+}
diff --git a/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/Errors.java b/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/Errors.java
new file mode 100644
index 0000000000..3228876051
--- /dev/null
+++ b/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/Errors.java
@@ -0,0 +1,23 @@
+package com.baeldung.algorithms.kmeans;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Encapsulates methods to calculates errors between centroid and the cluster members.
+ */
+public class Errors {
+
+ public static double sse(Map> clustered, Distance distance) {
+ double sum = 0;
+ for (Map.Entry> entry : clustered.entrySet()) {
+ Centroid centroid = entry.getKey();
+ for (Record record : entry.getValue()) {
+ double d = distance.calculate(centroid.getCoordinates(), record.getFeatures());
+ sum += Math.pow(d, 2);
+ }
+ }
+
+ return sum;
+ }
+}
diff --git a/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/EuclideanDistance.java b/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/EuclideanDistance.java
new file mode 100644
index 0000000000..193d9afed1
--- /dev/null
+++ b/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/EuclideanDistance.java
@@ -0,0 +1,26 @@
+package com.baeldung.algorithms.kmeans;
+
+import java.util.Map;
+
+/**
+ * Calculates the distance between two items using the Euclidean formula.
+ */
+public class EuclideanDistance implements Distance {
+
+ @Override
+ public double calculate(Map f1, Map f2) {
+ if (f1 == null || f2 == null) {
+ throw new IllegalArgumentException("Feature vectors can't be null");
+ }
+
+ double sum = 0;
+ for (String key : f1.keySet()) {
+ Double v1 = f1.get(key);
+ Double v2 = f2.get(key);
+
+ if (v1 != null && v2 != null) sum += Math.pow(v1 - v2, 2);
+ }
+
+ return Math.sqrt(sum);
+ }
+}
diff --git a/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/KMeans.java b/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/KMeans.java
new file mode 100644
index 0000000000..1fb8541ff9
--- /dev/null
+++ b/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/KMeans.java
@@ -0,0 +1,236 @@
+package com.baeldung.algorithms.kmeans;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+import static java.util.stream.Collectors.toList;
+import static java.util.stream.Collectors.toSet;
+
+/**
+ * Encapsulates an implementation of KMeans clustering algorithm.
+ *
+ * @author Ali Dehghani
+ */
+public class KMeans {
+
+ private KMeans() {
+ throw new IllegalAccessError("You shouldn't call this constructor");
+ }
+
+ /**
+ * Will be used to generate random numbers.
+ */
+ private static final Random random = new Random();
+
+ /**
+ * Performs the K-Means clustering algorithm on the given dataset.
+ *
+ * @param records The dataset.
+ * @param k Number of Clusters.
+ * @param distance To calculate the distance between two items.
+ * @param maxIterations Upper bound for the number of iterations.
+ * @return K clusters along with their features.
+ */
+ public static Map> fit(List records, int k, Distance distance, int maxIterations) {
+ applyPreconditions(records, k, distance, maxIterations);
+
+ List centroids = randomCentroids(records, k);
+ Map> clusters = new HashMap<>();
+ Map> lastState = new HashMap<>();
+
+ // iterate for a pre-defined number of times
+ for (int i = 0; i < maxIterations; i++) {
+ boolean isLastIteration = i == maxIterations - 1;
+
+ // in each iteration we should find the nearest centroid for each record
+ for (Record record : records) {
+ Centroid centroid = nearestCentroid(record, centroids, distance);
+ assignToCluster(clusters, record, centroid);
+ }
+
+ // if the assignment does not change, then the algorithm terminates
+ boolean shouldTerminate = isLastIteration || clusters.equals(lastState);
+ lastState = clusters;
+ if (shouldTerminate) {
+ break;
+ }
+
+ // at the end of each iteration we should relocate the centroids
+ centroids = relocateCentroids(clusters);
+ clusters = new HashMap<>();
+ }
+
+ return lastState;
+ }
+
+ /**
+ * Move all cluster centroids to the average of all assigned features.
+ *
+ * @param clusters The current cluster configuration.
+ * @return Collection of new and relocated centroids.
+ */
+ private static List relocateCentroids(Map> clusters) {
+ return clusters
+ .entrySet()
+ .stream()
+ .map(e -> average(e.getKey(), e.getValue()))
+ .collect(toList());
+ }
+
+ /**
+ * Moves the given centroid to the average position of all assigned features. If
+ * the centroid has no feature in its cluster, then there would be no need for a
+ * relocation. Otherwise, for each entry we calculate the average of all records
+ * first by summing all the entries and then dividing the final summation value by
+ * the number of records.
+ *
+ * @param centroid The centroid to move.
+ * @param records The assigned features.
+ * @return The moved centroid.
+ */
+ private static Centroid average(Centroid centroid, List records) {
+ // if this cluster is empty, then we shouldn't move the centroid
+ if (records == null || records.isEmpty()) {
+ return centroid;
+ }
+
+ // Since some records don't have all possible attributes, we initialize
+ // average coordinates equal to current centroid coordinates
+ Map average = centroid.getCoordinates();
+
+ // The average function works correctly if we clear all coordinates corresponding
+ // to present record attributes
+ records
+ .stream()
+ .flatMap(e -> e
+ .getFeatures()
+ .keySet()
+ .stream())
+ .forEach(k -> average.put(k, 0.0));
+
+ for (Record record : records) {
+ record
+ .getFeatures()
+ .forEach((k, v) -> average.compute(k, (k1, currentValue) -> v + currentValue));
+ }
+
+ average.forEach((k, v) -> average.put(k, v / records.size()));
+
+ return new Centroid(average);
+ }
+
+ /**
+ * Assigns a feature vector to the given centroid. If this is the first assignment for this centroid,
+ * first we should create the list.
+ *
+ * @param clusters The current cluster configuration.
+ * @param record The feature vector.
+ * @param centroid The centroid.
+ */
+ private static void assignToCluster(Map> clusters, Record record, Centroid centroid) {
+ clusters.compute(centroid, (key, list) -> {
+ if (list == null) {
+ list = new ArrayList<>();
+ }
+
+ list.add(record);
+ return list;
+ });
+ }
+
+ /**
+ * With the help of the given distance calculator, iterates through centroids and finds the
+ * nearest one to the given record.
+ *
+ * @param record The feature vector to find a centroid for.
+ * @param centroids Collection of all centroids.
+ * @param distance To calculate the distance between two items.
+ * @return The nearest centroid to the given feature vector.
+ */
+ private static Centroid nearestCentroid(Record record, List centroids, Distance distance) {
+ double minimumDistance = Double.MAX_VALUE;
+ Centroid nearest = null;
+
+ for (Centroid centroid : centroids) {
+ double currentDistance = distance.calculate(record.getFeatures(), centroid.getCoordinates());
+
+ if (currentDistance < minimumDistance) {
+ minimumDistance = currentDistance;
+ nearest = centroid;
+ }
+ }
+
+ return nearest;
+ }
+
+ /**
+ * Generates k random centroids. Before kicking-off the centroid generation process,
+ * first we calculate the possible value range for each attribute. Then when
+ * we're going to generate the centroids, we generate random coordinates in
+ * the [min, max] range for each attribute.
+ *
+ * @param records The dataset which helps to calculate the [min, max] range for
+ * each attribute.
+ * @param k Number of clusters.
+ * @return Collections of randomly generated centroids.
+ */
+ private static List randomCentroids(List records, int k) {
+ List centroids = new ArrayList<>();
+ Map maxs = new HashMap<>();
+ Map mins = new HashMap<>();
+
+ for (Record record : records) {
+ record
+ .getFeatures()
+ .forEach((key, value) -> {
+ // compares the value with the current max and choose the bigger value between them
+ maxs.compute(key, (k1, max) -> max == null || value > max ? value : max);
+
+ // compare the value with the current min and choose the smaller value between them
+ mins.compute(key, (k1, min) -> min == null || value < min ? value : min);
+ });
+ }
+
+ Set attributes = records
+ .stream()
+ .flatMap(e -> e
+ .getFeatures()
+ .keySet()
+ .stream())
+ .collect(toSet());
+ for (int i = 0; i < k; i++) {
+ Map coordinates = new HashMap<>();
+ for (String attribute : attributes) {
+ double max = maxs.get(attribute);
+ double min = mins.get(attribute);
+ coordinates.put(attribute, random.nextDouble() * (max - min) + min);
+ }
+
+ centroids.add(new Centroid(coordinates));
+ }
+
+ return centroids;
+ }
+
+ private static void applyPreconditions(List records, int k, Distance distance, int maxIterations) {
+ if (records == null || records.isEmpty()) {
+ throw new IllegalArgumentException("The dataset can't be empty");
+ }
+
+ if (k <= 1) {
+ throw new IllegalArgumentException("It doesn't make sense to have less than or equal to 1 cluster");
+ }
+
+ if (distance == null) {
+ throw new IllegalArgumentException("The distance calculator is required");
+ }
+
+ if (maxIterations <= 0) {
+ throw new IllegalArgumentException("Max iterations should be a positive number");
+ }
+ }
+}
diff --git a/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/LastFm.java b/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/LastFm.java
new file mode 100644
index 0000000000..4694a845af
--- /dev/null
+++ b/algorithms-miscellaneous-3/src/main/java/com/baeldung/algorithms/kmeans/LastFm.java
@@ -0,0 +1,144 @@
+package com.baeldung.algorithms.kmeans;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import okhttp3.OkHttpClient;
+import retrofit2.Retrofit;
+import retrofit2.converter.jackson.JacksonConverterFactory;
+
+import static java.util.stream.Collectors.toSet;
+
+public class LastFm {
+
+ private static OkHttpClient okHttp = new OkHttpClient.Builder()
+ .addInterceptor(new LastFmService.Authenticator("put your API key here"))
+ .build();
+
+ private static Retrofit retrofit = new Retrofit.Builder()
+ .client(okHttp)
+ .addConverterFactory(JacksonConverterFactory.create())
+ .baseUrl("http://ws.audioscrobbler.com/")
+ .build();
+
+ private static LastFmService lastFm = retrofit.create(LastFmService.class);
+
+ private static ObjectMapper mapper = new ObjectMapper();
+
+ public static void main(String[] args) throws IOException {
+ List artists = getTop100Artists();
+ Set tags = getTop100Tags();
+ List records = datasetWithTaggedArtists(artists, tags);
+
+ Map> clusters = KMeans.fit(records, 7, new EuclideanDistance(), 1000);
+ // Print the cluster configuration
+ clusters.forEach((key, value) -> {
+ System.out.println("------------------------------ CLUSTER -----------------------------------");
+
+ System.out.println(sortedCentroid(key));
+ String members = String.join(", ", value
+ .stream()
+ .map(Record::getDescription)
+ .collect(toSet()));
+ System.out.print(members);
+
+ System.out.println();
+ System.out.println();
+ });
+
+ Map json = convertToD3CompatibleMap(clusters);
+ System.out.println(mapper.writeValueAsString(json));
+ }
+
+ private static Map convertToD3CompatibleMap(Map> clusters) {
+ Map json = new HashMap<>();
+ json.put("name", "Musicians");
+ List