minor changes to reddit classifier
This commit is contained in:
		
							parent
							
								
									424c7709ae
								
							
						
					
					
						commit
						3ffe8425f3
					
				| @ -24,7 +24,6 @@ import com.google.common.io.Files; | |||||||
| public class RedditClassifier { | public class RedditClassifier { | ||||||
|     public static int GOOD = 0; |     public static int GOOD = 0; | ||||||
|     public static int BAD = 1; |     public static int BAD = 1; | ||||||
|     public static int MIN_SCORE = 7; |  | ||||||
| 
 | 
 | ||||||
|     private final int[] trainCount = { 0, 0 }; |     private final int[] trainCount = { 0, 0 }; | ||||||
|     private final int[] evalCount = { 0, 0 }; |     private final int[] evalCount = { 0, 0 }; | ||||||
| @ -34,15 +33,16 @@ public class RedditClassifier { | |||||||
|     private final FeatureVectorEncoder titleEncoder; |     private final FeatureVectorEncoder titleEncoder; | ||||||
|     private final FeatureVectorEncoder domainEncoder; |     private final FeatureVectorEncoder domainEncoder; | ||||||
|     private final int noOfFeatures; |     private final int noOfFeatures; | ||||||
| 
 |     private final int minScore; | ||||||
|     private CrossFoldLearner learner; |     private CrossFoldLearner learner; | ||||||
|     private double accuracy; |     private double accuracy; | ||||||
| 
 | 
 | ||||||
|     public RedditClassifier() { |     public RedditClassifier() { | ||||||
|         this(150, 1000); |         this(150, 1000, 7); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public RedditClassifier(final int poolSize, final int noOfFeatures) { |     public RedditClassifier(final int poolSize, final int noOfFeatures, int minScore) { | ||||||
|  |         this.minScore = minScore; | ||||||
|         this.noOfFeatures = noOfFeatures; |         this.noOfFeatures = noOfFeatures; | ||||||
|         classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2()); |         classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2()); | ||||||
|         classifier.setPoolSize(poolSize); |         classifier.setPoolSize(poolSize); | ||||||
| @ -154,20 +154,15 @@ public class RedditClassifier { | |||||||
|         internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title |         internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title | ||||||
| 
 | 
 | ||||||
|         domainEncoder.addToVector(theRootDomain, internalVector); |         domainEncoder.addToVector(theRootDomain, internalVector); | ||||||
|         final String[] words = title.split(" "); |         final List<String> words = Splitter.on(' ').splitToList(title); | ||||||
|         // titleEncoder.setProbes(words.length); |         words.stream().filter(word -> word.length() > 2).forEach(word -> titleEncoder.addToVector(word, internalVector)); | ||||||
| 
 |  | ||||||
|         // TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to" |  | ||||||
|         for (final String word : words) { |  | ||||||
|             titleEncoder.addToVector(word, internalVector); |  | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|         final String category = extractCategory(Integer.parseInt(numberOfVotes)); |         final String category = extractCategory(Integer.parseInt(numberOfVotes)); | ||||||
|         return new NamedVector(internalVector, category); |         return new NamedVector(internalVector, category); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     private String extractCategory(final int score) { |     private String extractCategory(final int score) { | ||||||
|         return (score < MIN_SCORE) ? "BAD" : "GOOD"; |         return (score < minScore) ? "BAD" : "GOOD"; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  | |||||||
| @ -27,11 +27,7 @@ public class RedditDataCollector { | |||||||
|     private final String subreddit; |     private final String subreddit; | ||||||
| 
 | 
 | ||||||
|     public RedditDataCollector() { |     public RedditDataCollector() { | ||||||
|         restTemplate = new RestTemplate(); |         this("java"); | ||||||
|         final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>(); |  | ||||||
|         list.add(new UserAgentInterceptor()); |  | ||||||
|         restTemplate.setInterceptors(list); |  | ||||||
|         subreddit = "java"; |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public RedditDataCollector(String subreddit) { |     public RedditDataCollector(String subreddit) { | ||||||
|  | |||||||
| @ -23,7 +23,7 @@ public class RedditClassifierTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void givenSmallerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException { |     public void givenSmallerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException { | ||||||
|         final RedditClassifier classifier = new RedditClassifier(100, 500); |         final RedditClassifier classifier = new RedditClassifier(100, 500, 7); | ||||||
|         classifier.trainClassifier(RedditDataCollector.DATA_FILE); |         classifier.trainClassifier(RedditDataCollector.DATA_FILE); | ||||||
|         final double result = classifier.getAccuracy(); |         final double result = classifier.getAccuracy(); | ||||||
|         System.out.println("==== Custom Classifier (small) Accuracy = " + result); |         System.out.println("==== Custom Classifier (small) Accuracy = " + result); | ||||||
| @ -33,7 +33,7 @@ public class RedditClassifierTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException { |     public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException { | ||||||
|         final RedditClassifier classifier = new RedditClassifier(250, 2500); |         final RedditClassifier classifier = new RedditClassifier(250, 2500, 7); | ||||||
|         classifier.trainClassifier(RedditDataCollector.DATA_FILE); |         classifier.trainClassifier(RedditDataCollector.DATA_FILE); | ||||||
|         final double result = classifier.getAccuracy(); |         final double result = classifier.getAccuracy(); | ||||||
|         System.out.println("==== Custom Classifier (large) Accuracy = " + result); |         System.out.println("==== Custom Classifier (large) Accuracy = " + result); | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user