Merge pull request #197 from Doha2012/master

modify reddit classifier
This commit is contained in:
Eugen 2015-04-21 00:50:52 +03:00
commit f880c5f3aa
3 changed files with 58 additions and 50 deletions

View File

@ -25,19 +25,24 @@ public class RedditClassifier {
public static int GOOD = 0; public static int GOOD = 0;
public static int BAD = 1; public static int BAD = 1;
public static int MIN_SCORE = 10; public static int MIN_SCORE = 7;
public static int NUM_OF_FEATURES = 1000;
private final AdaptiveLogisticRegression classifier; private final AdaptiveLogisticRegression classifier;
private final FeatureVectorEncoder titleEncoder; private final FeatureVectorEncoder titleEncoder;
private final FeatureVectorEncoder domainEncoder; private final FeatureVectorEncoder domainEncoder;
private CrossFoldLearner learner; private CrossFoldLearner learner;
private double accuracy;
private final int[] trainCount = { 0, 0 }; private final int[] trainCount = { 0, 0 };
private final int[] evalCount = { 0, 0 }; private final int[] evalCount = { 0, 0 };
private final int[] correctCount = { 0, 0 };
public RedditClassifier() { public RedditClassifier() {
classifier = new AdaptiveLogisticRegression(2, 4, new L2()); classifier = new AdaptiveLogisticRegression(2, NUM_OF_FEATURES, new L2());
classifier.setPoolSize(25); classifier.setPoolSize(50);
titleEncoder = new AdaptiveWordValueEncoder("title"); titleEncoder = new AdaptiveWordValueEncoder("title");
titleEncoder.setProbes(1); titleEncoder.setProbes(1);
domainEncoder = new StaticWordValueEncoder("domain"); domainEncoder = new StaticWordValueEncoder("domain");
@ -46,34 +51,17 @@ public class RedditClassifier {
public void trainClassifier(String fileName) throws IOException { public void trainClassifier(String fileName) throws IOException {
final List<NamedVector> vectors = extractVectors(readDataFile(fileName)); final List<NamedVector> vectors = extractVectors(readDataFile(fileName));
final int noOfTraining = (int) (RedditDataCollector.DATA_SIZE * 0.8);
final List<NamedVector> trainingData = vectors.subList(0, noOfTraining);
final List<NamedVector> testData = vectors.subList(noOfTraining, RedditDataCollector.DATA_SIZE);
int category; int category;
for (final NamedVector vector : vectors) { for (final NamedVector vector : trainingData) {
category = (vector.getName() == "GOOD") ? GOOD : BAD; category = (vector.getName() == "GOOD") ? GOOD : BAD;
classifier.train(category, vector); classifier.train(category, vector);
trainCount[category]++; trainCount[category]++;
} }
System.out.println("Training count ========= " + trainCount[0] + "___" + trainCount[1]); System.out.println("Training count ========= Good = " + trainCount[0] + " ___ Bad = " + trainCount[1]);
} evaluateClassifier(testData);
public double evaluateClassifier() throws IOException {
final List<NamedVector> vectors = extractVectors(readDataFile(RedditDataCollector.TEST_FILE));
int category, result;
int correct = 0;
int wrong = 0;
for (final NamedVector vector : vectors) {
category = (vector.getName() == "GOOD") ? GOOD : BAD;
result = classify(vector);
evalCount[category]++;
if (category == result) {
correct++;
} else {
wrong++;
}
}
System.out.println(correct + " ----- " + wrong);
System.out.println("Eval count ========= " + evalCount[0] + "___" + evalCount[1]);
return correct / (wrong + correct + 0.0);
} }
public Vector convertPost(String title, String domain, int hour) { public Vector convertPost(String title, String domain, int hour) {
@ -93,8 +81,34 @@ public class RedditClassifier {
return learner.classifyFull(features).maxValueIndex(); return learner.classifyFull(features).maxValueIndex();
} }
public double getAccuracy() {
return accuracy;
}
// ==== Private methods // ==== Private methods
private void evaluateClassifier(List<NamedVector> vectors) throws IOException {
int category, result;
int correct = 0;
int wrong = 0;
for (final NamedVector vector : vectors) {
category = (vector.getName() == "GOOD") ? GOOD : BAD;
result = classify(vector);
evalCount[category]++;
if (category == result) {
correct++;
correctCount[result]++;
} else {
wrong++;
}
}
System.out.println("Eval count ========= Good = " + evalCount[0] + " ___ Bad = " + evalCount[1]);
System.out.println("Test result ======== Correct prediction = " + correct + " ----- Wrong prediction = " + wrong);
System.out.println("Test result ======== Correct Good = " + correctCount[0] + " ----- Correct Bad = " + correctCount[1]);
this.accuracy = correct / (wrong + correct + 0.0);
}
private List<String> readDataFile(String fileName) throws IOException { private List<String> readDataFile(String fileName) throws IOException {
List<String> lines = Files.readLines(new File(fileName), Charset.forName("utf-8")); List<String> lines = Files.readLines(new File(fileName), Charset.forName("utf-8"));
if ((lines == null) || (lines.size() == 0)) { if ((lines == null) || (lines.size() == 0)) {
@ -116,15 +130,18 @@ public class RedditClassifier {
private NamedVector extractVector(String line) { private NamedVector extractVector(String line) {
final String[] items = line.split(","); final String[] items = line.split(",");
final String category = extractCategory(Integer.parseInt(items[0])); final String category = extractCategory(Integer.parseInt(items[0]));
final NamedVector vector = new NamedVector(new RandomAccessSparseVector(4), category); final NamedVector vector = new NamedVector(new RandomAccessSparseVector(NUM_OF_FEATURES), category);
final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT")); final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
cal.setTimeInMillis(Long.parseLong(items[1]) * 1000); cal.setTimeInMillis(Long.parseLong(items[1]) * 1000);
titleEncoder.addToVector(items[3], vector); vector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day
vector.set(1, Integer.parseInt(items[2])); // number of words in the title
domainEncoder.addToVector(items[4], vector); domainEncoder.addToVector(items[4], vector);
vector.set(2, cal.get(Calendar.HOUR_OF_DAY)); // hour of day final String[] words = items[3].split(" ");
vector.set(3, Integer.parseInt(items[2])); // number of words in the title // titleEncoder.setProbes(words.length);
for (final String word : words) {
titleEncoder.addToVector(word, vector);
}
return vector; return vector;
} }

View File

@ -17,7 +17,7 @@ import com.google.common.base.Splitter;
public class RedditDataCollector { public class RedditDataCollector {
public static final String TRAINING_FILE = "src/main/resources/train.csv"; public static final String TRAINING_FILE = "src/main/resources/train.csv";
public static final String TEST_FILE = "src/main/resources/test.csv"; public static final int DATA_SIZE = 8000;
public static final int LIMIT = 100; public static final int LIMIT = 100;
public static final Long YEAR = 31536000L; public static final Long YEAR = 31536000L;
private final Logger logger = LoggerFactory.getLogger(getClass()); private final Logger logger = LoggerFactory.getLogger(getClass());
@ -42,24 +42,15 @@ public class RedditDataCollector {
this.subreddit = subreddit; this.subreddit = subreddit;
} }
public void collectData() { public void collectData() throws IOException {
final int noOfRounds = 80; final int noOfRounds = DATA_SIZE / LIMIT;
timestamp = System.currentTimeMillis() / 1000; timestamp = System.currentTimeMillis() / 1000;
try {
final FileWriter writer = new FileWriter(TRAINING_FILE); final FileWriter writer = new FileWriter(TRAINING_FILE);
writer.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n"); writer.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n");
for (int i = 0; i < noOfRounds; i++) { for (int i = 0; i < noOfRounds; i++) {
getPosts(writer); getPosts(writer);
} }
writer.close(); writer.close();
final FileWriter testWriter = new FileWriter(TEST_FILE);
testWriter.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n");
getPosts(testWriter);
testWriter.close();
} catch (final Exception e) {
logger.error("write to file error", e);
}
} }
// ==== Private // ==== Private
@ -93,7 +84,7 @@ public class RedditDataCollector {
} }
} }
public static void main(String[] args) { public static void main(String[] args) throws IOException {
final RedditDataCollector collector = new RedditDataCollector(); final RedditDataCollector collector = new RedditDataCollector();
collector.collectData(); collector.collectData();
} }

View File

@ -22,7 +22,7 @@ public class RedditClassifierTest {
@Test @Test
public void testClassifier() throws IOException { public void testClassifier() throws IOException {
final double result = classifier.evaluateClassifier(); final double result = classifier.getAccuracy();
System.out.println("Accuracy = " + result); System.out.println("Accuracy = " + result);
assertTrue(result > 0.8); assertTrue(result > 0.8);
} }