From debb363f6e3b5ec9b52af156da515735b29628d6 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Mon, 11 Nov 2013 13:17:36 +0000 Subject: [PATCH] LUCENE-5338 - avoid considering unlabeled documents for training git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1540703 13f79535-47bb-0310-9956-ffa450edef68 --- .../BooleanPerceptronClassifier.java | 15 ++++++++------- .../KNearestNeighborClassifier.java | 14 ++++++-------- .../SimpleNaiveBayesClassifier.java | 13 ++++--------- .../classification/ClassificationTestBase.java | 8 ++++++-- 4 files changed, 24 insertions(+), 26 deletions(-) diff --git a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java index 814923b499a..5eddf473418 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java @@ -23,12 +23,15 @@ import org.apache.lucene.index.AtomicReader; import org.apache.lucene.index.MultiFields; import org.apache.lucene.index.StorableField; import org.apache.lucene.index.StoredDocument; +import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.WildcardQuery; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IntsRef; import org.apache.lucene.util.fst.Builder; @@ -160,14 +163,12 @@ public class BooleanPerceptronClassifier implements Classifier { int batchCount = 0; - Query q; + BooleanQuery q = new BooleanQuery(); + q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, "*")), BooleanClause.Occur.MUST)); if (query != null) { - q = query; + q.add(new BooleanClause(query, BooleanClause.Occur.MUST)); } - else { - q = new MatchAllDocsQuery(); - } - // do a *:* search and use stored field values + // run the search and use stored field values for (ScoreDoc scoreDoc : indexSearcher.search(q, Integer.MAX_VALUE).scoreDocs) { StoredDocument doc = indexSearcher.doc(scoreDoc.doc); diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java index 2ba9887be10..e21e670d2d5 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java @@ -18,6 +18,7 @@ package org.apache.lucene.classification; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.index.AtomicReader; +import org.apache.lucene.index.Term; import org.apache.lucene.queries.mlt.MoreLikeThis; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; @@ -25,6 +26,7 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.WildcardQuery; import org.apache.lucene.util.BytesRef; import java.io.IOException; @@ -64,20 +66,16 @@ public class KNearestNeighborClassifier implements Classifier { if (mlt == null) { throw new IOException("You must first call Classifier#train"); } - Query q; BooleanQuery mltQuery = new BooleanQuery(); for (String textFieldName : textFieldNames) { mltQuery.add(new BooleanClause(mlt.like(new StringReader(text), textFieldName), BooleanClause.Occur.SHOULD)); } + Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*")); + mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST)); if (query != null) { - BooleanQuery bq = new BooleanQuery(); - bq.add(query, BooleanClause.Occur.MUST); - bq.add(mltQuery, BooleanClause.Occur.MUST); - q = bq; - } else { - q = mltQuery; + mltQuery.add(query, BooleanClause.Occur.MUST); } - TopDocs topDocs = indexSearcher.search(q, k); + TopDocs topDocs = indexSearcher.search(mltQuery, k); return selectClassFromNeighbors(topDocs); } diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java index db644f2ec1e..d1393523c09 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java @@ -102,15 +102,10 @@ public class SimpleNaiveBayesClassifier implements Classifier { int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount(); if (docCount == -1) { // in case codec doesn't support getDocCount TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); - Query q; + BooleanQuery q = new BooleanQuery(); + q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST)); if (query != null) { - BooleanQuery bq = new BooleanQuery(); - WildcardQuery wq = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))); - bq.add(wq, BooleanClause.Occur.MUST); - bq.add(query, BooleanClause.Occur.MUST); - q = bq; - } else { - q = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))); + q.add(query, BooleanClause.Occur.MUST); } indexSearcher.search(q, totalHitCountCollector); @@ -191,7 +186,7 @@ public class SimpleNaiveBayesClassifier implements Classifier { avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc } int docsWithC = atomicReader.docFreq(new Term(classFieldName, c)); - return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c + return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c } private int getWordFreqForClass(String word, BytesRef c) throws IOException { diff --git a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java index 1b20be924a7..cd488043fd5 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java @@ -94,12 +94,11 @@ public abstract class ClassificationTestBase extends LuceneTestCase { protected void checkPerformance(Classifier classifier, Analyzer analyzer, String classFieldName) throws Exception { AtomicReader atomicReader = null; long trainStart = System.currentTimeMillis(); - long trainEnd = 0l; try { populatePerformanceIndex(analyzer); atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); classifier.train(atomicReader, textFieldName, classFieldName, analyzer); - trainEnd = System.currentTimeMillis(); + long trainEnd = System.currentTimeMillis(); long trainTime = trainEnd - trainStart; assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000); } finally { @@ -212,6 +211,11 @@ public abstract class ClassificationTestBase extends LuceneTestCase { doc.add(new Field(booleanFieldName, "false", ft)); indexWriter.addDocument(doc, analyzer); + doc = new Document(); + text = "unlabeled doc"; + doc.add(new Field(textFieldName, text, ft)); + indexWriter.addDocument(doc, analyzer); + indexWriter.commit(); } }