cleanup work
This commit is contained in:
parent
cddb1ed328
commit
424c7709ae
@ -22,32 +22,24 @@ import com.google.common.base.Splitter;
|
|||||||
import com.google.common.io.Files;
|
import com.google.common.io.Files;
|
||||||
|
|
||||||
public class RedditClassifier {
|
public class RedditClassifier {
|
||||||
|
|
||||||
public static int GOOD = 0;
|
public static int GOOD = 0;
|
||||||
public static int BAD = 1;
|
public static int BAD = 1;
|
||||||
public static int MIN_SCORE = 10;
|
public static int MIN_SCORE = 7;
|
||||||
|
|
||||||
|
private final int[] trainCount = { 0, 0 };
|
||||||
|
private final int[] evalCount = { 0, 0 };
|
||||||
|
private final int[] correctCount = { 0, 0 };
|
||||||
|
|
||||||
private final AdaptiveLogisticRegression classifier;
|
private final AdaptiveLogisticRegression classifier;
|
||||||
private final FeatureVectorEncoder titleEncoder;
|
private final FeatureVectorEncoder titleEncoder;
|
||||||
private final FeatureVectorEncoder domainEncoder;
|
private final FeatureVectorEncoder domainEncoder;
|
||||||
private CrossFoldLearner learner;
|
|
||||||
private final int noOfFeatures;
|
private final int noOfFeatures;
|
||||||
|
|
||||||
|
private CrossFoldLearner learner;
|
||||||
private double accuracy;
|
private double accuracy;
|
||||||
|
|
||||||
private final int[] trainCount = { 0, 0 };
|
|
||||||
|
|
||||||
private final int[] evalCount = { 0, 0 };
|
|
||||||
|
|
||||||
private final int[] correctCount = { 0, 0 };
|
|
||||||
|
|
||||||
public RedditClassifier() {
|
public RedditClassifier() {
|
||||||
noOfFeatures = 1000;
|
this(150, 1000);
|
||||||
classifier = new AdaptiveLogisticRegression(2, 1000, new L2());
|
|
||||||
classifier.setPoolSize(150);
|
|
||||||
titleEncoder = new AdaptiveWordValueEncoder("title");
|
|
||||||
titleEncoder.setProbes(2);
|
|
||||||
domainEncoder = new StaticWordValueEncoder("domain");
|
|
||||||
domainEncoder.setProbes(1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public RedditClassifier(final int poolSize, final int noOfFeatures) {
|
public RedditClassifier(final int poolSize, final int noOfFeatures) {
|
||||||
@ -60,6 +52,8 @@ public class RedditClassifier {
|
|||||||
domainEncoder.setProbes(1);
|
domainEncoder.setProbes(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// API
|
||||||
|
|
||||||
public void trainClassifier(final String fileName) throws IOException {
|
public void trainClassifier(final String fileName) throws IOException {
|
||||||
final List<NamedVector> vectors = extractVectors(readDataFile(fileName));
|
final List<NamedVector> vectors = extractVectors(readDataFile(fileName));
|
||||||
final int size = vectors.size();
|
final int size = vectors.size();
|
||||||
@ -151,25 +145,25 @@ public class RedditClassifier {
|
|||||||
final String title = items[3];
|
final String title = items[3];
|
||||||
final String theRootDomain = items[4];
|
final String theRootDomain = items[4];
|
||||||
|
|
||||||
final String category = extractCategory(Integer.parseInt(numberOfVotes));
|
final RandomAccessSparseVector internalVector = new RandomAccessSparseVector(noOfFeatures);
|
||||||
|
|
||||||
final NamedVector vector = new NamedVector(new RandomAccessSparseVector(noOfFeatures), category);
|
|
||||||
|
|
||||||
final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
|
final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
|
||||||
cal.setTimeInMillis(Long.parseLong(time) * 1000);
|
cal.setTimeInMillis(Long.parseLong(time) * 1000);
|
||||||
vector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day
|
internalVector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day
|
||||||
|
|
||||||
vector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title
|
internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title
|
||||||
|
|
||||||
domainEncoder.addToVector(theRootDomain, vector);
|
domainEncoder.addToVector(theRootDomain, internalVector);
|
||||||
final String[] words = title.split(" ");
|
final String[] words = title.split(" ");
|
||||||
// titleEncoder.setProbes(words.length);
|
// titleEncoder.setProbes(words.length);
|
||||||
|
|
||||||
// TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to"
|
// TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to"
|
||||||
for (final String word : words) {
|
for (final String word : words) {
|
||||||
titleEncoder.addToVector(word, vector);
|
titleEncoder.addToVector(word, internalVector);
|
||||||
}
|
}
|
||||||
return vector;
|
|
||||||
|
final String category = extractCategory(Integer.parseInt(numberOfVotes));
|
||||||
|
return new NamedVector(internalVector, category);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String extractCategory(final int score) {
|
private String extractCategory(final int score) {
|
||||||
|
@ -33,7 +33,7 @@ public class RedditClassifierTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
|
public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
|
||||||
final RedditClassifier classifier = new RedditClassifier(200, 2000);
|
final RedditClassifier classifier = new RedditClassifier(250, 2500);
|
||||||
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
|
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
|
||||||
final double result = classifier.getAccuracy();
|
final double result = classifier.getAccuracy();
|
||||||
System.out.println("==== Custom Classifier (large) Accuracy = " + result);
|
System.out.println("==== Custom Classifier (large) Accuracy = " + result);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user