From cddb1ed3289437740e9dd522dfb08984e3f756fc Mon Sep 17 00:00:00 2001 From: eugenp Date: Wed, 22 Apr 2015 12:37:17 +0300 Subject: [PATCH] minor cleanup --- .../reddit/classifier/RedditClassifier.java | 43 ++++++++++++------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java index 7473dea865..4c58ff67ca 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java +++ b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java @@ -50,7 +50,7 @@ public class RedditClassifier { domainEncoder.setProbes(1); } - public RedditClassifier(int poolSize, int noOfFeatures) { + public RedditClassifier(final int poolSize, final int noOfFeatures) { this.noOfFeatures = noOfFeatures; classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2()); classifier.setPoolSize(poolSize); @@ -60,7 +60,7 @@ public class RedditClassifier { domainEncoder.setProbes(1); } - public void trainClassifier(String fileName) throws IOException { + public void trainClassifier(final String fileName) throws IOException { final List vectors = extractVectors(readDataFile(fileName)); final int size = vectors.size(); final int noOfTraining = (int) (size * 0.8); @@ -77,7 +77,7 @@ public class RedditClassifier { evaluateClassifier(testData); } - public Vector convertPost(String title, String domain, int hour) { + public Vector convertPost(final String title, final String domain, final int hour) { final Vector vector = new RandomAccessSparseVector(noOfFeatures); final List words = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(title); vector.set(0, hour); @@ -89,7 +89,7 @@ public class RedditClassifier { return vector; } - public int classify(Vector features) { + public int classify(final Vector features) { if (learner == null) { learner = classifier.getBest().getPayload().getLearner(); } @@ -102,7 +102,7 @@ public class RedditClassifier { // ==== Private methods - private void evaluateClassifier(List vectors) throws IOException { + private void evaluateClassifier(final List vectors) throws IOException { int category, result; int correct = 0; int wrong = 0; @@ -125,7 +125,7 @@ public class RedditClassifier { this.accuracy = correct / (wrong + correct + 0.0); } - private List readDataFile(String fileName) throws IOException { + private List readDataFile(final String fileName) throws IOException { List lines = Files.readLines(new File(fileName), Charset.forName("utf-8")); if ((lines == null) || (lines.size() == 0)) { new RedditDataCollector().collectData(); @@ -135,7 +135,7 @@ public class RedditClassifier { return lines; } - private List extractVectors(List lines) { + private List extractVectors(final List lines) { final List vectors = new ArrayList(lines.size()); for (final String line : lines) { vectors.add(extractVector(line)); @@ -143,25 +143,36 @@ public class RedditClassifier { return vectors; } - private NamedVector extractVector(String line) { + private NamedVector extractVector(final String line) { final String[] items = line.split(","); - final String category = extractCategory(Integer.parseInt(items[0])); - final NamedVector vector = new NamedVector(new RandomAccessSparseVector(noOfFeatures), category); - final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT")); - cal.setTimeInMillis(Long.parseLong(items[1]) * 1000); + final String numberOfVotes = items[0]; + final String time = items[1]; + final String numberOfWordInTitle = items[2]; + final String title = items[3]; + final String theRootDomain = items[4]; + final String category = extractCategory(Integer.parseInt(numberOfVotes)); + + final NamedVector vector = new NamedVector(new RandomAccessSparseVector(noOfFeatures), category); + + final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT")); + cal.setTimeInMillis(Long.parseLong(time) * 1000); vector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day - vector.set(1, Integer.parseInt(items[2])); // number of words in the title - domainEncoder.addToVector(items[4], vector); - final String[] words = items[3].split(" "); + + vector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title + + domainEncoder.addToVector(theRootDomain, vector); + final String[] words = title.split(" "); // titleEncoder.setProbes(words.length); + + // TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to" for (final String word : words) { titleEncoder.addToVector(word, vector); } return vector; } - private String extractCategory(int score) { + private String extractCategory(final int score) { return (score < MIN_SCORE) ? "BAD" : "GOOD"; }