update classifier

This commit is contained in:
DOHA 2015-04-21 08:30:11 +02:00
parent cd4927cdb9
commit 42369da868
1 changed files with 13 additions and 10 deletions

View File

@ -25,7 +25,7 @@ public class RedditClassifier {
public static int GOOD = 0;
public static int BAD = 1;
public static int MIN_SCORE = 7;
public static int MIN_SCORE = 10;
public static int NUM_OF_FEATURES = 1000;
private final AdaptiveLogisticRegression classifier;
@ -42,9 +42,9 @@ public class RedditClassifier {
public RedditClassifier() {
classifier = new AdaptiveLogisticRegression(2, NUM_OF_FEATURES, new L2());
classifier.setPoolSize(50);
classifier.setPoolSize(150);
titleEncoder = new AdaptiveWordValueEncoder("title");
titleEncoder.setProbes(1);
titleEncoder.setProbes(2);
domainEncoder = new StaticWordValueEncoder("domain");
domainEncoder.setProbes(1);
}
@ -65,13 +65,15 @@ public class RedditClassifier {
}
public Vector convertPost(String title, String domain, int hour) {
final Vector features = new RandomAccessSparseVector(4);
final int noOfWords = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(title).size();
titleEncoder.addToVector(title, features);
domainEncoder.addToVector(domain, features);
features.set(2, hour);
features.set(3, noOfWords);
return features;
final Vector vector = new RandomAccessSparseVector(NUM_OF_FEATURES);
final List<String> words = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(title);
vector.set(0, hour);
vector.set(1, words.size());
domainEncoder.addToVector(domain, vector);
for (final String word : words) {
titleEncoder.addToVector(word, vector);
}
return vector;
}
public int classify(Vector features) {
@ -106,6 +108,7 @@ public class RedditClassifier {
System.out.println("Eval count ========= Good = " + evalCount[0] + " ___ Bad = " + evalCount[1]);
System.out.println("Test result ======== Correct prediction = " + correct + " ----- Wrong prediction = " + wrong);
System.out.println("Test result ======== Correct Good = " + correctCount[0] + " ----- Correct Bad = " + correctCount[1]);
System.out.println("Test result ======== Good accuracy = " + (correctCount[0] / (evalCount[0] + 0.0)) + " ----- Bad accuracy = " + (correctCount[1] / (evalCount[1] + 0.0)));
this.accuracy = correct / (wrong + correct + 0.0);
}