From cd4927cdb992bb395882942647c74d310a58a103 Mon Sep 17 00:00:00 2001 From: DOHA Date: Mon, 20 Apr 2015 22:04:25 +0200 Subject: [PATCH] modify reddit classifier --- .../reddit/classifier/RedditClassifier.java | 79 +++++++++++-------- .../classifier/RedditDataCollector.java | 27 +++---- .../classifier/RedditClassifierTest.java | 2 +- 3 files changed, 58 insertions(+), 50 deletions(-) diff --git a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java index 04cddd3776..4772f8927f 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java +++ b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java @@ -25,19 +25,24 @@ public class RedditClassifier { public static int GOOD = 0; public static int BAD = 1; - public static int MIN_SCORE = 10; + public static int MIN_SCORE = 7; + public static int NUM_OF_FEATURES = 1000; + private final AdaptiveLogisticRegression classifier; private final FeatureVectorEncoder titleEncoder; private final FeatureVectorEncoder domainEncoder; private CrossFoldLearner learner; + private double accuracy; private final int[] trainCount = { 0, 0 }; private final int[] evalCount = { 0, 0 }; + private final int[] correctCount = { 0, 0 }; + public RedditClassifier() { - classifier = new AdaptiveLogisticRegression(2, 4, new L2()); - classifier.setPoolSize(25); + classifier = new AdaptiveLogisticRegression(2, NUM_OF_FEATURES, new L2()); + classifier.setPoolSize(50); titleEncoder = new AdaptiveWordValueEncoder("title"); titleEncoder.setProbes(1); domainEncoder = new StaticWordValueEncoder("domain"); @@ -46,34 +51,17 @@ public class RedditClassifier { public void trainClassifier(String fileName) throws IOException { final List vectors = extractVectors(readDataFile(fileName)); + final int noOfTraining = (int) (RedditDataCollector.DATA_SIZE * 0.8); + final List trainingData = vectors.subList(0, noOfTraining); + final List testData = vectors.subList(noOfTraining, RedditDataCollector.DATA_SIZE); int category; - for (final NamedVector vector : vectors) { + for (final NamedVector vector : trainingData) { category = (vector.getName() == "GOOD") ? GOOD : BAD; classifier.train(category, vector); trainCount[category]++; } - System.out.println("Training count ========= " + trainCount[0] + "___" + trainCount[1]); - } - - public double evaluateClassifier() throws IOException { - final List vectors = extractVectors(readDataFile(RedditDataCollector.TEST_FILE)); - int category, result; - int correct = 0; - int wrong = 0; - for (final NamedVector vector : vectors) { - category = (vector.getName() == "GOOD") ? GOOD : BAD; - result = classify(vector); - - evalCount[category]++; - if (category == result) { - correct++; - } else { - wrong++; - } - } - System.out.println(correct + " ----- " + wrong); - System.out.println("Eval count ========= " + evalCount[0] + "___" + evalCount[1]); - return correct / (wrong + correct + 0.0); + System.out.println("Training count ========= Good = " + trainCount[0] + " ___ Bad = " + trainCount[1]); + evaluateClassifier(testData); } public Vector convertPost(String title, String domain, int hour) { @@ -93,8 +81,34 @@ public class RedditClassifier { return learner.classifyFull(features).maxValueIndex(); } + public double getAccuracy() { + return accuracy; + } + // ==== Private methods + private void evaluateClassifier(List vectors) throws IOException { + int category, result; + int correct = 0; + int wrong = 0; + for (final NamedVector vector : vectors) { + category = (vector.getName() == "GOOD") ? GOOD : BAD; + result = classify(vector); + + evalCount[category]++; + if (category == result) { + correct++; + correctCount[result]++; + } else { + wrong++; + } + } + System.out.println("Eval count ========= Good = " + evalCount[0] + " ___ Bad = " + evalCount[1]); + System.out.println("Test result ======== Correct prediction = " + correct + " ----- Wrong prediction = " + wrong); + System.out.println("Test result ======== Correct Good = " + correctCount[0] + " ----- Correct Bad = " + correctCount[1]); + this.accuracy = correct / (wrong + correct + 0.0); + } + private List readDataFile(String fileName) throws IOException { List lines = Files.readLines(new File(fileName), Charset.forName("utf-8")); if ((lines == null) || (lines.size() == 0)) { @@ -116,15 +130,18 @@ public class RedditClassifier { private NamedVector extractVector(String line) { final String[] items = line.split(","); final String category = extractCategory(Integer.parseInt(items[0])); - final NamedVector vector = new NamedVector(new RandomAccessSparseVector(4), category); + final NamedVector vector = new NamedVector(new RandomAccessSparseVector(NUM_OF_FEATURES), category); final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT")); cal.setTimeInMillis(Long.parseLong(items[1]) * 1000); - titleEncoder.addToVector(items[3], vector); + vector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day + vector.set(1, Integer.parseInt(items[2])); // number of words in the title domainEncoder.addToVector(items[4], vector); - vector.set(2, cal.get(Calendar.HOUR_OF_DAY)); // hour of day - vector.set(3, Integer.parseInt(items[2])); // number of words in the title - + final String[] words = items[3].split(" "); + // titleEncoder.setProbes(words.length); + for (final String word : words) { + titleEncoder.addToVector(word, vector); + } return vector; } diff --git a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditDataCollector.java b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditDataCollector.java index fde2e696bb..07315cac59 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditDataCollector.java +++ b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditDataCollector.java @@ -17,7 +17,7 @@ import com.google.common.base.Splitter; public class RedditDataCollector { public static final String TRAINING_FILE = "src/main/resources/train.csv"; - public static final String TEST_FILE = "src/main/resources/test.csv"; + public static final int DATA_SIZE = 8000; public static final int LIMIT = 100; public static final Long YEAR = 31536000L; private final Logger logger = LoggerFactory.getLogger(getClass()); @@ -42,24 +42,15 @@ public class RedditDataCollector { this.subreddit = subreddit; } - public void collectData() { - final int noOfRounds = 80; + public void collectData() throws IOException { + final int noOfRounds = DATA_SIZE / LIMIT; timestamp = System.currentTimeMillis() / 1000; - try { - final FileWriter writer = new FileWriter(TRAINING_FILE); - writer.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n"); - for (int i = 0; i < noOfRounds; i++) { - getPosts(writer); - } - writer.close(); - - final FileWriter testWriter = new FileWriter(TEST_FILE); - testWriter.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n"); - getPosts(testWriter); - testWriter.close(); - } catch (final Exception e) { - logger.error("write to file error", e); + final FileWriter writer = new FileWriter(TRAINING_FILE); + writer.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n"); + for (int i = 0; i < noOfRounds; i++) { + getPosts(writer); } + writer.close(); } // ==== Private @@ -93,7 +84,7 @@ public class RedditDataCollector { } } - public static void main(String[] args) { + public static void main(String[] args) throws IOException { final RedditDataCollector collector = new RedditDataCollector(); collector.collectData(); } diff --git a/spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java b/spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java index f37208aa11..290ea9164c 100644 --- a/spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java +++ b/spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java @@ -22,7 +22,7 @@ public class RedditClassifierTest { @Test public void testClassifier() throws IOException { - final double result = classifier.evaluateClassifier(); + final double result = classifier.getAccuracy(); System.out.println("Accuracy = " + result); assertTrue(result > 0.8); }