cleanup work

This commit is contained in:
eugenp 2015-04-22 13:28:04 +03:00
parent cddb1ed328
commit 424c7709ae
2 changed files with 19 additions and 25 deletions

View File

@ -22,32 +22,24 @@ import com.google.common.base.Splitter;
import com.google.common.io.Files; import com.google.common.io.Files;
public class RedditClassifier { public class RedditClassifier {
public static int GOOD = 0; public static int GOOD = 0;
public static int BAD = 1; public static int BAD = 1;
public static int MIN_SCORE = 10; public static int MIN_SCORE = 7;
private final int[] trainCount = { 0, 0 };
private final int[] evalCount = { 0, 0 };
private final int[] correctCount = { 0, 0 };
private final AdaptiveLogisticRegression classifier; private final AdaptiveLogisticRegression classifier;
private final FeatureVectorEncoder titleEncoder; private final FeatureVectorEncoder titleEncoder;
private final FeatureVectorEncoder domainEncoder; private final FeatureVectorEncoder domainEncoder;
private CrossFoldLearner learner;
private final int noOfFeatures; private final int noOfFeatures;
private CrossFoldLearner learner;
private double accuracy; private double accuracy;
private final int[] trainCount = { 0, 0 };
private final int[] evalCount = { 0, 0 };
private final int[] correctCount = { 0, 0 };
public RedditClassifier() { public RedditClassifier() {
noOfFeatures = 1000; this(150, 1000);
classifier = new AdaptiveLogisticRegression(2, 1000, new L2());
classifier.setPoolSize(150);
titleEncoder = new AdaptiveWordValueEncoder("title");
titleEncoder.setProbes(2);
domainEncoder = new StaticWordValueEncoder("domain");
domainEncoder.setProbes(1);
} }
public RedditClassifier(final int poolSize, final int noOfFeatures) { public RedditClassifier(final int poolSize, final int noOfFeatures) {
@ -60,6 +52,8 @@ public class RedditClassifier {
domainEncoder.setProbes(1); domainEncoder.setProbes(1);
} }
// API
public void trainClassifier(final String fileName) throws IOException { public void trainClassifier(final String fileName) throws IOException {
final List<NamedVector> vectors = extractVectors(readDataFile(fileName)); final List<NamedVector> vectors = extractVectors(readDataFile(fileName));
final int size = vectors.size(); final int size = vectors.size();
@ -151,25 +145,25 @@ public class RedditClassifier {
final String title = items[3]; final String title = items[3];
final String theRootDomain = items[4]; final String theRootDomain = items[4];
final String category = extractCategory(Integer.parseInt(numberOfVotes)); final RandomAccessSparseVector internalVector = new RandomAccessSparseVector(noOfFeatures);
final NamedVector vector = new NamedVector(new RandomAccessSparseVector(noOfFeatures), category);
final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT")); final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
cal.setTimeInMillis(Long.parseLong(time) * 1000); cal.setTimeInMillis(Long.parseLong(time) * 1000);
vector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day internalVector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day
vector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title
domainEncoder.addToVector(theRootDomain, vector); domainEncoder.addToVector(theRootDomain, internalVector);
final String[] words = title.split(" "); final String[] words = title.split(" ");
// titleEncoder.setProbes(words.length); // titleEncoder.setProbes(words.length);
// TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to" // TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to"
for (final String word : words) { for (final String word : words) {
titleEncoder.addToVector(word, vector); titleEncoder.addToVector(word, internalVector);
} }
return vector;
final String category = extractCategory(Integer.parseInt(numberOfVotes));
return new NamedVector(internalVector, category);
} }
private String extractCategory(final int score) { private String extractCategory(final int score) {

View File

@ -33,7 +33,7 @@ public class RedditClassifierTest {
@Test @Test
public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException { public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
final RedditClassifier classifier = new RedditClassifier(200, 2000); final RedditClassifier classifier = new RedditClassifier(250, 2500);
classifier.trainClassifier(RedditDataCollector.DATA_FILE); classifier.trainClassifier(RedditDataCollector.DATA_FILE);
final double result = classifier.getAccuracy(); final double result = classifier.getAccuracy();
System.out.println("==== Custom Classifier (large) Accuracy = " + result); System.out.println("==== Custom Classifier (large) Accuracy = " + result);