minor changes to reddit classifier
This commit is contained in:
parent
424c7709ae
commit
3ffe8425f3
|
@ -24,7 +24,6 @@ import com.google.common.io.Files;
|
|||
public class RedditClassifier {
|
||||
public static int GOOD = 0;
|
||||
public static int BAD = 1;
|
||||
public static int MIN_SCORE = 7;
|
||||
|
||||
private final int[] trainCount = { 0, 0 };
|
||||
private final int[] evalCount = { 0, 0 };
|
||||
|
@ -34,15 +33,16 @@ public class RedditClassifier {
|
|||
private final FeatureVectorEncoder titleEncoder;
|
||||
private final FeatureVectorEncoder domainEncoder;
|
||||
private final int noOfFeatures;
|
||||
|
||||
private final int minScore;
|
||||
private CrossFoldLearner learner;
|
||||
private double accuracy;
|
||||
|
||||
public RedditClassifier() {
|
||||
this(150, 1000);
|
||||
this(150, 1000, 7);
|
||||
}
|
||||
|
||||
public RedditClassifier(final int poolSize, final int noOfFeatures) {
|
||||
public RedditClassifier(final int poolSize, final int noOfFeatures, int minScore) {
|
||||
this.minScore = minScore;
|
||||
this.noOfFeatures = noOfFeatures;
|
||||
classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2());
|
||||
classifier.setPoolSize(poolSize);
|
||||
|
@ -154,20 +154,15 @@ public class RedditClassifier {
|
|||
internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title
|
||||
|
||||
domainEncoder.addToVector(theRootDomain, internalVector);
|
||||
final String[] words = title.split(" ");
|
||||
// titleEncoder.setProbes(words.length);
|
||||
|
||||
// TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to"
|
||||
for (final String word : words) {
|
||||
titleEncoder.addToVector(word, internalVector);
|
||||
}
|
||||
final List<String> words = Splitter.on(' ').splitToList(title);
|
||||
words.stream().filter(word -> word.length() > 2).forEach(word -> titleEncoder.addToVector(word, internalVector));
|
||||
|
||||
final String category = extractCategory(Integer.parseInt(numberOfVotes));
|
||||
return new NamedVector(internalVector, category);
|
||||
}
|
||||
|
||||
private String extractCategory(final int score) {
|
||||
return (score < MIN_SCORE) ? "BAD" : "GOOD";
|
||||
return (score < minScore) ? "BAD" : "GOOD";
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -27,11 +27,7 @@ public class RedditDataCollector {
|
|||
private final String subreddit;
|
||||
|
||||
public RedditDataCollector() {
|
||||
restTemplate = new RestTemplate();
|
||||
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
|
||||
list.add(new UserAgentInterceptor());
|
||||
restTemplate.setInterceptors(list);
|
||||
subreddit = "java";
|
||||
this("java");
|
||||
}
|
||||
|
||||
public RedditDataCollector(String subreddit) {
|
||||
|
|
|
@ -23,7 +23,7 @@ public class RedditClassifierTest {
|
|||
|
||||
@Test
|
||||
public void givenSmallerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
|
||||
final RedditClassifier classifier = new RedditClassifier(100, 500);
|
||||
final RedditClassifier classifier = new RedditClassifier(100, 500, 7);
|
||||
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
|
||||
final double result = classifier.getAccuracy();
|
||||
System.out.println("==== Custom Classifier (small) Accuracy = " + result);
|
||||
|
@ -33,7 +33,7 @@ public class RedditClassifierTest {
|
|||
|
||||
@Test
|
||||
public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
|
||||
final RedditClassifier classifier = new RedditClassifier(250, 2500);
|
||||
final RedditClassifier classifier = new RedditClassifier(250, 2500, 7);
|
||||
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
|
||||
final double result = classifier.getAccuracy();
|
||||
System.out.println("==== Custom Classifier (large) Accuracy = " + result);
|
||||
|
|
Loading…
Reference in New Issue