LUCENE-5548 - minor fixes (imports, comments, method names)

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1638718 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2014-11-12 08:40:02 +00:00
parent a1f3cebe50
commit 88f2ebd5d0
5 changed files with 44 additions and 47 deletions

View File

@ -16,6 +16,14 @@
*/
package org.apache.lucene.classification;
import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Term;
@ -29,14 +37,6 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* A k-Nearest Neighbor classifier (see <code>http://en.wikipedia.org/wiki/K-nearest_neighbors</code>) based
* on {@link MoreLikeThis}
@ -82,14 +82,14 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
*/
@Override
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
TopDocs topDocs=knnSearcher(text);
List<ClassificationResult<BytesRef>> doclist=buildListFromTopDocs(topDocs);
ClassificationResult<BytesRef> retval=null;
double maxscore=-Double.MAX_VALUE;
for(ClassificationResult<BytesRef> element:doclist){
if(element.getScore()>maxscore){
retval=element;
maxscore=element.getScore();
TopDocs topDocs = knnSearch(text);
List<ClassificationResult<BytesRef>> doclist = buildListFromTopDocs(topDocs);
ClassificationResult<BytesRef> retval = null;
double maxscore = -Double.MAX_VALUE;
for (ClassificationResult<BytesRef> element : doclist) {
if (element.getScore() > maxscore) {
retval = element;
maxscore = element.getScore();
}
}
return retval;
@ -100,24 +100,24 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
TopDocs topDocs=knnSearcher(text);
List<ClassificationResult<BytesRef>> doclist=buildListFromTopDocs(topDocs);
TopDocs topDocs = knnSearch(text);
List<ClassificationResult<BytesRef>> doclist = buildListFromTopDocs(topDocs);
Collections.sort(doclist);
return doclist;
}
/**
* {@inheritDoc}
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
TopDocs topDocs=knnSearcher(text);
List<ClassificationResult<BytesRef>> doclist=buildListFromTopDocs(topDocs);
TopDocs topDocs = knnSearch(text);
List<ClassificationResult<BytesRef>> doclist = buildListFromTopDocs(topDocs);
Collections.sort(doclist);
return doclist.subList(0, max);
}
private TopDocs knnSearcher(String text) throws IOException{
private TopDocs knnSearch(String text) throws IOException {
if (mlt == null) {
throw new IOException("You must first call Classifier#train");
}
@ -132,31 +132,30 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
}
return indexSearcher.search(mltQuery, k);
}
private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
Map<BytesRef, Integer> classCounts = new HashMap<>();
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
Integer count = classCounts.get(cl);
if (count != null) {
classCounts.put(cl, count + 1);
} else {
classCounts.put(cl, 1);
}
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
Integer count = classCounts.get(cl);
if (count != null) {
classCounts.put(cl, count + 1);
} else {
classCounts.put(cl, 1);
}
}
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
int sumdoc=0;
int sumdoc = 0;
for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
Integer count = entry.getValue();
returnList.add(new ClassificationResult<>(entry.getKey().clone(), count / (double) k));
sumdoc+=count;
Integer count = entry.getValue();
returnList.add(new ClassificationResult<>(entry.getKey().clone(), count / (double) k));
sumdoc += count;
}
//correction
if(sumdoc<k){
for(ClassificationResult<BytesRef> cr:returnList){
cr.setScore(cr.getScore()*(double)k/(double)sumdoc);
if (sumdoc < k) {
for (ClassificationResult<BytesRef> cr : returnList) {
cr.setScore(cr.getScore() * (double) k / (double) sumdoc);
}
}
return returnList;

View File

@ -17,7 +17,6 @@
<html>
<body>
Uses already seen data (the indexed documents) to classify new documents.
Currently only contains a (simplistic) Lucene based Naive Bayes classifier,
a k-Nearest Neighbor classifier and a Perceptron based classifier
Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest Neighbor classifier and a Perceptron based classifier
</body>
</html>

View File

@ -17,14 +17,16 @@ package org.apache.lucene.classification.utils;
* limitations under the License.
*/
import java.io.IOException;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.StorableField;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
@ -32,8 +34,6 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import java.io.IOException;
/**
* Utility class for creating training / test / cross validation indexes from the original index.
*/

View File

@ -16,12 +16,12 @@
*/
package org.apache.lucene.classification.utils;
import java.io.IOException;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.util.BytesRef;
import java.io.IOException;
/**
* utility class for converting Lucene {@link org.apache.lucene.document.Document}s to <code>Double</code> vectors.
*/

View File

@ -33,7 +33,6 @@ import java.io.Reader;
/**
* Testcase for {@link SimpleNaiveBayesClassifier}
*/
// TODO : eventually remove this if / when fallback methods exist for all un-supportable codec methods (see LUCENE-4872)
public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<BytesRef> {
@Test