LUCENE-6045 - immutable ClassificationResult, minor fixes

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1677573 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2015-05-04 10:06:32 +00:00
parent 11c4a88e23
commit bf1355346c
5 changed files with 27 additions and 19 deletions

View File

@ -226,7 +226,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
@Override
public List<ClassificationResult<Boolean>> getClasses(String text)
throws IOException {
throw new RuntimeException("not implemented");
return null;
}
/**
@ -235,7 +235,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
@Override
public List<ClassificationResult<Boolean>> getClasses(String text, int max)
throws IOException {
throw new RuntimeException("not implemented");
return null;
}
}

View File

@ -141,12 +141,22 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
double wordProbability = num / den;
// modify the value in the result list item
int removeIdx = -1;
int i = 0;
for (ClassificationResult<BytesRef> cr : ret) {
if (cr.getAssignedClass().equals(cclass)) {
cr.setScore(cr.getScore() + Math.log(wordProbability));
removeIdx = i;
break;
}
i++;
}
if (removeIdx >= 0) {
ClassificationResult<BytesRef> toRemove = ret.get(removeIdx);
ret.add(new ClassificationResult<>(toRemove.getAssignedClass(), toRemove.getScore() + Math.log(wordProbability)));
ret.remove(removeIdx);
}
}
}

View File

@ -24,7 +24,7 @@ package org.apache.lucene.classification;
public class ClassificationResult<T> implements Comparable<ClassificationResult<T>> {
private final T assignedClass;
private double score;
private final double score;
/**
* Constructor
@ -55,16 +55,6 @@ public class ClassificationResult<T> implements Comparable<ClassificationResult<
return score;
}
/**
* set the score value
*
* @param score the score for the assignedClass as a <code>double</code>
*/
public void setScore(double score) {
this.score = score;
}
@Override
public int compareTo(ClassificationResult<T> o) {
return this.getScore() < o.getScore() ? 1 : this.getScore() > o.getScore() ? -1 : 0;

View File

@ -153,18 +153,21 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
}
}
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
int sumdoc = 0;
for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
Integer count = entry.getValue();
returnList.add(new ClassificationResult<>(entry.getKey().clone(), count / (double) k));
temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), count / (double) k));
sumdoc += count;
}
//correction
if (sumdoc < k) {
for (ClassificationResult<BytesRef> cr : returnList) {
cr.setScore(cr.getScore() * (double) k / (double) sumdoc);
for (ClassificationResult<BytesRef> cr : temporaryList) {
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc));
}
} else {
returnList = temporaryList;
}
return returnList;
}

View File

@ -28,6 +28,7 @@ import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.StorableField;
import org.apache.lucene.index.StoredDocument;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreDoc;
@ -91,12 +92,16 @@ public class DatasetSplitter {
// create a new document for indexing
Document doc = new Document();
StoredDocument document = originalIndex.document(scoreDoc.doc);
if (fieldNames != null && fieldNames.length > 0) {
for (String fieldName : fieldNames) {
doc.add(new Field(fieldName, originalIndex.document(scoreDoc.doc).getField(fieldName).stringValue(), ft));
StorableField field = document.getField(fieldName);
if (field != null) {
doc.add(new Field(fieldName, field.stringValue(), ft));
}
}
} else {
for (StorableField storableField : originalIndex.document(scoreDoc.doc).getFields()) {
for (StorableField storableField : document.getFields()) {
if (storableField.readerValue() != null) {
doc.add(new Field(storableField.name(), storableField.readerValue(), ft));
} else if (storableField.binaryValue() != null) {