mirror of https://github.com/apache/lucene.git
LUCENE-6433 - using generics in Classifier#getClasses
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1674304 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
44a97f67b7
commit
dbf063d3d1
|
@ -173,21 +173,24 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
Integer.MAX_VALUE).scoreDocs) {
|
Integer.MAX_VALUE).scoreDocs) {
|
||||||
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
|
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
|
||||||
|
|
||||||
// assign class to the doc
|
StorableField textField = doc.getField(textFieldName);
|
||||||
ClassificationResult<Boolean> classificationResult = assignClass(doc
|
|
||||||
.getField(textFieldName).stringValue());
|
|
||||||
Boolean assignedClass = classificationResult.getAssignedClass();
|
|
||||||
|
|
||||||
// get the expected result
|
// get the expected result
|
||||||
StorableField field = doc.getField(classFieldName);
|
StorableField classField = doc.getField(classFieldName);
|
||||||
|
|
||||||
Boolean correctClass = Boolean.valueOf(field.stringValue());
|
if (textField != null && classField != null) {
|
||||||
long modifier = correctClass.compareTo(assignedClass);
|
// assign class to the doc
|
||||||
if (modifier != 0) {
|
ClassificationResult<Boolean> classificationResult = assignClass(textField.stringValue());
|
||||||
updateWeights(leafReader, scoreDoc.doc, assignedClass,
|
Boolean assignedClass = classificationResult.getAssignedClass();
|
||||||
weights, modifier, batchCount % batchSize == 0);
|
|
||||||
|
Boolean correctClass = Boolean.valueOf(classField.stringValue());
|
||||||
|
long modifier = correctClass.compareTo(assignedClass);
|
||||||
|
if (modifier != 0) {
|
||||||
|
updateWeights(leafReader, scoreDoc.doc, assignedClass,
|
||||||
|
weights, modifier, batchCount % batchSize == 0);
|
||||||
|
}
|
||||||
|
batchCount++;
|
||||||
}
|
}
|
||||||
batchCount++;
|
|
||||||
}
|
}
|
||||||
weights.clear(); // free memory while waiting for GC
|
weights.clear(); // free memory while waiting for GC
|
||||||
}
|
}
|
||||||
|
@ -246,18 +249,18 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public List<ClassificationResult<BytesRef>> getClasses(String text)
|
public List<ClassificationResult<Boolean>> getClasses(String text)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
return null;
|
throw new RuntimeException("not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max)
|
public List<ClassificationResult<Boolean>> getClasses(String text, int max)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
return null;
|
throw new RuntimeException("not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ public interface Classifier<T> {
|
||||||
* @return the whole list of {@link ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
|
* @return the whole list of {@link ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
|
||||||
* @throws IOException If there is a low-level I/O error.
|
* @throws IOException If there is a low-level I/O error.
|
||||||
*/
|
*/
|
||||||
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException;
|
public List<ClassificationResult<T>> getClasses(String text) throws IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
|
* Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
|
||||||
|
@ -58,7 +58,7 @@ public interface Classifier<T> {
|
||||||
* @return the whole list of {@link ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
|
* @return the whole list of {@link ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
|
||||||
* @throws IOException If there is a low-level I/O error.
|
* @throws IOException If there is a low-level I/O error.
|
||||||
*/
|
*/
|
||||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException;
|
public List<ClassificationResult<T>> getClasses(String text, int max) throws IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Train the classifier using the underlying Lucene index
|
* Train the classifier using the underlying Lucene index
|
||||||
|
|
Loading…
Reference in New Issue