minor cleanup
This commit is contained in:
parent
1175f22a99
commit
cddb1ed328
|
@ -50,7 +50,7 @@ public class RedditClassifier {
|
||||||
domainEncoder.setProbes(1);
|
domainEncoder.setProbes(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
public RedditClassifier(int poolSize, int noOfFeatures) {
|
public RedditClassifier(final int poolSize, final int noOfFeatures) {
|
||||||
this.noOfFeatures = noOfFeatures;
|
this.noOfFeatures = noOfFeatures;
|
||||||
classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2());
|
classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2());
|
||||||
classifier.setPoolSize(poolSize);
|
classifier.setPoolSize(poolSize);
|
||||||
|
@ -60,7 +60,7 @@ public class RedditClassifier {
|
||||||
domainEncoder.setProbes(1);
|
domainEncoder.setProbes(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void trainClassifier(String fileName) throws IOException {
|
public void trainClassifier(final String fileName) throws IOException {
|
||||||
final List<NamedVector> vectors = extractVectors(readDataFile(fileName));
|
final List<NamedVector> vectors = extractVectors(readDataFile(fileName));
|
||||||
final int size = vectors.size();
|
final int size = vectors.size();
|
||||||
final int noOfTraining = (int) (size * 0.8);
|
final int noOfTraining = (int) (size * 0.8);
|
||||||
|
@ -77,7 +77,7 @@ public class RedditClassifier {
|
||||||
evaluateClassifier(testData);
|
evaluateClassifier(testData);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Vector convertPost(String title, String domain, int hour) {
|
public Vector convertPost(final String title, final String domain, final int hour) {
|
||||||
final Vector vector = new RandomAccessSparseVector(noOfFeatures);
|
final Vector vector = new RandomAccessSparseVector(noOfFeatures);
|
||||||
final List<String> words = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(title);
|
final List<String> words = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(title);
|
||||||
vector.set(0, hour);
|
vector.set(0, hour);
|
||||||
|
@ -89,7 +89,7 @@ public class RedditClassifier {
|
||||||
return vector;
|
return vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
public int classify(Vector features) {
|
public int classify(final Vector features) {
|
||||||
if (learner == null) {
|
if (learner == null) {
|
||||||
learner = classifier.getBest().getPayload().getLearner();
|
learner = classifier.getBest().getPayload().getLearner();
|
||||||
}
|
}
|
||||||
|
@ -102,7 +102,7 @@ public class RedditClassifier {
|
||||||
|
|
||||||
// ==== Private methods
|
// ==== Private methods
|
||||||
|
|
||||||
private void evaluateClassifier(List<NamedVector> vectors) throws IOException {
|
private void evaluateClassifier(final List<NamedVector> vectors) throws IOException {
|
||||||
int category, result;
|
int category, result;
|
||||||
int correct = 0;
|
int correct = 0;
|
||||||
int wrong = 0;
|
int wrong = 0;
|
||||||
|
@ -125,7 +125,7 @@ public class RedditClassifier {
|
||||||
this.accuracy = correct / (wrong + correct + 0.0);
|
this.accuracy = correct / (wrong + correct + 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<String> readDataFile(String fileName) throws IOException {
|
private List<String> readDataFile(final String fileName) throws IOException {
|
||||||
List<String> lines = Files.readLines(new File(fileName), Charset.forName("utf-8"));
|
List<String> lines = Files.readLines(new File(fileName), Charset.forName("utf-8"));
|
||||||
if ((lines == null) || (lines.size() == 0)) {
|
if ((lines == null) || (lines.size() == 0)) {
|
||||||
new RedditDataCollector().collectData();
|
new RedditDataCollector().collectData();
|
||||||
|
@ -135,7 +135,7 @@ public class RedditClassifier {
|
||||||
return lines;
|
return lines;
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<NamedVector> extractVectors(List<String> lines) {
|
private List<NamedVector> extractVectors(final List<String> lines) {
|
||||||
final List<NamedVector> vectors = new ArrayList<NamedVector>(lines.size());
|
final List<NamedVector> vectors = new ArrayList<NamedVector>(lines.size());
|
||||||
for (final String line : lines) {
|
for (final String line : lines) {
|
||||||
vectors.add(extractVector(line));
|
vectors.add(extractVector(line));
|
||||||
|
@ -143,25 +143,36 @@ public class RedditClassifier {
|
||||||
return vectors;
|
return vectors;
|
||||||
}
|
}
|
||||||
|
|
||||||
private NamedVector extractVector(String line) {
|
private NamedVector extractVector(final String line) {
|
||||||
final String[] items = line.split(",");
|
final String[] items = line.split(",");
|
||||||
final String category = extractCategory(Integer.parseInt(items[0]));
|
final String numberOfVotes = items[0];
|
||||||
final NamedVector vector = new NamedVector(new RandomAccessSparseVector(noOfFeatures), category);
|
final String time = items[1];
|
||||||
final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
|
final String numberOfWordInTitle = items[2];
|
||||||
cal.setTimeInMillis(Long.parseLong(items[1]) * 1000);
|
final String title = items[3];
|
||||||
|
final String theRootDomain = items[4];
|
||||||
|
|
||||||
|
final String category = extractCategory(Integer.parseInt(numberOfVotes));
|
||||||
|
|
||||||
|
final NamedVector vector = new NamedVector(new RandomAccessSparseVector(noOfFeatures), category);
|
||||||
|
|
||||||
|
final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
|
||||||
|
cal.setTimeInMillis(Long.parseLong(time) * 1000);
|
||||||
vector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day
|
vector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day
|
||||||
vector.set(1, Integer.parseInt(items[2])); // number of words in the title
|
|
||||||
domainEncoder.addToVector(items[4], vector);
|
vector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title
|
||||||
final String[] words = items[3].split(" ");
|
|
||||||
|
domainEncoder.addToVector(theRootDomain, vector);
|
||||||
|
final String[] words = title.split(" ");
|
||||||
// titleEncoder.setProbes(words.length);
|
// titleEncoder.setProbes(words.length);
|
||||||
|
|
||||||
|
// TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to"
|
||||||
for (final String word : words) {
|
for (final String word : words) {
|
||||||
titleEncoder.addToVector(word, vector);
|
titleEncoder.addToVector(word, vector);
|
||||||
}
|
}
|
||||||
return vector;
|
return vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
private String extractCategory(int score) {
|
private String extractCategory(final int score) {
|
||||||
return (score < MIN_SCORE) ? "BAD" : "GOOD";
|
return (score < MIN_SCORE) ? "BAD" : "GOOD";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue