Merge pull request #200 from Doha2012/master

minor changes to reddit classifier
This commit is contained in:
Eugen 2015-04-23 01:29:30 +03:00
commit 799c6120e7
3 changed files with 10 additions and 19 deletions

View File

@ -24,7 +24,6 @@ import com.google.common.io.Files;
public class RedditClassifier { public class RedditClassifier {
public static int GOOD = 0; public static int GOOD = 0;
public static int BAD = 1; public static int BAD = 1;
public static int MIN_SCORE = 7;
private final int[] trainCount = { 0, 0 }; private final int[] trainCount = { 0, 0 };
private final int[] evalCount = { 0, 0 }; private final int[] evalCount = { 0, 0 };
@ -34,15 +33,16 @@ public class RedditClassifier {
private final FeatureVectorEncoder titleEncoder; private final FeatureVectorEncoder titleEncoder;
private final FeatureVectorEncoder domainEncoder; private final FeatureVectorEncoder domainEncoder;
private final int noOfFeatures; private final int noOfFeatures;
private final int minScore;
private CrossFoldLearner learner; private CrossFoldLearner learner;
private double accuracy; private double accuracy;
public RedditClassifier() { public RedditClassifier() {
this(150, 1000); this(150, 1000, 7);
} }
public RedditClassifier(final int poolSize, final int noOfFeatures) { public RedditClassifier(final int poolSize, final int noOfFeatures, int minScore) {
this.minScore = minScore;
this.noOfFeatures = noOfFeatures; this.noOfFeatures = noOfFeatures;
classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2()); classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2());
classifier.setPoolSize(poolSize); classifier.setPoolSize(poolSize);
@ -154,20 +154,15 @@ public class RedditClassifier {
internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title
domainEncoder.addToVector(theRootDomain, internalVector); domainEncoder.addToVector(theRootDomain, internalVector);
final String[] words = title.split(" "); final List<String> words = Splitter.on(' ').splitToList(title);
// titleEncoder.setProbes(words.length); words.stream().filter(word -> word.length() > 2).forEach(word -> titleEncoder.addToVector(word, internalVector));
// TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to"
for (final String word : words) {
titleEncoder.addToVector(word, internalVector);
}
final String category = extractCategory(Integer.parseInt(numberOfVotes)); final String category = extractCategory(Integer.parseInt(numberOfVotes));
return new NamedVector(internalVector, category); return new NamedVector(internalVector, category);
} }
private String extractCategory(final int score) { private String extractCategory(final int score) {
return (score < MIN_SCORE) ? "BAD" : "GOOD"; return (score < minScore) ? "BAD" : "GOOD";
} }
} }

View File

@ -27,11 +27,7 @@ public class RedditDataCollector {
private final String subreddit; private final String subreddit;
public RedditDataCollector() { public RedditDataCollector() {
restTemplate = new RestTemplate(); this("java");
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
list.add(new UserAgentInterceptor());
restTemplate.setInterceptors(list);
subreddit = "java";
} }
public RedditDataCollector(String subreddit) { public RedditDataCollector(String subreddit) {

View File

@ -23,7 +23,7 @@ public class RedditClassifierTest {
@Test @Test
public void givenSmallerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException { public void givenSmallerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
final RedditClassifier classifier = new RedditClassifier(100, 500); final RedditClassifier classifier = new RedditClassifier(100, 500, 7);
classifier.trainClassifier(RedditDataCollector.DATA_FILE); classifier.trainClassifier(RedditDataCollector.DATA_FILE);
final double result = classifier.getAccuracy(); final double result = classifier.getAccuracy();
System.out.println("==== Custom Classifier (small) Accuracy = " + result); System.out.println("==== Custom Classifier (small) Accuracy = " + result);
@ -33,7 +33,7 @@ public class RedditClassifierTest {
@Test @Test
public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException { public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
final RedditClassifier classifier = new RedditClassifier(250, 2500); final RedditClassifier classifier = new RedditClassifier(250, 2500, 7);
classifier.trainClassifier(RedditDataCollector.DATA_FILE); classifier.trainClassifier(RedditDataCollector.DATA_FILE);
final double result = classifier.getAccuracy(); final double result = classifier.getAccuracy();
System.out.println("==== Custom Classifier (large) Accuracy = " + result); System.out.println("==== Custom Classifier (large) Accuracy = " + result);