diff --git a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java index 076ac0e65d..36ba657927 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java +++ b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java @@ -24,7 +24,6 @@ import com.google.common.io.Files; public class RedditClassifier { public static int GOOD = 0; public static int BAD = 1; - public static int MIN_SCORE = 7; private final int[] trainCount = { 0, 0 }; private final int[] evalCount = { 0, 0 }; @@ -34,15 +33,16 @@ public class RedditClassifier { private final FeatureVectorEncoder titleEncoder; private final FeatureVectorEncoder domainEncoder; private final int noOfFeatures; - + private final int minScore; private CrossFoldLearner learner; private double accuracy; public RedditClassifier() { - this(150, 1000); + this(150, 1000, 7); } - public RedditClassifier(final int poolSize, final int noOfFeatures) { + public RedditClassifier(final int poolSize, final int noOfFeatures, int minScore) { + this.minScore = minScore; this.noOfFeatures = noOfFeatures; classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2()); classifier.setPoolSize(poolSize); @@ -154,20 +154,15 @@ public class RedditClassifier { internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title domainEncoder.addToVector(theRootDomain, internalVector); - final String[] words = title.split(" "); - // titleEncoder.setProbes(words.length); - - // TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to" - for (final String word : words) { - titleEncoder.addToVector(word, internalVector); - } + final List words = Splitter.on(' ').splitToList(title); + words.stream().filter(word -> word.length() > 2).forEach(word -> titleEncoder.addToVector(word, internalVector)); final String category = extractCategory(Integer.parseInt(numberOfVotes)); return new NamedVector(internalVector, category); } private String extractCategory(final int score) { - return (score < MIN_SCORE) ? "BAD" : "GOOD"; + return (score < minScore) ? "BAD" : "GOOD"; } } diff --git a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditDataCollector.java b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditDataCollector.java index 2d446df12c..d489b58db7 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditDataCollector.java +++ b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditDataCollector.java @@ -27,11 +27,7 @@ public class RedditDataCollector { private final String subreddit; public RedditDataCollector() { - restTemplate = new RestTemplate(); - final List list = new ArrayList(); - list.add(new UserAgentInterceptor()); - restTemplate.setInterceptors(list); - subreddit = "java"; + this("java"); } public RedditDataCollector(String subreddit) { diff --git a/spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java b/spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java index 1bdc843599..d72a689253 100644 --- a/spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java +++ b/spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java @@ -23,7 +23,7 @@ public class RedditClassifierTest { @Test public void givenSmallerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException { - final RedditClassifier classifier = new RedditClassifier(100, 500); + final RedditClassifier classifier = new RedditClassifier(100, 500, 7); classifier.trainClassifier(RedditDataCollector.DATA_FILE); final double result = classifier.getAccuracy(); System.out.println("==== Custom Classifier (small) Accuracy = " + result); @@ -33,7 +33,7 @@ public class RedditClassifierTest { @Test public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException { - final RedditClassifier classifier = new RedditClassifier(250, 2500); + final RedditClassifier classifier = new RedditClassifier(250, 2500, 7); classifier.trainClassifier(RedditDataCollector.DATA_FILE); final double result = classifier.getAccuracy(); System.out.println("==== Custom Classifier (large) Accuracy = " + result);