mirror of https://github.com/apache/lucene.git
LUCENE-6045 - immutable ClassificationResult, minor fixes
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1677573 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
11c4a88e23
commit
bf1355346c
|
@ -226,7 +226,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
@Override
|
||||
public List<ClassificationResult<Boolean>> getClasses(String text)
|
||||
throws IOException {
|
||||
throw new RuntimeException("not implemented");
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -235,7 +235,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
@Override
|
||||
public List<ClassificationResult<Boolean>> getClasses(String text, int max)
|
||||
throws IOException {
|
||||
throw new RuntimeException("not implemented");
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -141,12 +141,22 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
|||
double wordProbability = num / den;
|
||||
|
||||
// modify the value in the result list item
|
||||
int removeIdx = -1;
|
||||
int i = 0;
|
||||
for (ClassificationResult<BytesRef> cr : ret) {
|
||||
if (cr.getAssignedClass().equals(cclass)) {
|
||||
cr.setScore(cr.getScore() + Math.log(wordProbability));
|
||||
removeIdx = i;
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
if (removeIdx >= 0) {
|
||||
ClassificationResult<BytesRef> toRemove = ret.get(removeIdx);
|
||||
ret.add(new ClassificationResult<>(toRemove.getAssignedClass(), toRemove.getScore() + Math.log(wordProbability)));
|
||||
ret.remove(removeIdx);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ package org.apache.lucene.classification;
|
|||
public class ClassificationResult<T> implements Comparable<ClassificationResult<T>> {
|
||||
|
||||
private final T assignedClass;
|
||||
private double score;
|
||||
private final double score;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
|
@ -55,16 +55,6 @@ public class ClassificationResult<T> implements Comparable<ClassificationResult<
|
|||
return score;
|
||||
}
|
||||
|
||||
/**
|
||||
* set the score value
|
||||
*
|
||||
* @param score the score for the assignedClass as a <code>double</code>
|
||||
*/
|
||||
public void setScore(double score) {
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public int compareTo(ClassificationResult<T> o) {
|
||||
return this.getScore() < o.getScore() ? 1 : this.getScore() > o.getScore() ? -1 : 0;
|
||||
|
|
|
@ -153,18 +153,21 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
}
|
||||
}
|
||||
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
||||
List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
|
||||
int sumdoc = 0;
|
||||
for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
|
||||
Integer count = entry.getValue();
|
||||
returnList.add(new ClassificationResult<>(entry.getKey().clone(), count / (double) k));
|
||||
temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), count / (double) k));
|
||||
sumdoc += count;
|
||||
}
|
||||
|
||||
//correction
|
||||
if (sumdoc < k) {
|
||||
for (ClassificationResult<BytesRef> cr : returnList) {
|
||||
cr.setScore(cr.getScore() * (double) k / (double) sumdoc);
|
||||
for (ClassificationResult<BytesRef> cr : temporaryList) {
|
||||
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc));
|
||||
}
|
||||
} else {
|
||||
returnList = temporaryList;
|
||||
}
|
||||
return returnList;
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.lucene.index.IndexWriter;
|
|||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.StorableField;
|
||||
import org.apache.lucene.index.StoredDocument;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
|
@ -91,12 +92,16 @@ public class DatasetSplitter {
|
|||
|
||||
// create a new document for indexing
|
||||
Document doc = new Document();
|
||||
StoredDocument document = originalIndex.document(scoreDoc.doc);
|
||||
if (fieldNames != null && fieldNames.length > 0) {
|
||||
for (String fieldName : fieldNames) {
|
||||
doc.add(new Field(fieldName, originalIndex.document(scoreDoc.doc).getField(fieldName).stringValue(), ft));
|
||||
StorableField field = document.getField(fieldName);
|
||||
if (field != null) {
|
||||
doc.add(new Field(fieldName, field.stringValue(), ft));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (StorableField storableField : originalIndex.document(scoreDoc.doc).getFields()) {
|
||||
for (StorableField storableField : document.getFields()) {
|
||||
if (storableField.readerValue() != null) {
|
||||
doc.add(new Field(storableField.name(), storableField.readerValue(), ft));
|
||||
} else if (storableField.binaryValue() != null) {
|
||||
|
|
Loading…
Reference in New Issue