diff --git a/machine-learning/src/main/java/com/baeldung/ml/kmeans/Distance.java b/machine-learning/src/main/java/com/baeldung/ml/kmeans/Distance.java index 37e1c8492e..88275d67bb 100644 --- a/machine-learning/src/main/java/com/baeldung/ml/kmeans/Distance.java +++ b/machine-learning/src/main/java/com/baeldung/ml/kmeans/Distance.java @@ -8,13 +8,13 @@ import java.util.Map; */ public interface Distance { - /** - * Calculates the distance between two feature vectors. - * - * @param f1 The first set of features. - * @param f2 The second set of features. - * @return Calculated distance. - * @throws IllegalArgumentException If the given feature vectors are invalid. - */ - double calculate(Map f1, Map f2); + /** + * Calculates the distance between two feature vectors. + * + * @param f1 The first set of features. + * @param f2 The second set of features. + * @return Calculated distance. + * @throws IllegalArgumentException If the given feature vectors are invalid. + */ + double calculate(Map f1, Map f2); } diff --git a/machine-learning/src/main/java/com/baeldung/ml/kmeans/Errors.java b/machine-learning/src/main/java/com/baeldung/ml/kmeans/Errors.java index 4b973c1146..25e70470c1 100644 --- a/machine-learning/src/main/java/com/baeldung/ml/kmeans/Errors.java +++ b/machine-learning/src/main/java/com/baeldung/ml/kmeans/Errors.java @@ -8,16 +8,16 @@ import java.util.Map; */ public class Errors { - public static double sse(Map> clustered, Distance distance) { - double sum = 0; - for (Map.Entry> entry : clustered.entrySet()) { - Centroid centroid = entry.getKey(); - for (Record record : entry.getValue()) { - double d = distance.calculate(centroid.getCoordinates(), record.getFeatures()); - sum += Math.pow(d, 2); - } - } + public static double sse(Map> clustered, Distance distance) { + double sum = 0; + for (Map.Entry> entry : clustered.entrySet()) { + Centroid centroid = entry.getKey(); + for (Record record : entry.getValue()) { + double d = distance.calculate(centroid.getCoordinates(), record.getFeatures()); + sum += Math.pow(d, 2); + } + } - return sum; - } + return sum; + } } diff --git a/machine-learning/src/main/java/com/baeldung/ml/kmeans/EuclideanDistance.java b/machine-learning/src/main/java/com/baeldung/ml/kmeans/EuclideanDistance.java index 1946c508b4..7efc6e617b 100644 --- a/machine-learning/src/main/java/com/baeldung/ml/kmeans/EuclideanDistance.java +++ b/machine-learning/src/main/java/com/baeldung/ml/kmeans/EuclideanDistance.java @@ -7,19 +7,19 @@ import java.util.Map; */ public class EuclideanDistance implements Distance { - @Override - public double calculate(Map f1, Map f2) { - if (f1 == null || f2 == null) - throw new IllegalArgumentException("Feature vectors can't be null"); + @Override + public double calculate(Map f1, Map f2) { + if (f1 == null || f2 == null) + throw new IllegalArgumentException("Feature vectors can't be null"); - double sum = 0; - for (String key : f1.keySet()) { - Double v1 = f1.get(key); - Double v2 = f2.get(key); + double sum = 0; + for (String key : f1.keySet()) { + Double v1 = f1.get(key); + Double v2 = f2.get(key); - if (v1 != null && v2 != null) sum += Math.pow(v1 - v2, 2); - } + if (v1 != null && v2 != null) sum += Math.pow(v1 - v2, 2); + } - return Math.sqrt(sum); - } + return Math.sqrt(sum); + } } diff --git a/machine-learning/src/main/java/com/baeldung/ml/kmeans/LastFm.java b/machine-learning/src/main/java/com/baeldung/ml/kmeans/LastFm.java index cdaf2170cd..0ff9d3cff4 100644 --- a/machine-learning/src/main/java/com/baeldung/ml/kmeans/LastFm.java +++ b/machine-learning/src/main/java/com/baeldung/ml/kmeans/LastFm.java @@ -13,40 +13,40 @@ import static java.util.stream.Collectors.toSet; public class LastFm { - private static OkHttpClient okHttp = new OkHttpClient.Builder() - .addInterceptor(new LastFmService.Authenticator("put your API key here")) - .build(); + private static OkHttpClient okHttp = new OkHttpClient.Builder() + .addInterceptor(new LastFmService.Authenticator("put your API key here")) + .build(); - private static Retrofit retrofit = new Retrofit.Builder().client(okHttp) - .addConverterFactory(JacksonConverterFactory.create()) - .baseUrl("http://ws.audioscrobbler.com/") - .build(); + private static Retrofit retrofit = new Retrofit.Builder().client(okHttp) + .addConverterFactory(JacksonConverterFactory.create()) + .baseUrl("http://ws.audioscrobbler.com/") + .build(); - private static LastFmService lastFm = retrofit.create(LastFmService.class); + private static LastFmService lastFm = retrofit.create(LastFmService.class); - private static ObjectMapper mapper = new ObjectMapper(); + private static ObjectMapper mapper = new ObjectMapper(); - public static void main(String[] args) throws IOException { + public static void main(String[] args) throws IOException { List artists = getTop100Artists(); Set tags = getTop100Tags(); List records = datasetWithTaggedArtists(artists, tags); Map> clusters = KMeans.fit(records, 7, new EuclideanDistance(), 1000); - // Print the cluster configuration - clusters.forEach((key, value) -> { - System.out.println("------------------------------ CLUSTER -----------------------------------"); + // Print the cluster configuration + clusters.forEach((key, value) -> { + System.out.println("------------------------------ CLUSTER -----------------------------------"); - System.out.println(sortedCentroid(key)); - String members = String.join(", ", value.stream().map(Record::getDescription).collect(toSet())); - System.out.print(members); + System.out.println(sortedCentroid(key)); + String members = String.join(", ", value.stream().map(Record::getDescription).collect(toSet())); + System.out.print(members); - System.out.println(); - System.out.println(); - }); + System.out.println(); + System.out.println(); + }); Map json = convertToD3CompatibleMap(clusters); System.out.println(mapper.writeValueAsString(json)); - } + } private static Map convertToD3CompatibleMap(Map> clusters) { Map json = new HashMap<>(); @@ -69,45 +69,45 @@ public class LastFm { } private static String dominantGenre(Centroid centroid) { - return centroid.getCoordinates().keySet().stream().limit(2).collect(Collectors.joining(", ")); - } + return centroid.getCoordinates().keySet().stream().limit(2).collect(Collectors.joining(", ")); + } - private static Centroid sortedCentroid(Centroid key) { - List> entries = new ArrayList<>(key.getCoordinates().entrySet()); - entries.sort((e1, e2) -> e2.getValue().compareTo(e1.getValue())); + private static Centroid sortedCentroid(Centroid key) { + List> entries = new ArrayList<>(key.getCoordinates().entrySet()); + entries.sort((e1, e2) -> e2.getValue().compareTo(e1.getValue())); - Map sorted = new LinkedHashMap<>(); - for (Map.Entry entry : entries) { - sorted.put(entry.getKey(), entry.getValue()); - } + Map sorted = new LinkedHashMap<>(); + for (Map.Entry entry : entries) { + sorted.put(entry.getKey(), entry.getValue()); + } - return new Centroid(sorted); - } + return new Centroid(sorted); + } - private static List datasetWithTaggedArtists(List artists, - Set topTags) throws IOException { - List records = new ArrayList<>(); - for (String artist : artists) { - Map tags = lastFm.topTagsFor(artist).execute().body().all(); + private static List datasetWithTaggedArtists(List artists, + Set topTags) throws IOException { + List records = new ArrayList<>(); + for (String artist : artists) { + Map tags = lastFm.topTagsFor(artist).execute().body().all(); - // Only keep popular tags. - tags.entrySet().removeIf(e -> !topTags.contains(e.getKey())); + // Only keep popular tags. + tags.entrySet().removeIf(e -> !topTags.contains(e.getKey())); - records.add(new Record(artist, tags)); - } - return records; - } + records.add(new Record(artist, tags)); + } + return records; + } - private static Set getTop100Tags() throws IOException { - return lastFm.topTags().execute().body().all(); - } + private static Set getTop100Tags() throws IOException { + return lastFm.topTags().execute().body().all(); + } - private static List getTop100Artists() throws IOException { - List artists = new ArrayList<>(); - for (int i = 1; i <= 2; i++) { - artists.addAll(lastFm.topArtists(i).execute().body().all()); - } + private static List getTop100Artists() throws IOException { + List artists = new ArrayList<>(); + for (int i = 1; i <= 2; i++) { + artists.addAll(lastFm.topArtists(i).execute().body().all()); + } - return artists; - } + return artists; + } } diff --git a/machine-learning/src/main/java/com/baeldung/ml/kmeans/LastFmService.java b/machine-learning/src/main/java/com/baeldung/ml/kmeans/LastFmService.java index 0cc8e9e285..4e2bf6bd92 100644 --- a/machine-learning/src/main/java/com/baeldung/ml/kmeans/LastFmService.java +++ b/machine-learning/src/main/java/com/baeldung/ml/kmeans/LastFmService.java @@ -8,6 +8,8 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonProperty; import okhttp3.HttpUrl; import okhttp3.Interceptor; import okhttp3.Request; @@ -16,86 +18,89 @@ import retrofit2.Call; import retrofit2.http.GET; import retrofit2.http.Query; +import static com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility.ANY; import static java.util.stream.Collectors.toList; public interface LastFmService { - @GET("/2.0/?method=chart.gettopartists&format=json&limit=50") - Call topArtists(@Query("page") int page); + @GET("/2.0/?method=chart.gettopartists&format=json&limit=50") + Call topArtists(@Query("page") int page); - @GET("/2.0/?method=artist.gettoptags&format=json&limit=20&autocorrect=1") - Call topTagsFor(@Query("artist") String artist); + @GET("/2.0/?method=artist.gettoptags&format=json&limit=20&autocorrect=1") + Call topTagsFor(@Query("artist") String artist); - @GET("/2.0/?method=chart.gettoptags&format=json&limit=100") - Call topTags(); + @GET("/2.0/?method=chart.gettoptags&format=json&limit=100") + Call topTags(); - /** - * HTTP interceptor to intercept all HTTP requests and add the API key to them. - */ - class Authenticator implements Interceptor { + /** + * HTTP interceptor to intercept all HTTP requests and add the API key to them. + */ + class Authenticator implements Interceptor { - private final String apiKey; + private final String apiKey; - Authenticator(String apiKey) { - this.apiKey = apiKey; - } + Authenticator(String apiKey) { + this.apiKey = apiKey; + } - @Override - public Response intercept(Chain chain) throws IOException { - HttpUrl url = chain.request().url().newBuilder().addQueryParameter("api_key", apiKey).build(); - Request request = chain.request().newBuilder().url(url).build(); + @Override + public Response intercept(Chain chain) throws IOException { + HttpUrl url = chain.request().url().newBuilder().addQueryParameter("api_key", apiKey).build(); + Request request = chain.request().newBuilder().url(url).build(); - return chain.proceed(request); - } - } + return chain.proceed(request); + } + } - @JsonAutoDetect(fieldVisibility = ANY) - class TopTags { + @JsonAutoDetect(fieldVisibility = ANY) + class TopTags { - private Map tags; + private Map tags; - @SuppressWarnings("unchecked") - public Set all() { - List> topTags = (List>) tags.get("tag"); - return topTags.stream().map(e -> ((String) e.get("name"))).collect(Collectors.toSet()); - } - } + @SuppressWarnings("unchecked") + public Set all() { + List> topTags = (List>) tags.get("tag"); + return topTags.stream().map(e -> ((String) e.get("name"))).collect(Collectors.toSet()); + } + } - @JsonAutoDetect(fieldVisibility = ANY) - class Tags { + @JsonAutoDetect(fieldVisibility = ANY) + class Tags { - @JsonProperty("toptags") - private Map topTags; + @JsonProperty("toptags") + private Map topTags; - @SuppressWarnings("unchecked") - public Map all() { - try { - Map all = new HashMap<>(); - List> tags = (List>) topTags.get("tag"); - for (Map tag : tags) { - all.put(((String) tag.get("name")), ((Integer) tag.get("count")).doubleValue()); - } + @SuppressWarnings("unchecked") + public Map all() { + try { + Map all = new HashMap<>(); + List> tags = (List>) topTags.get("tag"); + for (Map tag : tags) { + all.put(((String) tag.get("name")), ((Integer) tag.get("count")).doubleValue()); + } - return all; - } catch (Exception e) { - return Collections.emptyMap(); - } - } - } + return all; + } + catch (Exception e) { + return Collections.emptyMap(); + } + } + } - @JsonAutoDetect(fieldVisibility = ANY) - class Artists { + @JsonAutoDetect(fieldVisibility = ANY) + class Artists { - private Map artists; + private Map artists; - @SuppressWarnings("unchecked") - public List all() { - try { - List> artists = (List>) this.artists.get("artist"); - return artists.stream().map(e -> ((String) e.get("name"))).collect(toList()); - } catch (Exception e) { - return Collections.emptyList(); - } - } - } + @SuppressWarnings("unchecked") + public List all() { + try { + List> artists = (List>) this.artists.get("artist"); + return artists.stream().map(e -> ((String) e.get("name"))).collect(toList()); + } + catch (Exception e) { + return Collections.emptyList(); + } + } + } } diff --git a/machine-learning/src/main/java/com/baeldung/ml/kmeans/Record.java b/machine-learning/src/main/java/com/baeldung/ml/kmeans/Record.java index 7208526136..0e936d49e6 100644 --- a/machine-learning/src/main/java/com/baeldung/ml/kmeans/Record.java +++ b/machine-learning/src/main/java/com/baeldung/ml/kmeans/Record.java @@ -9,52 +9,52 @@ import java.util.Objects; */ public class Record { - /** - * The record description. For example, this can be the artist name for the famous musician - * example. - */ - private final String description; + /** + * The record description. For example, this can be the artist name for the famous musician + * example. + */ + private final String description; - /** - * Encapsulates all attributes and their corresponding values, i.e. features. - */ - private final Map features; + /** + * Encapsulates all attributes and their corresponding values, i.e. features. + */ + private final Map features; - public Record(String description, Map features) { - this.description = description; - this.features = features; - } + public Record(String description, Map features) { + this.description = description; + this.features = features; + } - public Record(Map features) { - this("", features); - } + public Record(Map features) { + this("", features); + } - public String getDescription() { - return description; - } + public String getDescription() { + return description; + } - public Map getFeatures() { - return features; - } + public Map getFeatures() { + return features; + } - @Override - public String toString() { - String prefix = description == null || description.trim().isEmpty() ? "Record" : description; + @Override + public String toString() { + String prefix = description == null || description.trim().isEmpty() ? "Record" : description; - return prefix + ": " + features; - } + return prefix + ": " + features; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Record record = (Record) o; - return Objects.equals(getDescription(), record.getDescription()) && - Objects.equals(getFeatures(), record.getFeatures()); - } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Record record = (Record) o; + return Objects.equals(getDescription(), record.getDescription()) && + Objects.equals(getFeatures(), record.getFeatures()); + } - @Override - public int hashCode() { - return Objects.hash(getDescription(), getFeatures()); - } + @Override + public int hashCode() { + return Objects.hash(getDescription(), getFeatures()); + } }