mirror of https://github.com/apache/lucene.git
LUCENE-5548 - minor fixes (imports, comments, method names)
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1638718 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
a1f3cebe50
commit
88f2ebd5d0
|
@ -16,6 +16,14 @@
|
|||
*/
|
||||
package org.apache.lucene.classification;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.StringReader;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
|
@ -29,14 +37,6 @@ import org.apache.lucene.search.TopDocs;
|
|||
import org.apache.lucene.search.WildcardQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.StringReader;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A k-Nearest Neighbor classifier (see <code>http://en.wikipedia.org/wiki/K-nearest_neighbors</code>) based
|
||||
* on {@link MoreLikeThis}
|
||||
|
@ -82,14 +82,14 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
*/
|
||||
@Override
|
||||
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
|
||||
TopDocs topDocs=knnSearcher(text);
|
||||
List<ClassificationResult<BytesRef>> doclist=buildListFromTopDocs(topDocs);
|
||||
ClassificationResult<BytesRef> retval=null;
|
||||
double maxscore=-Double.MAX_VALUE;
|
||||
for(ClassificationResult<BytesRef> element:doclist){
|
||||
if(element.getScore()>maxscore){
|
||||
retval=element;
|
||||
maxscore=element.getScore();
|
||||
TopDocs topDocs = knnSearch(text);
|
||||
List<ClassificationResult<BytesRef>> doclist = buildListFromTopDocs(topDocs);
|
||||
ClassificationResult<BytesRef> retval = null;
|
||||
double maxscore = -Double.MAX_VALUE;
|
||||
for (ClassificationResult<BytesRef> element : doclist) {
|
||||
if (element.getScore() > maxscore) {
|
||||
retval = element;
|
||||
maxscore = element.getScore();
|
||||
}
|
||||
}
|
||||
return retval;
|
||||
|
@ -100,24 +100,24 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
|
||||
TopDocs topDocs=knnSearcher(text);
|
||||
List<ClassificationResult<BytesRef>> doclist=buildListFromTopDocs(topDocs);
|
||||
TopDocs topDocs = knnSearch(text);
|
||||
List<ClassificationResult<BytesRef>> doclist = buildListFromTopDocs(topDocs);
|
||||
Collections.sort(doclist);
|
||||
return doclist;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
|
||||
TopDocs topDocs=knnSearcher(text);
|
||||
List<ClassificationResult<BytesRef>> doclist=buildListFromTopDocs(topDocs);
|
||||
TopDocs topDocs = knnSearch(text);
|
||||
List<ClassificationResult<BytesRef>> doclist = buildListFromTopDocs(topDocs);
|
||||
Collections.sort(doclist);
|
||||
return doclist.subList(0, max);
|
||||
}
|
||||
|
||||
private TopDocs knnSearcher(String text) throws IOException{
|
||||
private TopDocs knnSearch(String text) throws IOException {
|
||||
if (mlt == null) {
|
||||
throw new IOException("You must first call Classifier#train");
|
||||
}
|
||||
|
@ -132,31 +132,30 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
}
|
||||
return indexSearcher.search(mltQuery, k);
|
||||
}
|
||||
|
||||
|
||||
private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
|
||||
Map<BytesRef, Integer> classCounts = new HashMap<>();
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
|
||||
Integer count = classCounts.get(cl);
|
||||
if (count != null) {
|
||||
classCounts.put(cl, count + 1);
|
||||
} else {
|
||||
classCounts.put(cl, 1);
|
||||
}
|
||||
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
|
||||
Integer count = classCounts.get(cl);
|
||||
if (count != null) {
|
||||
classCounts.put(cl, count + 1);
|
||||
} else {
|
||||
classCounts.put(cl, 1);
|
||||
}
|
||||
}
|
||||
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
||||
int sumdoc=0;
|
||||
int sumdoc = 0;
|
||||
for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
|
||||
Integer count = entry.getValue();
|
||||
returnList.add(new ClassificationResult<>(entry.getKey().clone(), count / (double) k));
|
||||
sumdoc+=count;
|
||||
|
||||
Integer count = entry.getValue();
|
||||
returnList.add(new ClassificationResult<>(entry.getKey().clone(), count / (double) k));
|
||||
sumdoc += count;
|
||||
}
|
||||
|
||||
|
||||
//correction
|
||||
if(sumdoc<k){
|
||||
for(ClassificationResult<BytesRef> cr:returnList){
|
||||
cr.setScore(cr.getScore()*(double)k/(double)sumdoc);
|
||||
if (sumdoc < k) {
|
||||
for (ClassificationResult<BytesRef> cr : returnList) {
|
||||
cr.setScore(cr.getScore() * (double) k / (double) sumdoc);
|
||||
}
|
||||
}
|
||||
return returnList;
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
<html>
|
||||
<body>
|
||||
Uses already seen data (the indexed documents) to classify new documents.
|
||||
Currently only contains a (simplistic) Lucene based Naive Bayes classifier,
|
||||
a k-Nearest Neighbor classifier and a Perceptron based classifier
|
||||
Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest Neighbor classifier and a Perceptron based classifier
|
||||
</body>
|
||||
</html>
|
||||
|
|
|
@ -17,14 +17,16 @@ package org.apache.lucene.classification.utils;
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.TextField;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.StorableField;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
|
@ -32,8 +34,6 @@ import org.apache.lucene.search.ScoreDoc;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Utility class for creating training / test / cross validation indexes from the original index.
|
||||
*/
|
||||
|
|
|
@ -16,12 +16,12 @@
|
|||
*/
|
||||
package org.apache.lucene.classification.utils;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* utility class for converting Lucene {@link org.apache.lucene.document.Document}s to <code>Double</code> vectors.
|
||||
*/
|
||||
|
|
|
@ -33,7 +33,6 @@ import java.io.Reader;
|
|||
/**
|
||||
* Testcase for {@link SimpleNaiveBayesClassifier}
|
||||
*/
|
||||
// TODO : eventually remove this if / when fallback methods exist for all un-supportable codec methods (see LUCENE-4872)
|
||||
public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<BytesRef> {
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue