Merge pull request #197 from Doha2012/master

modify reddit classifier
This commit is contained in:
Eugen 2015-04-21 00:50:52 +03:00
commit f880c5f3aa
3 changed files with 58 additions and 50 deletions

View File

@ -25,19 +25,24 @@ public class RedditClassifier {
public static int GOOD = 0;
public static int BAD = 1;
public static int MIN_SCORE = 10;
public static int MIN_SCORE = 7;
public static int NUM_OF_FEATURES = 1000;
private final AdaptiveLogisticRegression classifier;
private final FeatureVectorEncoder titleEncoder;
private final FeatureVectorEncoder domainEncoder;
private CrossFoldLearner learner;
private double accuracy;
private final int[] trainCount = { 0, 0 };
private final int[] evalCount = { 0, 0 };
private final int[] correctCount = { 0, 0 };
public RedditClassifier() {
classifier = new AdaptiveLogisticRegression(2, 4, new L2());
classifier.setPoolSize(25);
classifier = new AdaptiveLogisticRegression(2, NUM_OF_FEATURES, new L2());
classifier.setPoolSize(50);
titleEncoder = new AdaptiveWordValueEncoder("title");
titleEncoder.setProbes(1);
domainEncoder = new StaticWordValueEncoder("domain");
@ -46,34 +51,17 @@ public class RedditClassifier {
public void trainClassifier(String fileName) throws IOException {
final List<NamedVector> vectors = extractVectors(readDataFile(fileName));
final int noOfTraining = (int) (RedditDataCollector.DATA_SIZE * 0.8);
final List<NamedVector> trainingData = vectors.subList(0, noOfTraining);
final List<NamedVector> testData = vectors.subList(noOfTraining, RedditDataCollector.DATA_SIZE);
int category;
for (final NamedVector vector : vectors) {
for (final NamedVector vector : trainingData) {
category = (vector.getName() == "GOOD") ? GOOD : BAD;
classifier.train(category, vector);
trainCount[category]++;
}
System.out.println("Training count ========= " + trainCount[0] + "___" + trainCount[1]);
}
public double evaluateClassifier() throws IOException {
final List<NamedVector> vectors = extractVectors(readDataFile(RedditDataCollector.TEST_FILE));
int category, result;
int correct = 0;
int wrong = 0;
for (final NamedVector vector : vectors) {
category = (vector.getName() == "GOOD") ? GOOD : BAD;
result = classify(vector);
evalCount[category]++;
if (category == result) {
correct++;
} else {
wrong++;
}
}
System.out.println(correct + " ----- " + wrong);
System.out.println("Eval count ========= " + evalCount[0] + "___" + evalCount[1]);
return correct / (wrong + correct + 0.0);
System.out.println("Training count ========= Good = " + trainCount[0] + " ___ Bad = " + trainCount[1]);
evaluateClassifier(testData);
}
public Vector convertPost(String title, String domain, int hour) {
@ -93,8 +81,34 @@ public class RedditClassifier {
return learner.classifyFull(features).maxValueIndex();
}
public double getAccuracy() {
return accuracy;
}
// ==== Private methods
private void evaluateClassifier(List<NamedVector> vectors) throws IOException {
int category, result;
int correct = 0;
int wrong = 0;
for (final NamedVector vector : vectors) {
category = (vector.getName() == "GOOD") ? GOOD : BAD;
result = classify(vector);
evalCount[category]++;
if (category == result) {
correct++;
correctCount[result]++;
} else {
wrong++;
}
}
System.out.println("Eval count ========= Good = " + evalCount[0] + " ___ Bad = " + evalCount[1]);
System.out.println("Test result ======== Correct prediction = " + correct + " ----- Wrong prediction = " + wrong);
System.out.println("Test result ======== Correct Good = " + correctCount[0] + " ----- Correct Bad = " + correctCount[1]);
this.accuracy = correct / (wrong + correct + 0.0);
}
private List<String> readDataFile(String fileName) throws IOException {
List<String> lines = Files.readLines(new File(fileName), Charset.forName("utf-8"));
if ((lines == null) || (lines.size() == 0)) {
@ -116,15 +130,18 @@ public class RedditClassifier {
private NamedVector extractVector(String line) {
final String[] items = line.split(",");
final String category = extractCategory(Integer.parseInt(items[0]));
final NamedVector vector = new NamedVector(new RandomAccessSparseVector(4), category);
final NamedVector vector = new NamedVector(new RandomAccessSparseVector(NUM_OF_FEATURES), category);
final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
cal.setTimeInMillis(Long.parseLong(items[1]) * 1000);
titleEncoder.addToVector(items[3], vector);
vector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day
vector.set(1, Integer.parseInt(items[2])); // number of words in the title
domainEncoder.addToVector(items[4], vector);
vector.set(2, cal.get(Calendar.HOUR_OF_DAY)); // hour of day
vector.set(3, Integer.parseInt(items[2])); // number of words in the title
final String[] words = items[3].split(" ");
// titleEncoder.setProbes(words.length);
for (final String word : words) {
titleEncoder.addToVector(word, vector);
}
return vector;
}

View File

@ -17,7 +17,7 @@ import com.google.common.base.Splitter;
public class RedditDataCollector {
public static final String TRAINING_FILE = "src/main/resources/train.csv";
public static final String TEST_FILE = "src/main/resources/test.csv";
public static final int DATA_SIZE = 8000;
public static final int LIMIT = 100;
public static final Long YEAR = 31536000L;
private final Logger logger = LoggerFactory.getLogger(getClass());
@ -42,24 +42,15 @@ public class RedditDataCollector {
this.subreddit = subreddit;
}
public void collectData() {
final int noOfRounds = 80;
public void collectData() throws IOException {
final int noOfRounds = DATA_SIZE / LIMIT;
timestamp = System.currentTimeMillis() / 1000;
try {
final FileWriter writer = new FileWriter(TRAINING_FILE);
writer.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n");
for (int i = 0; i < noOfRounds; i++) {
getPosts(writer);
}
writer.close();
final FileWriter testWriter = new FileWriter(TEST_FILE);
testWriter.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n");
getPosts(testWriter);
testWriter.close();
} catch (final Exception e) {
logger.error("write to file error", e);
final FileWriter writer = new FileWriter(TRAINING_FILE);
writer.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n");
for (int i = 0; i < noOfRounds; i++) {
getPosts(writer);
}
writer.close();
}
// ==== Private
@ -93,7 +84,7 @@ public class RedditDataCollector {
}
}
public static void main(String[] args) {
public static void main(String[] args) throws IOException {
final RedditDataCollector collector = new RedditDataCollector();
collector.collectData();
}

View File

@ -22,7 +22,7 @@ public class RedditClassifierTest {
@Test
public void testClassifier() throws IOException {
final double result = classifier.evaluateClassifier();
final double result = classifier.getAccuracy();
System.out.println("Accuracy = " + result);
assertTrue(result > 0.8);
}