From b4a343d6ba9efbc87e73d0e7689afd4fc6ccc488 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Sat, 2 Nov 2013 15:41:49 +0000 Subject: [PATCH] LUCENE-5311 - added support for training using multiple content fields for knn and naive bayes git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1538205 13f79535-47bb-0310-9956-ffa450edef68 --- .../BooleanPerceptronClassifier.java | 5 ++ .../classification/ClassificationResult.java | 4 +- .../lucene/classification/Classifier.java | 12 +++++ .../KNearestNeighborClassifier.java | 27 ++++++++-- .../SimpleNaiveBayesClassifier.java | 51 ++++++++++++++----- .../classification/utils/DatasetSplitter.java | 6 +-- 6 files changed, 82 insertions(+), 23 deletions(-) diff --git a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java index d20f722b8f5..814923b499a 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java @@ -191,6 +191,11 @@ public class BooleanPerceptronClassifier implements Classifier { weights.clear(); // free memory while waiting for GC } + @Override + public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException { + throw new IOException("training with multiple fields not supported by boolean perceptron classifier"); + } + private TermsEnum updateWeights(AtomicReader atomicReader, TermsEnum reuse, int docId, Boolean assignedClass, SortedMap weights, double modifier, boolean updateFST) throws IOException { diff --git a/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java b/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java index 9d60e4c3f87..1327308dc60 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java @@ -22,8 +22,8 @@ package org.apache.lucene.classification; */ public class ClassificationResult { - private T assignedClass; - private double score; + private final T assignedClass; + private final double score; /** * Constructor diff --git a/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java b/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java index e5d10973559..b5ae33097d3 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java @@ -60,4 +60,16 @@ public interface Classifier { public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException; + /** + * Train the classifier using the underlying Lucene index + * @param atomicReader the reader to use to access the Lucene index + * @param textFieldNames the names of the fields to be used to compare documents + * @param classFieldName the name of the field containing the class assigned to documents + * @param analyzer the analyzer used to tokenize / filter the unseen text + * @param query the query to filter which documents use for training + * @throws IOException If there is a low-level I/O error. + */ + public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) + throws IOException; + } diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java index 4084c611f72..2ba9887be10 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java @@ -41,10 +41,10 @@ import java.util.Map; public class KNearestNeighborClassifier implements Classifier { private MoreLikeThis mlt; - private String textFieldName; + private String[] textFieldNames; private String classFieldName; private IndexSearcher indexSearcher; - private int k; + private final int k; private Query query; /** @@ -65,14 +65,17 @@ public class KNearestNeighborClassifier implements Classifier { throw new IOException("You must first call Classifier#train"); } Query q; + BooleanQuery mltQuery = new BooleanQuery(); + for (String textFieldName : textFieldNames) { + mltQuery.add(new BooleanClause(mlt.like(new StringReader(text), textFieldName), BooleanClause.Occur.SHOULD)); + } if (query != null) { - Query mltQuery = mlt.like(new StringReader(text), textFieldName); BooleanQuery bq = new BooleanQuery(); bq.add(query, BooleanClause.Occur.MUST); bq.add(mltQuery, BooleanClause.Occur.MUST); q = bq; } else { - q = mlt.like(new StringReader(text), textFieldName); + q = mltQuery; } TopDocs topDocs = indexSearcher.search(q, k); return selectClassFromNeighbors(topDocs); @@ -116,7 +119,7 @@ public class KNearestNeighborClassifier implements Classifier { */ @Override public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException { - this.textFieldName = textFieldName; + this.textFieldNames = new String[]{textFieldName}; this.classFieldName = classFieldName; mlt = new MoreLikeThis(atomicReader); mlt.setAnalyzer(analyzer); @@ -124,4 +127,18 @@ public class KNearestNeighborClassifier implements Classifier { indexSearcher = new IndexSearcher(atomicReader); this.query = query; } + + /** + * {@inheritDoc} + */ + @Override + public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException { + this.textFieldNames = textFieldNames; + this.classFieldName = classFieldName; + mlt = new MoreLikeThis(atomicReader); + mlt.setAnalyzer(analyzer); + mlt.setFieldNames(textFieldNames); + indexSearcher = new IndexSearcher(atomicReader); + this.query = query; + } } diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java index fa6c637a7c9..db644f2ec1e 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java @@ -45,7 +45,7 @@ import java.util.LinkedList; public class SimpleNaiveBayesClassifier implements Classifier { private AtomicReader atomicReader; - private String textFieldName; + private String[] textFieldNames; private String classFieldName; private int docsWithClassSize; private Analyzer analyzer; @@ -68,18 +68,36 @@ public class SimpleNaiveBayesClassifier implements Classifier { throws IOException { this.atomicReader = atomicReader; this.indexSearcher = new IndexSearcher(this.atomicReader); - this.textFieldName = textFieldName; + this.textFieldNames = new String[]{textFieldName}; this.classFieldName = classFieldName; this.analyzer = analyzer; this.docsWithClassSize = countDocsWithClass(); this.query = query; } + /** + * {@inheritDoc} + */ @Override public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException { train(atomicReader, textFieldName, classFieldName, analyzer, null); } + /** + * {@inheritDoc} + */ + @Override + public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) + throws IOException { + this.atomicReader = atomicReader; + this.indexSearcher = new IndexSearcher(this.atomicReader); + this.textFieldNames = textFieldNames; + this.classFieldName = classFieldName; + this.analyzer = analyzer; + this.docsWithClassSize = countDocsWithClass(); + this.query = query; + } + private int countDocsWithClass() throws IOException { int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount(); if (docCount == -1) { // in case codec doesn't support getDocCount @@ -103,13 +121,15 @@ public class SimpleNaiveBayesClassifier implements Classifier { private String[] tokenizeDoc(String doc) throws IOException { Collection result = new LinkedList(); - try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) { - CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class); - tokenStream.reset(); - while (tokenStream.incrementToken()) { - result.add(charTermAttribute.toString()); + for (String textFieldName : textFieldNames) { + try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) { + CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class); + tokenStream.reset(); + while (tokenStream.incrementToken()) { + result.add(charTermAttribute.toString()); + } + tokenStream.end(); } - tokenStream.end(); } return result.toArray(new String[result.size()]); } @@ -164,16 +184,23 @@ public class SimpleNaiveBayesClassifier implements Classifier { } private double getTextTermFreqForClass(BytesRef c) throws IOException { - Terms terms = MultiFields.getTerms(atomicReader, textFieldName); - long numPostings = terms.getSumDocFreq(); // number of term/doc pairs - double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc + double avgNumberOfUniqueTerms = 0; + for (String textFieldName : textFieldNames) { + Terms terms = MultiFields.getTerms(atomicReader, textFieldName); + long numPostings = terms.getSumDocFreq(); // number of term/doc pairs + avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc + } int docsWithC = atomicReader.docFreq(new Term(classFieldName, c)); return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c } private int getWordFreqForClass(String word, BytesRef c) throws IOException { BooleanQuery booleanQuery = new BooleanQuery(); - booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST)); + BooleanQuery subQuery = new BooleanQuery(); + for (String textFieldName : textFieldNames) { + subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD)); + } + booleanQuery.add(new BooleanClause(subQuery, BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST)); if (query != null) { booleanQuery.add(query, BooleanClause.Occur.MUST); diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java index faea4dcc03d..401da814b8d 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java @@ -40,8 +40,8 @@ import java.io.IOException; */ public class DatasetSplitter { - private double crossValidationRatio; - private double testRatio; + private final double crossValidationRatio; + private final double testRatio; /** * Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes @@ -68,8 +68,6 @@ public class DatasetSplitter { public void split(AtomicReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, Analyzer analyzer, String... fieldNames) throws IOException { - // TODO : check that the passed fields are stored in the original index - // create IWs for train / test / cv IDXs IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer)); IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer));