commit
f880c5f3aa
|
@ -25,19 +25,24 @@ public class RedditClassifier {
|
|||
|
||||
public static int GOOD = 0;
|
||||
public static int BAD = 1;
|
||||
public static int MIN_SCORE = 10;
|
||||
public static int MIN_SCORE = 7;
|
||||
public static int NUM_OF_FEATURES = 1000;
|
||||
|
||||
private final AdaptiveLogisticRegression classifier;
|
||||
private final FeatureVectorEncoder titleEncoder;
|
||||
private final FeatureVectorEncoder domainEncoder;
|
||||
private CrossFoldLearner learner;
|
||||
private double accuracy;
|
||||
|
||||
private final int[] trainCount = { 0, 0 };
|
||||
|
||||
private final int[] evalCount = { 0, 0 };
|
||||
|
||||
private final int[] correctCount = { 0, 0 };
|
||||
|
||||
public RedditClassifier() {
|
||||
classifier = new AdaptiveLogisticRegression(2, 4, new L2());
|
||||
classifier.setPoolSize(25);
|
||||
classifier = new AdaptiveLogisticRegression(2, NUM_OF_FEATURES, new L2());
|
||||
classifier.setPoolSize(50);
|
||||
titleEncoder = new AdaptiveWordValueEncoder("title");
|
||||
titleEncoder.setProbes(1);
|
||||
domainEncoder = new StaticWordValueEncoder("domain");
|
||||
|
@ -46,34 +51,17 @@ public class RedditClassifier {
|
|||
|
||||
public void trainClassifier(String fileName) throws IOException {
|
||||
final List<NamedVector> vectors = extractVectors(readDataFile(fileName));
|
||||
final int noOfTraining = (int) (RedditDataCollector.DATA_SIZE * 0.8);
|
||||
final List<NamedVector> trainingData = vectors.subList(0, noOfTraining);
|
||||
final List<NamedVector> testData = vectors.subList(noOfTraining, RedditDataCollector.DATA_SIZE);
|
||||
int category;
|
||||
for (final NamedVector vector : vectors) {
|
||||
for (final NamedVector vector : trainingData) {
|
||||
category = (vector.getName() == "GOOD") ? GOOD : BAD;
|
||||
classifier.train(category, vector);
|
||||
trainCount[category]++;
|
||||
}
|
||||
System.out.println("Training count ========= " + trainCount[0] + "___" + trainCount[1]);
|
||||
}
|
||||
|
||||
public double evaluateClassifier() throws IOException {
|
||||
final List<NamedVector> vectors = extractVectors(readDataFile(RedditDataCollector.TEST_FILE));
|
||||
int category, result;
|
||||
int correct = 0;
|
||||
int wrong = 0;
|
||||
for (final NamedVector vector : vectors) {
|
||||
category = (vector.getName() == "GOOD") ? GOOD : BAD;
|
||||
result = classify(vector);
|
||||
|
||||
evalCount[category]++;
|
||||
if (category == result) {
|
||||
correct++;
|
||||
} else {
|
||||
wrong++;
|
||||
}
|
||||
}
|
||||
System.out.println(correct + " ----- " + wrong);
|
||||
System.out.println("Eval count ========= " + evalCount[0] + "___" + evalCount[1]);
|
||||
return correct / (wrong + correct + 0.0);
|
||||
System.out.println("Training count ========= Good = " + trainCount[0] + " ___ Bad = " + trainCount[1]);
|
||||
evaluateClassifier(testData);
|
||||
}
|
||||
|
||||
public Vector convertPost(String title, String domain, int hour) {
|
||||
|
@ -93,8 +81,34 @@ public class RedditClassifier {
|
|||
return learner.classifyFull(features).maxValueIndex();
|
||||
}
|
||||
|
||||
public double getAccuracy() {
|
||||
return accuracy;
|
||||
}
|
||||
|
||||
// ==== Private methods
|
||||
|
||||
private void evaluateClassifier(List<NamedVector> vectors) throws IOException {
|
||||
int category, result;
|
||||
int correct = 0;
|
||||
int wrong = 0;
|
||||
for (final NamedVector vector : vectors) {
|
||||
category = (vector.getName() == "GOOD") ? GOOD : BAD;
|
||||
result = classify(vector);
|
||||
|
||||
evalCount[category]++;
|
||||
if (category == result) {
|
||||
correct++;
|
||||
correctCount[result]++;
|
||||
} else {
|
||||
wrong++;
|
||||
}
|
||||
}
|
||||
System.out.println("Eval count ========= Good = " + evalCount[0] + " ___ Bad = " + evalCount[1]);
|
||||
System.out.println("Test result ======== Correct prediction = " + correct + " ----- Wrong prediction = " + wrong);
|
||||
System.out.println("Test result ======== Correct Good = " + correctCount[0] + " ----- Correct Bad = " + correctCount[1]);
|
||||
this.accuracy = correct / (wrong + correct + 0.0);
|
||||
}
|
||||
|
||||
private List<String> readDataFile(String fileName) throws IOException {
|
||||
List<String> lines = Files.readLines(new File(fileName), Charset.forName("utf-8"));
|
||||
if ((lines == null) || (lines.size() == 0)) {
|
||||
|
@ -116,15 +130,18 @@ public class RedditClassifier {
|
|||
private NamedVector extractVector(String line) {
|
||||
final String[] items = line.split(",");
|
||||
final String category = extractCategory(Integer.parseInt(items[0]));
|
||||
final NamedVector vector = new NamedVector(new RandomAccessSparseVector(4), category);
|
||||
final NamedVector vector = new NamedVector(new RandomAccessSparseVector(NUM_OF_FEATURES), category);
|
||||
final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
|
||||
cal.setTimeInMillis(Long.parseLong(items[1]) * 1000);
|
||||
|
||||
titleEncoder.addToVector(items[3], vector);
|
||||
vector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day
|
||||
vector.set(1, Integer.parseInt(items[2])); // number of words in the title
|
||||
domainEncoder.addToVector(items[4], vector);
|
||||
vector.set(2, cal.get(Calendar.HOUR_OF_DAY)); // hour of day
|
||||
vector.set(3, Integer.parseInt(items[2])); // number of words in the title
|
||||
|
||||
final String[] words = items[3].split(" ");
|
||||
// titleEncoder.setProbes(words.length);
|
||||
for (final String word : words) {
|
||||
titleEncoder.addToVector(word, vector);
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ import com.google.common.base.Splitter;
|
|||
|
||||
public class RedditDataCollector {
|
||||
public static final String TRAINING_FILE = "src/main/resources/train.csv";
|
||||
public static final String TEST_FILE = "src/main/resources/test.csv";
|
||||
public static final int DATA_SIZE = 8000;
|
||||
public static final int LIMIT = 100;
|
||||
public static final Long YEAR = 31536000L;
|
||||
private final Logger logger = LoggerFactory.getLogger(getClass());
|
||||
|
@ -42,24 +42,15 @@ public class RedditDataCollector {
|
|||
this.subreddit = subreddit;
|
||||
}
|
||||
|
||||
public void collectData() {
|
||||
final int noOfRounds = 80;
|
||||
public void collectData() throws IOException {
|
||||
final int noOfRounds = DATA_SIZE / LIMIT;
|
||||
timestamp = System.currentTimeMillis() / 1000;
|
||||
try {
|
||||
final FileWriter writer = new FileWriter(TRAINING_FILE);
|
||||
writer.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n");
|
||||
for (int i = 0; i < noOfRounds; i++) {
|
||||
getPosts(writer);
|
||||
}
|
||||
writer.close();
|
||||
|
||||
final FileWriter testWriter = new FileWriter(TEST_FILE);
|
||||
testWriter.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n");
|
||||
getPosts(testWriter);
|
||||
testWriter.close();
|
||||
} catch (final Exception e) {
|
||||
logger.error("write to file error", e);
|
||||
final FileWriter writer = new FileWriter(TRAINING_FILE);
|
||||
writer.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n");
|
||||
for (int i = 0; i < noOfRounds; i++) {
|
||||
getPosts(writer);
|
||||
}
|
||||
writer.close();
|
||||
}
|
||||
|
||||
// ==== Private
|
||||
|
@ -93,7 +84,7 @@ public class RedditDataCollector {
|
|||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
public static void main(String[] args) throws IOException {
|
||||
final RedditDataCollector collector = new RedditDataCollector();
|
||||
collector.collectData();
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ public class RedditClassifierTest {
|
|||
|
||||
@Test
|
||||
public void testClassifier() throws IOException {
|
||||
final double result = classifier.evaluateClassifier();
|
||||
final double result = classifier.getAccuracy();
|
||||
System.out.println("Accuracy = " + result);
|
||||
assertTrue(result > 0.8);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue