Merge pull request #200 from Doha2012/master

minor changes to reddit classifier
This commit is contained in:
Eugen 2015-04-23 01:29:30 +03:00
commit 799c6120e7
3 changed files with 10 additions and 19 deletions

View File

@ -24,7 +24,6 @@ import com.google.common.io.Files;
public class RedditClassifier {
public static int GOOD = 0;
public static int BAD = 1;
public static int MIN_SCORE = 7;
private final int[] trainCount = { 0, 0 };
private final int[] evalCount = { 0, 0 };
@ -34,15 +33,16 @@ public class RedditClassifier {
private final FeatureVectorEncoder titleEncoder;
private final FeatureVectorEncoder domainEncoder;
private final int noOfFeatures;
private final int minScore;
private CrossFoldLearner learner;
private double accuracy;
public RedditClassifier() {
this(150, 1000);
this(150, 1000, 7);
}
public RedditClassifier(final int poolSize, final int noOfFeatures) {
public RedditClassifier(final int poolSize, final int noOfFeatures, int minScore) {
this.minScore = minScore;
this.noOfFeatures = noOfFeatures;
classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2());
classifier.setPoolSize(poolSize);
@ -154,20 +154,15 @@ public class RedditClassifier {
internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title
domainEncoder.addToVector(theRootDomain, internalVector);
final String[] words = title.split(" ");
// titleEncoder.setProbes(words.length);
// TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to"
for (final String word : words) {
titleEncoder.addToVector(word, internalVector);
}
final List<String> words = Splitter.on(' ').splitToList(title);
words.stream().filter(word -> word.length() > 2).forEach(word -> titleEncoder.addToVector(word, internalVector));
final String category = extractCategory(Integer.parseInt(numberOfVotes));
return new NamedVector(internalVector, category);
}
private String extractCategory(final int score) {
return (score < MIN_SCORE) ? "BAD" : "GOOD";
return (score < minScore) ? "BAD" : "GOOD";
}
}

View File

@ -27,11 +27,7 @@ public class RedditDataCollector {
private final String subreddit;
public RedditDataCollector() {
restTemplate = new RestTemplate();
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
list.add(new UserAgentInterceptor());
restTemplate.setInterceptors(list);
subreddit = "java";
this("java");
}
public RedditDataCollector(String subreddit) {

View File

@ -23,7 +23,7 @@ public class RedditClassifierTest {
@Test
public void givenSmallerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
final RedditClassifier classifier = new RedditClassifier(100, 500);
final RedditClassifier classifier = new RedditClassifier(100, 500, 7);
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
final double result = classifier.getAccuracy();
System.out.println("==== Custom Classifier (small) Accuracy = " + result);
@ -33,7 +33,7 @@ public class RedditClassifierTest {
@Test
public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
final RedditClassifier classifier = new RedditClassifier(250, 2500);
final RedditClassifier classifier = new RedditClassifier(250, 2500, 7);
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
final double result = classifier.getAccuracy();
System.out.println("==== Custom Classifier (large) Accuracy = " + result);