mirror of https://github.com/apache/lucene.git
LUCENE-7838 - added knn classifier based on flt
This commit is contained in:
parent
afd70a48cc
commit
bd9e32d358
|
@ -16,8 +16,9 @@
|
|||
<orderEntry type="module" scope="TEST" module-name="lucene-test-framework" />
|
||||
<orderEntry type="module" module-name="lucene-core" />
|
||||
<orderEntry type="module" module-name="queries" />
|
||||
<orderEntry type="module" scope="TEST" module-name="analysis-common" />
|
||||
<orderEntry type="module" module-name="analysis-common" />
|
||||
<orderEntry type="module" module-name="grouping" />
|
||||
<orderEntry type="module" module-name="misc" />
|
||||
<orderEntry type="module" module-name="sandbox" />
|
||||
</component>
|
||||
</module>
|
||||
|
|
|
@ -28,6 +28,8 @@
|
|||
<path refid="base.classpath"/>
|
||||
<pathelement path="${queries.jar}"/>
|
||||
<pathelement path="${grouping.jar}"/>
|
||||
<pathelement path="${sandbox.jar}"/>
|
||||
<pathelement path="${analyzers-common.jar}"/>
|
||||
</path>
|
||||
|
||||
<path id="test.classpath">
|
||||
|
@ -36,16 +38,18 @@
|
|||
<path refid="test.base.classpath"/>
|
||||
</path>
|
||||
|
||||
<target name="compile-core" depends="jar-grouping,jar-queries,jar-analyzers-common,common.compile-core" />
|
||||
<target name="compile-core" depends="jar-sandbox,jar-grouping,jar-queries,jar-analyzers-common,common.compile-core" />
|
||||
|
||||
<target name="jar-core" depends="common.jar-core" />
|
||||
|
||||
<target name="javadocs" depends="javadocs-grouping,compile-core,check-javadocs-uptodate"
|
||||
<target name="javadocs" depends="javadocs-sandbox,javadocs-grouping,compile-core,check-javadocs-uptodate"
|
||||
unless="javadocs-uptodate-${name}">
|
||||
<invoke-module-javadoc>
|
||||
<links>
|
||||
<link href="../queries"/>
|
||||
<link href="../analyzers/common"/>
|
||||
<link href="../grouping"/>
|
||||
<link href="../sandbox"/>
|
||||
</links>
|
||||
</invoke-module-javadoc>
|
||||
</target>
|
||||
|
|
|
@ -0,0 +1,225 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.classification;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexableField;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.sandbox.queries.FuzzyLikeThisQuery;
|
||||
import org.apache.lucene.search.BooleanClause;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.WildcardQuery;
|
||||
import org.apache.lucene.search.similarities.BM25Similarity;
|
||||
import org.apache.lucene.search.similarities.ClassicSimilarity;
|
||||
import org.apache.lucene.search.similarities.Similarity;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* A k-Nearest Neighbor classifier based on {@link FuzzyLikeThisQuery}.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class KNearestFuzzyClassifier implements Classifier<BytesRef> {
|
||||
|
||||
/**
|
||||
* the name of the fields used as the input text
|
||||
*/
|
||||
protected final String[] textFieldNames;
|
||||
|
||||
/**
|
||||
* the name of the field used as the output text
|
||||
*/
|
||||
protected final String classFieldName;
|
||||
|
||||
/**
|
||||
* an {@link IndexSearcher} used to perform queries
|
||||
*/
|
||||
protected final IndexSearcher indexSearcher;
|
||||
|
||||
/**
|
||||
* the no. of docs to compare in order to find the nearest neighbor to the input text
|
||||
*/
|
||||
protected final int k;
|
||||
|
||||
/**
|
||||
* a {@link Query} used to filter the documents that should be used from this classifier's underlying {@link LeafReader}
|
||||
*/
|
||||
protected final Query query;
|
||||
private final Analyzer analyzer;
|
||||
|
||||
/**
|
||||
* Creates a {@link KNearestFuzzyClassifier}.
|
||||
*
|
||||
* @param indexReader the reader on the index to be used for classification
|
||||
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
||||
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
|
||||
* (defaults to {@link BM25Similarity})
|
||||
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||
* if all the indexed docs should be used
|
||||
* @param k the no. of docs to select in the MLT results to find the nearest neighbor
|
||||
* @param classFieldName the name of the field used as the output for the classifier
|
||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
||||
*/
|
||||
public KNearestFuzzyClassifier(IndexReader indexReader, Similarity similarity, Analyzer analyzer, Query query, int k,
|
||||
String classFieldName, String... textFieldNames) {
|
||||
this.textFieldNames = textFieldNames;
|
||||
this.classFieldName = classFieldName;
|
||||
this.analyzer = analyzer;
|
||||
this.indexSearcher = new IndexSearcher(indexReader);
|
||||
if (similarity != null) {
|
||||
this.indexSearcher.setSimilarity(similarity);
|
||||
} else {
|
||||
this.indexSearcher.setSimilarity(new BM25Similarity());
|
||||
}
|
||||
this.query = query;
|
||||
this.k = k;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
|
||||
TopDocs knnResults = knnSearch(text);
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
|
||||
ClassificationResult<BytesRef> assignedClass = null;
|
||||
double maxscore = -Double.MAX_VALUE;
|
||||
for (ClassificationResult<BytesRef> cl : assignedClasses) {
|
||||
if (cl.getScore() > maxscore) {
|
||||
assignedClass = cl;
|
||||
maxscore = cl.getScore();
|
||||
}
|
||||
}
|
||||
return assignedClass;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
|
||||
TopDocs knnResults = knnSearch(text);
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
|
||||
Collections.sort(assignedClasses);
|
||||
return assignedClasses;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
|
||||
TopDocs knnResults = knnSearch(text);
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
|
||||
Collections.sort(assignedClasses);
|
||||
return assignedClasses.subList(0, max);
|
||||
}
|
||||
|
||||
private TopDocs knnSearch(String text) throws IOException {
|
||||
BooleanQuery.Builder bq = new BooleanQuery.Builder();
|
||||
FuzzyLikeThisQuery fuzzyLikeThisQuery = new FuzzyLikeThisQuery(300, analyzer);
|
||||
for (String fieldName : textFieldNames) {
|
||||
fuzzyLikeThisQuery.addTerms(text, fieldName, 1f, 2); // TODO: make this parameters configurable
|
||||
}
|
||||
bq.add(fuzzyLikeThisQuery, BooleanClause.Occur.MUST);
|
||||
Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
|
||||
bq.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
|
||||
if (query != null) {
|
||||
bq.add(query, BooleanClause.Occur.MUST);
|
||||
}
|
||||
return indexSearcher.search(bq.build(), k);
|
||||
}
|
||||
|
||||
/**
|
||||
* build a list of classification results from search results
|
||||
*
|
||||
* @param topDocs the search results as a {@link TopDocs} object
|
||||
* @return a {@link List} of {@link ClassificationResult}, one for each existing class
|
||||
* @throws IOException if it's not possible to get the stored value of class field
|
||||
*/
|
||||
protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
|
||||
Map<BytesRef, Integer> classCounts = new HashMap<>();
|
||||
Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
|
||||
float maxScore = topDocs.getMaxScore();
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
IndexableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
|
||||
if (storableField != null) {
|
||||
BytesRef cl = new BytesRef(storableField.stringValue());
|
||||
//update count
|
||||
Integer count = classCounts.get(cl);
|
||||
if (count != null) {
|
||||
classCounts.put(cl, count + 1);
|
||||
} else {
|
||||
classCounts.put(cl, 1);
|
||||
}
|
||||
//update boost, the boost is based on the best score
|
||||
Double totalBoost = classBoosts.get(cl);
|
||||
double singleBoost = scoreDoc.score / maxScore;
|
||||
if (totalBoost != null) {
|
||||
classBoosts.put(cl, totalBoost + singleBoost);
|
||||
} else {
|
||||
classBoosts.put(cl, singleBoost);
|
||||
}
|
||||
}
|
||||
}
|
||||
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
||||
List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
|
||||
int sumdoc = 0;
|
||||
for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
|
||||
Integer count = entry.getValue();
|
||||
Double normBoost = classBoosts.get(entry.getKey()) / count; //the boost is normalized to be 0<b<1
|
||||
temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), (count * normBoost) / (double) k));
|
||||
sumdoc += count;
|
||||
}
|
||||
|
||||
//correction
|
||||
if (sumdoc < k) {
|
||||
for (ClassificationResult<BytesRef> cr : temporaryList) {
|
||||
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc));
|
||||
}
|
||||
} else {
|
||||
returnList = temporaryList;
|
||||
}
|
||||
return returnList;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "KNearestFuzzyClassifier{" +
|
||||
"textFieldNames=" + Arrays.toString(textFieldNames) +
|
||||
", classFieldName='" + classFieldName + '\'' +
|
||||
", k=" + k +
|
||||
", query=" + query +
|
||||
", similarity=" + indexSearcher.getSimilarity(true) +
|
||||
'}';
|
||||
}
|
||||
}
|
|
@ -121,7 +121,7 @@ public class DatasetSplitter {
|
|||
int b = 0;
|
||||
|
||||
// iterate over existing documents
|
||||
for (GroupDocs group : topGroups.groups) {
|
||||
for (GroupDocs<Object> group : topGroups.groups) {
|
||||
int totalHits = group.totalHits;
|
||||
double testSize = totalHits * testRatio;
|
||||
int tc = 0;
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.classification;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.analysis.Tokenizer;
|
||||
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
||||
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
||||
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
||||
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Testcase for {@link KNearestFuzzyClassifier}
|
||||
*/
|
||||
public class KNearestFuzzyClassifierTest extends ClassificationTestBase<BytesRef> {
|
||||
|
||||
@Test
|
||||
public void testBasicUsage() throws Exception {
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(leafReader, null, analyzer, null, 3, categoryFieldName, textFieldName);
|
||||
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
checkCorrectClassification(classifier, POLITICS_INPUT, POLITICS_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicUsageWithQuery() throws Exception {
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = getSampleIndex(analyzer);
|
||||
TermQuery query = new TermQuery(new Term(textFieldName, "not"));
|
||||
Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(leafReader, null, analyzer, query, 3, categoryFieldName, textFieldName);
|
||||
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPerformance() throws Exception {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
LeafReader leafReader = getRandomIndex(analyzer, 100);
|
||||
try {
|
||||
long trainStart = System.currentTimeMillis();
|
||||
Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(leafReader, null, analyzer, null, 3, categoryFieldName, textFieldName);
|
||||
long trainEnd = System.currentTimeMillis();
|
||||
long trainTime = trainEnd - trainStart;
|
||||
assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
|
||||
|
||||
long evaluationStart = System.currentTimeMillis();
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
|
||||
classifier, categoryFieldName, textFieldName, -1);
|
||||
assertNotNull(confusionMatrix);
|
||||
long evaluationEnd = System.currentTimeMillis();
|
||||
long evaluationTime = evaluationEnd - evaluationStart;
|
||||
assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue(5000 > avgClassificationTime);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
|
||||
Terms terms = MultiFields.getTerms(leafReader, categoryFieldName);
|
||||
TermsEnum iterator = terms.iterator();
|
||||
BytesRef term;
|
||||
while ((term = iterator.next()) != null) {
|
||||
String s = term.utf8ToString();
|
||||
recall = confusionMatrix.getRecall(s);
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
precision = confusionMatrix.getPrecision(s);
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure(s);
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
}
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -21,11 +21,13 @@ import java.io.IOException;
|
|||
import java.util.List;
|
||||
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.classification.BM25NBClassifier;
|
||||
import org.apache.lucene.classification.BooleanPerceptronClassifier;
|
||||
import org.apache.lucene.classification.CachingNaiveBayesClassifier;
|
||||
import org.apache.lucene.classification.ClassificationResult;
|
||||
import org.apache.lucene.classification.ClassificationTestBase;
|
||||
import org.apache.lucene.classification.Classifier;
|
||||
import org.apache.lucene.classification.KNearestFuzzyClassifier;
|
||||
import org.apache.lucene.classification.KNearestNeighborClassifier;
|
||||
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
|
@ -94,22 +96,43 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
Classifier<BytesRef> classifier = new SimpleNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
|
||||
classifier, categoryFieldName, textFieldName, -1);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure();
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
checkCM(confusionMatrix);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void checkCM(ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix) {
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure();
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetConfusionMatrixWithBM25NB() throws Exception {
|
||||
LeafReader reader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
reader = getSampleIndex(analyzer);
|
||||
Classifier<BytesRef> classifier = new BM25NBClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
|
||||
classifier, categoryFieldName, textFieldName, -1);
|
||||
checkCM(confusionMatrix);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -126,22 +149,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
Classifier<BytesRef> classifier = new CachingNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
|
||||
classifier, categoryFieldName, textFieldName, -1);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure();
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
checkCM(confusionMatrix);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -158,22 +166,24 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
|
||||
classifier, categoryFieldName, textFieldName, -1);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure();
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
checkCM(confusionMatrix);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetConfusionMatrixWithFLTKNN() throws Exception {
|
||||
LeafReader reader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
reader = getSampleIndex(analyzer);
|
||||
Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(reader, null, analyzer, null, 1, categoryFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
|
||||
classifier, categoryFieldName, textFieldName, -1);
|
||||
checkCM(confusionMatrix);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -190,22 +200,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
Classifier<Boolean> classifier = new BooleanPerceptronClassifier(reader, analyzer, null, 1, null, booleanFieldName, textFieldName);
|
||||
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
|
||||
classifier, booleanFieldName, textFieldName, -1);
|
||||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure();
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
checkCM(confusionMatrix);
|
||||
assertTrue(confusionMatrix.getPrecision("true") >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision("true") <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision("false") >= 0d);
|
||||
|
|
Loading…
Reference in New Issue