Merge pull request #200 from Doha2012/master
minor changes to reddit classifier
This commit is contained in:
commit
799c6120e7
|
@ -24,7 +24,6 @@ import com.google.common.io.Files;
|
||||||
public class RedditClassifier {
|
public class RedditClassifier {
|
||||||
public static int GOOD = 0;
|
public static int GOOD = 0;
|
||||||
public static int BAD = 1;
|
public static int BAD = 1;
|
||||||
public static int MIN_SCORE = 7;
|
|
||||||
|
|
||||||
private final int[] trainCount = { 0, 0 };
|
private final int[] trainCount = { 0, 0 };
|
||||||
private final int[] evalCount = { 0, 0 };
|
private final int[] evalCount = { 0, 0 };
|
||||||
|
@ -34,15 +33,16 @@ public class RedditClassifier {
|
||||||
private final FeatureVectorEncoder titleEncoder;
|
private final FeatureVectorEncoder titleEncoder;
|
||||||
private final FeatureVectorEncoder domainEncoder;
|
private final FeatureVectorEncoder domainEncoder;
|
||||||
private final int noOfFeatures;
|
private final int noOfFeatures;
|
||||||
|
private final int minScore;
|
||||||
private CrossFoldLearner learner;
|
private CrossFoldLearner learner;
|
||||||
private double accuracy;
|
private double accuracy;
|
||||||
|
|
||||||
public RedditClassifier() {
|
public RedditClassifier() {
|
||||||
this(150, 1000);
|
this(150, 1000, 7);
|
||||||
}
|
}
|
||||||
|
|
||||||
public RedditClassifier(final int poolSize, final int noOfFeatures) {
|
public RedditClassifier(final int poolSize, final int noOfFeatures, int minScore) {
|
||||||
|
this.minScore = minScore;
|
||||||
this.noOfFeatures = noOfFeatures;
|
this.noOfFeatures = noOfFeatures;
|
||||||
classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2());
|
classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2());
|
||||||
classifier.setPoolSize(poolSize);
|
classifier.setPoolSize(poolSize);
|
||||||
|
@ -154,20 +154,15 @@ public class RedditClassifier {
|
||||||
internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title
|
internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title
|
||||||
|
|
||||||
domainEncoder.addToVector(theRootDomain, internalVector);
|
domainEncoder.addToVector(theRootDomain, internalVector);
|
||||||
final String[] words = title.split(" ");
|
final List<String> words = Splitter.on(' ').splitToList(title);
|
||||||
// titleEncoder.setProbes(words.length);
|
words.stream().filter(word -> word.length() > 2).forEach(word -> titleEncoder.addToVector(word, internalVector));
|
||||||
|
|
||||||
// TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to"
|
|
||||||
for (final String word : words) {
|
|
||||||
titleEncoder.addToVector(word, internalVector);
|
|
||||||
}
|
|
||||||
|
|
||||||
final String category = extractCategory(Integer.parseInt(numberOfVotes));
|
final String category = extractCategory(Integer.parseInt(numberOfVotes));
|
||||||
return new NamedVector(internalVector, category);
|
return new NamedVector(internalVector, category);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String extractCategory(final int score) {
|
private String extractCategory(final int score) {
|
||||||
return (score < MIN_SCORE) ? "BAD" : "GOOD";
|
return (score < minScore) ? "BAD" : "GOOD";
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,11 +27,7 @@ public class RedditDataCollector {
|
||||||
private final String subreddit;
|
private final String subreddit;
|
||||||
|
|
||||||
public RedditDataCollector() {
|
public RedditDataCollector() {
|
||||||
restTemplate = new RestTemplate();
|
this("java");
|
||||||
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
|
|
||||||
list.add(new UserAgentInterceptor());
|
|
||||||
restTemplate.setInterceptors(list);
|
|
||||||
subreddit = "java";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public RedditDataCollector(String subreddit) {
|
public RedditDataCollector(String subreddit) {
|
||||||
|
|
|
@ -23,7 +23,7 @@ public class RedditClassifierTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void givenSmallerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
|
public void givenSmallerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
|
||||||
final RedditClassifier classifier = new RedditClassifier(100, 500);
|
final RedditClassifier classifier = new RedditClassifier(100, 500, 7);
|
||||||
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
|
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
|
||||||
final double result = classifier.getAccuracy();
|
final double result = classifier.getAccuracy();
|
||||||
System.out.println("==== Custom Classifier (small) Accuracy = " + result);
|
System.out.println("==== Custom Classifier (small) Accuracy = " + result);
|
||||||
|
@ -33,7 +33,7 @@ public class RedditClassifierTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
|
public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
|
||||||
final RedditClassifier classifier = new RedditClassifier(250, 2500);
|
final RedditClassifier classifier = new RedditClassifier(250, 2500, 7);
|
||||||
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
|
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
|
||||||
final double result = classifier.getAccuracy();
|
final double result = classifier.getAccuracy();
|
||||||
System.out.println("==== Custom Classifier (large) Accuracy = " + result);
|
System.out.println("==== Custom Classifier (large) Accuracy = " + result);
|
||||||
|
|
Loading…
Reference in New Issue