From fcf4389d82e440d078f61ed9ad8c6dedce10d124 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Tue, 21 Jun 2016 13:10:34 +0200 Subject: [PATCH] LUCENE-7350 - Let classifiers be constructed from IndexReaders --- .../BooleanPerceptronClassifier.java | 19 ++++----- .../CachingNaiveBayesClassifier.java | 8 ++-- .../KNearestNeighborClassifier.java | 9 +++-- .../SimpleNaiveBayesClassifier.java | 24 +++++------ .../KNearestNeighborDocumentClassifier.java | 7 ++-- .../SimpleNaiveBayesDocumentClassifier.java | 15 +++---- .../utils/ConfusionMatrixGenerator.java | 8 ++-- .../classification/utils/DatasetSplitter.java | 20 ++++++---- .../DocumentClassificationTestBase.java | 11 +++-- ...NearestNeighborDocumentClassifierTest.java | 40 +++++++++---------- ...impleNaiveBayesDocumentClassifierTest.java | 32 +++++++-------- .../ClassificationUpdateProcessor.java | 3 +- .../ClassificationUpdateProcessorFactory.java | 5 ++- 13 files changed, 105 insertions(+), 96 deletions(-) diff --git a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java index 3d8e75b9cde..760e66d1869 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java @@ -26,6 +26,7 @@ import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.MultiFields; @@ -67,7 +68,7 @@ public class BooleanPerceptronClassifier implements Classifier { /** * Creates a {@link BooleanPerceptronClassifier} * - * @param leafReader the reader on the index to be used for classification + * @param indexReader the reader on the index to be used for classification * @param analyzer an {@link Analyzer} used to analyze unseen text * @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null} * if all the indexed docs should be used @@ -78,9 +79,9 @@ public class BooleanPerceptronClassifier implements Classifier { * @throws IOException if the building of the underlying {@link FST} fails and / or {@link TermsEnum} for the text field * cannot be found */ - public BooleanPerceptronClassifier(LeafReader leafReader, Analyzer analyzer, Query query, Integer batchSize, + public BooleanPerceptronClassifier(IndexReader indexReader, Analyzer analyzer, Query query, Integer batchSize, Double threshold, String classFieldName, String textFieldName) throws IOException { - this.textTerms = MultiFields.getTerms(leafReader, textFieldName); + this.textTerms = MultiFields.getTerms(indexReader, textFieldName); if (textTerms == null) { throw new IOException("term vectors need to be available for field " + textFieldName); @@ -91,7 +92,7 @@ public class BooleanPerceptronClassifier implements Classifier { if (threshold == null || threshold == 0d) { // automatic assign a threshold - long sumDocFreq = leafReader.getSumDocFreq(textFieldName); + long sumDocFreq = indexReader.getSumDocFreq(textFieldName); if (sumDocFreq != -1) { this.threshold = (double) sumDocFreq / 2d; } else { @@ -113,7 +114,7 @@ public class BooleanPerceptronClassifier implements Classifier { } updateFST(weights); - IndexSearcher indexSearcher = new IndexSearcher(leafReader); + IndexSearcher indexSearcher = new IndexSearcher(indexReader); int batchCount = 0; @@ -140,7 +141,7 @@ public class BooleanPerceptronClassifier implements Classifier { Boolean correctClass = Boolean.valueOf(classField.stringValue()); long modifier = correctClass.compareTo(assignedClass); if (modifier != 0) { - updateWeights(leafReader, scoreDoc.doc, assignedClass, + updateWeights(indexReader, scoreDoc.doc, assignedClass, weights, modifier, batchCount % batchSize == 0); } batchCount++; @@ -149,13 +150,13 @@ public class BooleanPerceptronClassifier implements Classifier { weights.clear(); // free memory while waiting for GC } - private void updateWeights(LeafReader leafReader, + private void updateWeights(IndexReader indexReader, int docId, Boolean assignedClass, SortedMap weights, double modifier, boolean updateFST) throws IOException { TermsEnum cte = textTerms.iterator(); // get the doc term vectors - Terms terms = leafReader.getTermVector(docId, textFieldName); + Terms terms = indexReader.getTermVector(docId, textFieldName); if (terms == null) { throw new IOException("term vectors must be stored for field " @@ -201,7 +202,7 @@ public class BooleanPerceptronClassifier implements Classifier { @Override public ClassificationResult assignClass(String text) throws IOException { - Long output = 0l; + Long output = 0L; try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) { CharTermAttribute charTermAttribute = tokenStream .addAttribute(CharTermAttribute.class); diff --git a/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java index ec56a919fa1..b87b8d82bc7 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java @@ -212,7 +212,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier { // build the cache for the word Map frequencyMap = new HashMap<>(); for (String textFieldName : textFieldNames) { - TermsEnum termsEnum = leafReader.terms(textFieldName).iterator(); + TermsEnum termsEnum = MultiFields.getTerms(indexReader, textFieldName).iterator(); while (termsEnum.next() != null) { BytesRef term = termsEnum.term(); String termText = term.utf8ToString(); @@ -229,7 +229,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier { } // fill the class list - Terms terms = MultiFields.getTerms(leafReader, classFieldName); + Terms terms = MultiFields.getTerms(indexReader, classFieldName); TermsEnum termsEnum = terms.iterator(); while ((termsEnum.next()) != null) { cclasses.add(BytesRef.deepCopyOf(termsEnum.term())); @@ -238,11 +238,11 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier { for (BytesRef cclass : cclasses) { double avgNumberOfUniqueTerms = 0; for (String textFieldName : textFieldNames) { - terms = MultiFields.getTerms(leafReader, textFieldName); + terms = MultiFields.getTerms(indexReader, textFieldName); long numPostings = terms.getSumDocFreq(); // number of term/doc pairs avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); } - int docsWithC = leafReader.docFreq(new Term(classFieldName, cclass)); + int docsWithC = indexReader.docFreq(new Term(classFieldName, cclass)); classTermFreq.put(cclass, avgNumberOfUniqueTerms * docsWithC); } } diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java index 1d7cf492379..c4f2c2f53a5 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.Term; @@ -82,7 +83,7 @@ public class KNearestNeighborClassifier implements Classifier { /** * Creates a {@link KNearestNeighborClassifier}. * - * @param leafReader the reader on the index to be used for classification + * @param indexReader the reader on the index to be used for classification * @param analyzer an {@link Analyzer} used to analyze unseen text * @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null} * (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity}) @@ -94,14 +95,14 @@ public class KNearestNeighborClassifier implements Classifier { * @param classFieldName the name of the field used as the output for the classifier * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10 */ - public KNearestNeighborClassifier(LeafReader leafReader, Similarity similarity, Analyzer analyzer, Query query, int k, int minDocsFreq, + public KNearestNeighborClassifier(IndexReader indexReader, Similarity similarity, Analyzer analyzer, Query query, int k, int minDocsFreq, int minTermFreq, String classFieldName, String... textFieldNames) { this.textFieldNames = textFieldNames; this.classFieldName = classFieldName; - this.mlt = new MoreLikeThis(leafReader); + this.mlt = new MoreLikeThis(indexReader); this.mlt.setAnalyzer(analyzer); this.mlt.setFieldNames(textFieldNames); - this.indexSearcher = new IndexSearcher(leafReader); + this.indexSearcher = new IndexSearcher(indexReader); if (similarity != null) { this.indexSearcher.setSimilarity(similarity); } else { diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java index 2514ae1e644..3509df58511 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java @@ -26,7 +26,7 @@ import java.util.List; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; -import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.MultiFields; import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; @@ -48,10 +48,10 @@ import org.apache.lucene.util.BytesRef; public class SimpleNaiveBayesClassifier implements Classifier { /** - * {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene.classification.Classifier}'s + * {@link org.apache.lucene.index.IndexReader} used to access the {@link org.apache.lucene.classification.Classifier}'s * index */ - protected final LeafReader leafReader; + protected final IndexReader indexReader; /** * names of the fields to be used as input text @@ -81,7 +81,7 @@ public class SimpleNaiveBayesClassifier implements Classifier { /** * Creates a new NaiveBayes classifier. * - * @param leafReader the reader on the index to be used for classification + * @param indexReader the reader on the index to be used for classification * @param analyzer an {@link Analyzer} used to analyze unseen text * @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null} * if all the indexed docs should be used @@ -89,9 +89,9 @@ public class SimpleNaiveBayesClassifier implements Classifier { * as the returned class will be a token indexed for this field * @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field */ - public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) { - this.leafReader = leafReader; - this.indexSearcher = new IndexSearcher(this.leafReader); + public SimpleNaiveBayesClassifier(IndexReader indexReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) { + this.indexReader = indexReader; + this.indexSearcher = new IndexSearcher(this.indexReader); this.textFieldNames = textFieldNames; this.classFieldName = classFieldName; this.analyzer = analyzer; @@ -144,7 +144,7 @@ public class SimpleNaiveBayesClassifier implements Classifier { protected List> assignClassNormalizedList(String inputDocument) throws IOException { List> assignedClasses = new ArrayList<>(); - Terms classes = MultiFields.getTerms(leafReader, classFieldName); + Terms classes = MultiFields.getTerms(indexReader, classFieldName); if (classes != null) { TermsEnum classesEnum = classes.iterator(); BytesRef next; @@ -169,7 +169,7 @@ public class SimpleNaiveBayesClassifier implements Classifier { * @throws IOException if accessing to term vectors or search fails */ protected int countDocsWithClass() throws IOException { - Terms terms = MultiFields.getTerms(this.leafReader, this.classFieldName); + Terms terms = MultiFields.getTerms(this.indexReader, this.classFieldName); int docCount; if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector(); @@ -240,11 +240,11 @@ public class SimpleNaiveBayesClassifier implements Classifier { private double getTextTermFreqForClass(Term term) throws IOException { double avgNumberOfUniqueTerms = 0; for (String textFieldName : textFieldNames) { - Terms terms = MultiFields.getTerms(leafReader, textFieldName); + Terms terms = MultiFields.getTerms(indexReader, textFieldName); long numPostings = terms.getSumDocFreq(); // number of term/doc pairs avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc } - int docsWithC = leafReader.docFreq(term); + int docsWithC = indexReader.docFreq(term); return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c } @@ -277,7 +277,7 @@ public class SimpleNaiveBayesClassifier implements Classifier { } private int docCount(Term term) throws IOException { - return leafReader.docFreq(term); + return indexReader.docFreq(term); } /** diff --git a/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java index c0d709a415d..248b2eb8fcb 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java @@ -27,6 +27,7 @@ import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.classification.ClassificationResult; import org.apache.lucene.classification.KNearestNeighborClassifier; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause; @@ -54,7 +55,7 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi /** * Creates a {@link KNearestNeighborClassifier}. * - * @param leafReader the reader on the index to be used for classification + * @param indexReader the reader on the index to be used for classification * @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null} * (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity}) * @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null} @@ -66,9 +67,9 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi * @param field2analyzer map with key a field name and the related {org.apache.lucene.analysis.Analyzer} * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10 */ - public KNearestNeighborDocumentClassifier(LeafReader leafReader, Similarity similarity, Query query, int k, int minDocsFreq, + public KNearestNeighborDocumentClassifier(IndexReader indexReader, Similarity similarity, Query query, int k, int minDocsFreq, int minTermFreq, String classFieldName, Map field2analyzer, String... textFieldNames) { - super(leafReader,similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames); + super(indexReader, similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames); this.field2analyzer = field2analyzer; } diff --git a/lucene/classification/src/java/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.java index 3fbe556f96b..416d09739f2 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.java @@ -32,6 +32,7 @@ import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.classification.ClassificationResult; import org.apache.lucene.classification.SimpleNaiveBayesClassifier; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.MultiFields; @@ -59,15 +60,15 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi /** * Creates a new NaiveBayes classifier. * - * @param leafReader the reader on the index to be used for classification + * @param indexReader the reader on the index to be used for classification * @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null} * if all the indexed docs should be used * @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed * as the returned class will be a token indexed for this field * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10 */ - public SimpleNaiveBayesDocumentClassifier(LeafReader leafReader, Query query, String classFieldName, Map field2analyzer, String... textFieldNames) { - super(leafReader, null, query, classFieldName, textFieldNames); + public SimpleNaiveBayesDocumentClassifier(IndexReader indexReader, Query query, String classFieldName, Map field2analyzer, String... textFieldNames) { + super(indexReader, null, query, classFieldName, textFieldNames); this.field2analyzer = field2analyzer; } @@ -112,7 +113,7 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi List> assignedClasses = new ArrayList<>(); Map> fieldName2tokensArray = new LinkedHashMap<>(); Map fieldName2boost = new LinkedHashMap<>(); - Terms classes = MultiFields.getTerms(leafReader, classFieldName); + Terms classes = MultiFields.getTerms(indexReader, classFieldName); TermsEnum classesEnum = classes.iterator(); BytesRef c; @@ -225,10 +226,10 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi */ private double getTextTermFreqForClass(Term term, String fieldName) throws IOException { double avgNumberOfUniqueTerms; - Terms terms = MultiFields.getTerms(leafReader, fieldName); + Terms terms = MultiFields.getTerms(indexReader, fieldName); long numPostings = terms.getSumDocFreq(); // number of term/doc pairs avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc - int docsWithC = leafReader.docFreq(term); + int docsWithC = indexReader.docFreq(term); return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c } @@ -261,6 +262,6 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi } private int docCount(Term term) throws IOException { - return leafReader.docFreq(term); + return indexReader.docFreq(term); } } diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java index 65de8015f20..bd9a0d9bd3c 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java @@ -31,7 +31,7 @@ import java.util.concurrent.TimeoutException; import org.apache.lucene.classification.ClassificationResult; import org.apache.lucene.classification.Classifier; import org.apache.lucene.document.Document; -import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TermRangeQuery; @@ -50,9 +50,9 @@ public class ConfusionMatrixGenerator { /** * get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier}, - * generated on the given {@link LeafReader}, class and text fields. + * generated on the given {@link IndexReader}, class and text fields. * - * @param reader the {@link LeafReader} containing the index used for creating the {@link Classifier} + * @param reader the {@link IndexReader} containing the index used for creating the {@link Classifier} * @param classifier the {@link Classifier} whose confusion matrix has to be generated * @param classFieldName the name of the Lucene field used as the classifier's output * @param textFieldName the nome the Lucene field used as the classifier's input @@ -61,7 +61,7 @@ public class ConfusionMatrixGenerator { * @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} * @throws IOException if problems occurr while reading the index or using the classifier */ - public static ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier classifier, String classFieldName, + public static ConfusionMatrix getConfusionMatrix(IndexReader reader, Classifier classifier, String classFieldName, String textFieldName, long timeoutMilliseconds) throws IOException { ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-")); diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java index 95239505296..374624b41c6 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java @@ -18,15 +18,17 @@ package org.apache.lucene.classification.utils; import java.io.IOException; + import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.TextField; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexableField; -import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; @@ -69,7 +71,7 @@ public class DatasetSplitter { * @param fieldNames names of fields that need to be put in the new indexes or null if all should be used * @throws IOException if any writing operation fails on any of the indexes */ - public void split(LeafReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, + public void split(IndexReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException { // create IWs for train / test / cv IDXs @@ -78,13 +80,15 @@ public class DatasetSplitter { IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer)); // get the exact no. of existing classes - SortedDocValues classValues = originalIndex.getSortedDocValues(classFieldName); - if (classValues == null) { - throw new IllegalStateException("the classFieldName \"" + classFieldName + "\" must index sorted doc values"); + int noOfClasses = 0; + for (LeafReaderContext leave : originalIndex.leaves()) { + SortedDocValues classValues = leave.reader().getSortedDocValues(classFieldName); + if (classValues == null) { + throw new IllegalStateException("the classFieldName \"" + classFieldName + "\" must index sorted doc values"); + } + noOfClasses += classValues.getValueCount(); } - int noOfClasses = classValues.getValueCount(); - try { IndexSearcher indexSearcher = new IndexSearcher(originalIndex); @@ -150,7 +154,7 @@ public class DatasetSplitter { } } - private Document createNewDoc(LeafReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException { + private Document createNewDoc(IndexReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException { Document doc = new Document(); Document document = originalIndex.document(scoreDoc.doc); if (fieldNames != null && fieldNames.length > 0) { diff --git a/lucene/classification/src/test/org/apache/lucene/classification/document/DocumentClassificationTestBase.java b/lucene/classification/src/test/org/apache/lucene/classification/document/DocumentClassificationTestBase.java index 4193bde1679..3848151c5f8 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/document/DocumentClassificationTestBase.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/document/DocumentClassificationTestBase.java @@ -27,8 +27,8 @@ import org.apache.lucene.classification.ClassificationResult; import org.apache.lucene.classification.ClassificationTestBase; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.util.BytesRef; import org.junit.Before; @@ -47,7 +47,7 @@ public abstract class DocumentClassificationTestBase extends ClassificationTe protected Analyzer analyzer; protected Map field2analyzer; - protected LeafReader leafReader; + protected IndexReader indexReader; @Before public void init() throws IOException { @@ -56,7 +56,7 @@ public abstract class DocumentClassificationTestBase extends ClassificationTe field2analyzer.put(textFieldName, analyzer); field2analyzer.put(titleFieldName, analyzer); field2analyzer.put(authorFieldName, analyzer); - leafReader = populateDocumentClassificationIndex(analyzer); + indexReader = populateDocumentClassificationIndex(analyzer); } protected double checkCorrectDocumentClassification(DocumentClassifier classifier, Document inputDoc, T expectedResult) throws Exception { @@ -68,7 +68,7 @@ public abstract class DocumentClassificationTestBase extends ClassificationTe return score; } - protected LeafReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException { + protected IndexReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException { indexWriter.close(); indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE)); indexWriter.commit(); @@ -201,8 +201,7 @@ public abstract class DocumentClassificationTestBase extends ClassificationTe indexWriter.addDocument(doc); indexWriter.commit(); - indexWriter.forceMerge(1); - return getOnlyLeafReader(indexWriter.getReader()); + return indexWriter.getReader(); } protected Document getVideoGameDocument() { diff --git a/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java index 74152f6a45c..8c885fb9c72 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java @@ -33,15 +33,15 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati try { Document videoGameDocument = getVideoGameDocument(); Document batmanDocument = getBatmanDocument(); - checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT); - checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT); // considering only the text we have wrong classification because the text was ambiguos on purpose - checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT); - checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT); } finally { - if (leafReader != null) { - leafReader.close(); + if (indexReader != null) { + indexReader.close(); } } } @@ -51,18 +51,18 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati try { Document videoGameDocument = getVideoGameDocument(); Document batmanDocument = getBatmanDocument(); - double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT); + double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT); assertEquals(1.0,score1,0); - double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT); + double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT); assertEquals(1.0,score2,0); // considering only the text we have wrong classification because the text was ambiguos on purpose - double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT); + double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT); assertEquals(1.0,score3,0); - double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT); + double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT); assertEquals(1.0,score4,0); } finally { - if (leafReader != null) { - leafReader.close(); + if (indexReader != null) { + indexReader.close(); } } } @@ -70,12 +70,12 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati @Test public void testBoostedDocumentClassification() throws Exception { try { - checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT); // considering without boost wrong classification will appear - checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT); } finally { - if (leafReader != null) { - leafReader.close(); + if (indexReader != null) { + indexReader.close(); } } } @@ -84,11 +84,11 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati public void testBasicDocumentClassificationWithQuery() throws Exception { try { TermQuery query = new TermQuery(new Term(authorFieldName, "ign")); - checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT); - checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT); + checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT); } finally { - if (leafReader != null) { - leafReader.close(); + if (indexReader != null) { + indexReader.close(); } } } diff --git a/lucene/classification/src/test/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifierTest.java index 7f6630c212c..b42bcc631d1 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifierTest.java @@ -28,14 +28,14 @@ public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificati @Test public void testBasicDocumentClassification() throws Exception { try { - checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT); - checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT); + checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT); + checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT); - checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT); - checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT); + checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT); + checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT); } finally { - if (leafReader != null) { - leafReader.close(); + if (indexReader != null) { + indexReader.close(); } } } @@ -43,18 +43,18 @@ public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificati @Test public void testBasicDocumentClassificationScore() throws Exception { try { - double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT); + double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT); assertEquals(0.88,score1,0.01); - double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT); + double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT); assertEquals(0.89,score2,0.01); //taking in consideration only the text - double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT); + double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT); assertEquals(0.55,score3,0.01); - double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT); + double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT); assertEquals(0.52,score4,0.01); } finally { - if (leafReader != null) { - leafReader.close(); + if (indexReader != null) { + indexReader.close(); } } } @@ -62,12 +62,12 @@ public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificati @Test public void testBoostedDocumentClassification() throws Exception { try { - checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT); + checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT); // considering without boost wrong classification will appear - checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT); + checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT); } finally { - if (leafReader != null) { - leafReader.close(); + if (indexReader != null) { + indexReader.close(); } } } diff --git a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java index 9344fb9a5fd..125f6d0e301 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java +++ b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java @@ -10,6 +10,7 @@ import org.apache.lucene.classification.document.DocumentClassifier; import org.apache.lucene.classification.document.KNearestNeighborDocumentClassifier; import org.apache.lucene.classification.document.SimpleNaiveBayesDocumentClassifier; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; import org.apache.lucene.util.BytesRef; import org.apache.solr.common.SolrInputDocument; @@ -60,7 +61,7 @@ class ClassificationUpdateProcessor * @param schema schema */ public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm, - UpdateRequestProcessor next, LeafReader indexReader, IndexSchema schema) { + UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) { super(next); this.classFieldName = classFieldName; Map field2analyzer = new HashMap(); diff --git a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java index a70f21d40be..47a32c0136f 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java +++ b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java @@ -1,5 +1,6 @@ package org.apache.solr.update.processor; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; import org.apache.solr.common.SolrException; import org.apache.solr.common.params.SolrParams; @@ -109,8 +110,8 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor @Override public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) { IndexSchema schema = req.getSchema(); - LeafReader leafReader = req.getSearcher().getLeafReader(); - return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, leafReader, schema); + IndexReader indexReader = req.getSearcher().getIndexReader(); + return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, indexReader, schema); } /**