modify reddit classifier
This commit is contained in:
		
							parent
							
								
									608d2dbca9
								
							
						
					
					
						commit
						cd4927cdb9
					
				| @ -25,19 +25,24 @@ public class RedditClassifier { | |||||||
| 
 | 
 | ||||||
|     public static int GOOD = 0; |     public static int GOOD = 0; | ||||||
|     public static int BAD = 1; |     public static int BAD = 1; | ||||||
|     public static int MIN_SCORE = 10; |     public static int MIN_SCORE = 7; | ||||||
|  |     public static int NUM_OF_FEATURES = 1000; | ||||||
|  | 
 | ||||||
|     private final AdaptiveLogisticRegression classifier; |     private final AdaptiveLogisticRegression classifier; | ||||||
|     private final FeatureVectorEncoder titleEncoder; |     private final FeatureVectorEncoder titleEncoder; | ||||||
|     private final FeatureVectorEncoder domainEncoder; |     private final FeatureVectorEncoder domainEncoder; | ||||||
|     private CrossFoldLearner learner; |     private CrossFoldLearner learner; | ||||||
|  |     private double accuracy; | ||||||
| 
 | 
 | ||||||
|     private final int[] trainCount = { 0, 0 }; |     private final int[] trainCount = { 0, 0 }; | ||||||
| 
 | 
 | ||||||
|     private final int[] evalCount = { 0, 0 }; |     private final int[] evalCount = { 0, 0 }; | ||||||
| 
 | 
 | ||||||
|  |     private final int[] correctCount = { 0, 0 }; | ||||||
|  | 
 | ||||||
|     public RedditClassifier() { |     public RedditClassifier() { | ||||||
|         classifier = new AdaptiveLogisticRegression(2, 4, new L2()); |         classifier = new AdaptiveLogisticRegression(2, NUM_OF_FEATURES, new L2()); | ||||||
|         classifier.setPoolSize(25); |         classifier.setPoolSize(50); | ||||||
|         titleEncoder = new AdaptiveWordValueEncoder("title"); |         titleEncoder = new AdaptiveWordValueEncoder("title"); | ||||||
|         titleEncoder.setProbes(1); |         titleEncoder.setProbes(1); | ||||||
|         domainEncoder = new StaticWordValueEncoder("domain"); |         domainEncoder = new StaticWordValueEncoder("domain"); | ||||||
| @ -46,34 +51,17 @@ public class RedditClassifier { | |||||||
| 
 | 
 | ||||||
|     public void trainClassifier(String fileName) throws IOException { |     public void trainClassifier(String fileName) throws IOException { | ||||||
|         final List<NamedVector> vectors = extractVectors(readDataFile(fileName)); |         final List<NamedVector> vectors = extractVectors(readDataFile(fileName)); | ||||||
|  |         final int noOfTraining = (int) (RedditDataCollector.DATA_SIZE * 0.8); | ||||||
|  |         final List<NamedVector> trainingData = vectors.subList(0, noOfTraining); | ||||||
|  |         final List<NamedVector> testData = vectors.subList(noOfTraining, RedditDataCollector.DATA_SIZE); | ||||||
|         int category; |         int category; | ||||||
|         for (final NamedVector vector : vectors) { |         for (final NamedVector vector : trainingData) { | ||||||
|             category = (vector.getName() == "GOOD") ? GOOD : BAD; |             category = (vector.getName() == "GOOD") ? GOOD : BAD; | ||||||
|             classifier.train(category, vector); |             classifier.train(category, vector); | ||||||
|             trainCount[category]++; |             trainCount[category]++; | ||||||
|         } |         } | ||||||
|         System.out.println("Training count ========= " + trainCount[0] + "___" + trainCount[1]); |         System.out.println("Training count ========= Good = " + trainCount[0] + " ___ Bad = " + trainCount[1]); | ||||||
|     } |         evaluateClassifier(testData); | ||||||
| 
 |  | ||||||
|     public double evaluateClassifier() throws IOException { |  | ||||||
|         final List<NamedVector> vectors = extractVectors(readDataFile(RedditDataCollector.TEST_FILE)); |  | ||||||
|         int category, result; |  | ||||||
|         int correct = 0; |  | ||||||
|         int wrong = 0; |  | ||||||
|         for (final NamedVector vector : vectors) { |  | ||||||
|             category = (vector.getName() == "GOOD") ? GOOD : BAD; |  | ||||||
|             result = classify(vector); |  | ||||||
| 
 |  | ||||||
|             evalCount[category]++; |  | ||||||
|             if (category == result) { |  | ||||||
|                 correct++; |  | ||||||
|             } else { |  | ||||||
|                 wrong++; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         System.out.println(correct + " ----- " + wrong); |  | ||||||
|         System.out.println("Eval count ========= " + evalCount[0] + "___" + evalCount[1]); |  | ||||||
|         return correct / (wrong + correct + 0.0); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public Vector convertPost(String title, String domain, int hour) { |     public Vector convertPost(String title, String domain, int hour) { | ||||||
| @ -93,8 +81,34 @@ public class RedditClassifier { | |||||||
|         return learner.classifyFull(features).maxValueIndex(); |         return learner.classifyFull(features).maxValueIndex(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     public double getAccuracy() { | ||||||
|  |         return accuracy; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     // ==== Private methods |     // ==== Private methods | ||||||
| 
 | 
 | ||||||
|  |     private void evaluateClassifier(List<NamedVector> vectors) throws IOException { | ||||||
|  |         int category, result; | ||||||
|  |         int correct = 0; | ||||||
|  |         int wrong = 0; | ||||||
|  |         for (final NamedVector vector : vectors) { | ||||||
|  |             category = (vector.getName() == "GOOD") ? GOOD : BAD; | ||||||
|  |             result = classify(vector); | ||||||
|  | 
 | ||||||
|  |             evalCount[category]++; | ||||||
|  |             if (category == result) { | ||||||
|  |                 correct++; | ||||||
|  |                 correctCount[result]++; | ||||||
|  |             } else { | ||||||
|  |                 wrong++; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         System.out.println("Eval count ========= Good = " + evalCount[0] + " ___ Bad = " + evalCount[1]); | ||||||
|  |         System.out.println("Test result ======== Correct prediction = " + correct + " -----  Wrong prediction = " + wrong); | ||||||
|  |         System.out.println("Test result ======== Correct Good = " + correctCount[0] + " -----  Correct Bad = " + correctCount[1]); | ||||||
|  |         this.accuracy = correct / (wrong + correct + 0.0); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     private List<String> readDataFile(String fileName) throws IOException { |     private List<String> readDataFile(String fileName) throws IOException { | ||||||
|         List<String> lines = Files.readLines(new File(fileName), Charset.forName("utf-8")); |         List<String> lines = Files.readLines(new File(fileName), Charset.forName("utf-8")); | ||||||
|         if ((lines == null) || (lines.size() == 0)) { |         if ((lines == null) || (lines.size() == 0)) { | ||||||
| @ -116,15 +130,18 @@ public class RedditClassifier { | |||||||
|     private NamedVector extractVector(String line) { |     private NamedVector extractVector(String line) { | ||||||
|         final String[] items = line.split(","); |         final String[] items = line.split(","); | ||||||
|         final String category = extractCategory(Integer.parseInt(items[0])); |         final String category = extractCategory(Integer.parseInt(items[0])); | ||||||
|         final NamedVector vector = new NamedVector(new RandomAccessSparseVector(4), category); |         final NamedVector vector = new NamedVector(new RandomAccessSparseVector(NUM_OF_FEATURES), category); | ||||||
|         final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT")); |         final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT")); | ||||||
|         cal.setTimeInMillis(Long.parseLong(items[1]) * 1000); |         cal.setTimeInMillis(Long.parseLong(items[1]) * 1000); | ||||||
| 
 | 
 | ||||||
|         titleEncoder.addToVector(items[3], vector); |         vector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day | ||||||
|  |         vector.set(1, Integer.parseInt(items[2])); // number of words in the title | ||||||
|         domainEncoder.addToVector(items[4], vector); |         domainEncoder.addToVector(items[4], vector); | ||||||
|         vector.set(2, cal.get(Calendar.HOUR_OF_DAY)); // hour of day |         final String[] words = items[3].split(" "); | ||||||
|         vector.set(3, Integer.parseInt(items[2])); // number of words in the title |         // titleEncoder.setProbes(words.length); | ||||||
| 
 |         for (final String word : words) { | ||||||
|  |             titleEncoder.addToVector(word, vector); | ||||||
|  |         } | ||||||
|         return vector; |         return vector; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -17,7 +17,7 @@ import com.google.common.base.Splitter; | |||||||
| 
 | 
 | ||||||
| public class RedditDataCollector { | public class RedditDataCollector { | ||||||
|     public static final String TRAINING_FILE = "src/main/resources/train.csv"; |     public static final String TRAINING_FILE = "src/main/resources/train.csv"; | ||||||
|     public static final String TEST_FILE = "src/main/resources/test.csv"; |     public static final int DATA_SIZE = 8000; | ||||||
|     public static final int LIMIT = 100; |     public static final int LIMIT = 100; | ||||||
|     public static final Long YEAR = 31536000L; |     public static final Long YEAR = 31536000L; | ||||||
|     private final Logger logger = LoggerFactory.getLogger(getClass()); |     private final Logger logger = LoggerFactory.getLogger(getClass()); | ||||||
| @ -42,24 +42,15 @@ public class RedditDataCollector { | |||||||
|         this.subreddit = subreddit; |         this.subreddit = subreddit; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public void collectData() { |     public void collectData() throws IOException { | ||||||
|         final int noOfRounds = 80; |         final int noOfRounds = DATA_SIZE / LIMIT; | ||||||
|         timestamp = System.currentTimeMillis() / 1000; |         timestamp = System.currentTimeMillis() / 1000; | ||||||
|         try { |         final FileWriter writer = new FileWriter(TRAINING_FILE); | ||||||
|             final FileWriter writer = new FileWriter(TRAINING_FILE); |         writer.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n"); | ||||||
|             writer.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n"); |         for (int i = 0; i < noOfRounds; i++) { | ||||||
|             for (int i = 0; i < noOfRounds; i++) { |             getPosts(writer); | ||||||
|                 getPosts(writer); |  | ||||||
|             } |  | ||||||
|             writer.close(); |  | ||||||
| 
 |  | ||||||
|             final FileWriter testWriter = new FileWriter(TEST_FILE); |  | ||||||
|             testWriter.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n"); |  | ||||||
|             getPosts(testWriter); |  | ||||||
|             testWriter.close(); |  | ||||||
|         } catch (final Exception e) { |  | ||||||
|             logger.error("write to file error", e); |  | ||||||
|         } |         } | ||||||
|  |         writer.close(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // ==== Private |     // ==== Private | ||||||
| @ -93,7 +84,7 @@ public class RedditDataCollector { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public static void main(String[] args) { |     public static void main(String[] args) throws IOException { | ||||||
|         final RedditDataCollector collector = new RedditDataCollector(); |         final RedditDataCollector collector = new RedditDataCollector(); | ||||||
|         collector.collectData(); |         collector.collectData(); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -22,7 +22,7 @@ public class RedditClassifierTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testClassifier() throws IOException { |     public void testClassifier() throws IOException { | ||||||
|         final double result = classifier.evaluateClassifier(); |         final double result = classifier.getAccuracy(); | ||||||
|         System.out.println("Accuracy = " + result); |         System.out.println("Accuracy = " + result); | ||||||
|         assertTrue(result > 0.8); |         assertTrue(result > 0.8); | ||||||
|     } |     } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user