mirror of https://github.com/apache/lucene.git
LUCENE-6433 - using generics in Classifier#getClasses
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1674304 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
44a97f67b7
commit
dbf063d3d1
|
@ -173,15 +173,17 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
Integer.MAX_VALUE).scoreDocs) {
|
||||
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
|
||||
|
||||
// assign class to the doc
|
||||
ClassificationResult<Boolean> classificationResult = assignClass(doc
|
||||
.getField(textFieldName).stringValue());
|
||||
Boolean assignedClass = classificationResult.getAssignedClass();
|
||||
StorableField textField = doc.getField(textFieldName);
|
||||
|
||||
// get the expected result
|
||||
StorableField field = doc.getField(classFieldName);
|
||||
StorableField classField = doc.getField(classFieldName);
|
||||
|
||||
Boolean correctClass = Boolean.valueOf(field.stringValue());
|
||||
if (textField != null && classField != null) {
|
||||
// assign class to the doc
|
||||
ClassificationResult<Boolean> classificationResult = assignClass(textField.stringValue());
|
||||
Boolean assignedClass = classificationResult.getAssignedClass();
|
||||
|
||||
Boolean correctClass = Boolean.valueOf(classField.stringValue());
|
||||
long modifier = correctClass.compareTo(assignedClass);
|
||||
if (modifier != 0) {
|
||||
updateWeights(leafReader, scoreDoc.doc, assignedClass,
|
||||
|
@ -189,6 +191,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
}
|
||||
batchCount++;
|
||||
}
|
||||
}
|
||||
weights.clear(); // free memory while waiting for GC
|
||||
}
|
||||
|
||||
|
@ -246,18 +249,18 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text)
|
||||
public List<ClassificationResult<Boolean>> getClasses(String text)
|
||||
throws IOException {
|
||||
return null;
|
||||
throw new RuntimeException("not implemented");
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max)
|
||||
public List<ClassificationResult<Boolean>> getClasses(String text, int max)
|
||||
throws IOException {
|
||||
return null;
|
||||
throw new RuntimeException("not implemented");
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ public interface Classifier<T> {
|
|||
* @return the whole list of {@link ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException;
|
||||
public List<ClassificationResult<T>> getClasses(String text) throws IOException;
|
||||
|
||||
/**
|
||||
* Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
|
||||
|
@ -58,7 +58,7 @@ public interface Classifier<T> {
|
|||
* @return the whole list of {@link ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException;
|
||||
public List<ClassificationResult<T>> getClasses(String text, int max) throws IOException;
|
||||
|
||||
/**
|
||||
* Train the classifier using the underlying Lucene index
|
||||
|
|
Loading…
Reference in New Issue