diff --git a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java new file mode 100644 index 00000000000..267ac99949c --- /dev/null +++ b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.classification; + +import java.io.IOException; +import java.io.StringReader; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.index.AtomicReader; +import org.apache.lucene.index.MultiFields; +import org.apache.lucene.index.StorableField; +import org.apache.lucene.index.StoredDocument; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.IntsRef; +import org.apache.lucene.util.fst.Builder; +import org.apache.lucene.util.fst.FST; +import org.apache.lucene.util.fst.PositiveIntOutputs; +import org.apache.lucene.util.fst.Util; + +/** + * A perceptron (see http://en.wikipedia.org/wiki/Perceptron) based + * Boolean {@link org.apache.lucene.classification.Classifier}. The + * weights are calculated using + * {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field + * and a per document basis and then a corresponding + * {@link org.apache.lucene.util.fst.FST} is used for class assignment. + * + * @lucene.experimental + */ +public class BooleanPerceptronClassifier implements Classifier { + + private Double threshold; + private final Integer batchSize; + private Terms textTerms; + private Analyzer analyzer; + private String textFieldName; + private FST fst; + + /** + * Create a {@link BooleanPerceptronClassifier} + * + * @param threshold + * the binary threshold for perceptron output evaluation + */ + public BooleanPerceptronClassifier(Double threshold, Integer batchSize) { + this.threshold = threshold; + this.batchSize = batchSize; + } + + /** + * Default constructor, no batch updates of FST, perceptron threshold is + * calculated via underlying index metrics during + * {@link #train(org.apache.lucene.index.AtomicReader, String, String, org.apache.lucene.analysis.Analyzer) + * training} + */ + public BooleanPerceptronClassifier() { + batchSize = 1; + } + + /** + * {@inheritDoc} + */ + @Override + public ClassificationResult assignClass(String text) + throws IOException { + if (textTerms == null) { + throw new IOException("You must first call Classifier#train"); + } + Long output = 0l; + TokenStream tokenStream = analyzer.tokenStream(textFieldName, + new StringReader(text)); + CharTermAttribute charTermAttribute = tokenStream + .addAttribute(CharTermAttribute.class); + tokenStream.reset(); + while (tokenStream.incrementToken()) { + String s = charTermAttribute.toString(); + Long d = Util.get(fst, new BytesRef(s)); + if (d != null) { + output += d; + } + } + tokenStream.end(); + tokenStream.close(); + + return new ClassificationResult<>(output >= threshold, output.doubleValue()); + } + + /** + * {@inheritDoc} + */ + @Override + public void train(AtomicReader atomicReader, String textFieldName, + String classFieldName, Analyzer analyzer) throws IOException { + this.textTerms = MultiFields.getTerms(atomicReader, textFieldName); + + if (textTerms == null) { + throw new IOException(new StringBuilder( + "term vectors need to be available for field ").append(textFieldName) + .toString()); + } + + this.analyzer = analyzer; + this.textFieldName = textFieldName; + + if (threshold == null || threshold == 0d) { + // automatic assign a threshold + long sumDocFreq = atomicReader.getSumDocFreq(textFieldName); + if (sumDocFreq != -1) { + this.threshold = (double) sumDocFreq / 2d; + } else { + throw new IOException( + "threshold cannot be assigned since term vectors for field " + + textFieldName + " do not exist"); + } + } + + // TODO : remove this map as soon as we have a writable FST + SortedMap weights = new TreeMap<>(); + + TermsEnum reuse = textTerms.iterator(null); + BytesRef textTerm; + while ((textTerm = reuse.next()) != null) { + weights.put(textTerm.utf8ToString(), (double) reuse.totalTermFreq()); + } + updateFST(weights); + + IndexSearcher indexSearcher = new IndexSearcher(atomicReader); + + int batchCount = 0; + + // do a *:* search and use stored field values + for (ScoreDoc scoreDoc : indexSearcher.search(new MatchAllDocsQuery(), + Integer.MAX_VALUE).scoreDocs) { + StoredDocument doc = indexSearcher.doc(scoreDoc.doc); + + // assign class to the doc + ClassificationResult classificationResult = assignClass(doc + .getField(textFieldName).stringValue()); + Boolean assignedClass = classificationResult.getAssignedClass(); + + // get the expected result + StorableField field = doc.getField(classFieldName); + + Boolean correctClass = Boolean.valueOf(field.stringValue()); + long modifier = correctClass.compareTo(assignedClass); + if (modifier != 0) { + reuse = updateWeights(atomicReader, reuse, scoreDoc.doc, assignedClass, + weights, modifier, batchCount % batchSize == 0); + } + batchCount++; + } + weights.clear(); // free memory while waiting for GC + } + + private TermsEnum updateWeights(AtomicReader atomicReader, TermsEnum reuse, + int docId, Boolean assignedClass, SortedMap weights, + double modifier, boolean updateFST) throws IOException { + TermsEnum cte = textTerms.iterator(reuse); + + // get the doc term vectors + Terms terms = atomicReader.getTermVector(docId, textFieldName); + + if (terms == null) { + throw new IOException("term vectors must be stored for field " + + textFieldName); + } + + TermsEnum termsEnum = terms.iterator(null); + + BytesRef term; + + while ((term = termsEnum.next()) != null) { + cte.seekExact(term); + if (assignedClass != null) { + long termFreqLocal = termsEnum.totalTermFreq(); + // update weights + Long previousValue = Util.get(fst, term); + String termString = term.utf8ToString(); + weights.put(termString, previousValue + modifier * termFreqLocal); + } + } + if (updateFST) { + updateFST(weights); + } + reuse = cte; + return reuse; + } + + private void updateFST(SortedMap weights) throws IOException { + PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton(); + Builder fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs); + BytesRef scratchBytes = new BytesRef(); + IntsRef scratchInts = new IntsRef(); + for (Map.Entry entry : weights.entrySet()) { + scratchBytes.copyChars(entry.getKey()); + fstBuilder.add(Util.toIntsRef(scratchBytes, scratchInts), entry + .getValue().longValue()); + } + fst = fstBuilder.finish(); + } + +} \ No newline at end of file diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java index e8069eea47a..bbaa0566d5b 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java @@ -59,7 +59,7 @@ public class KNearestNeighborClassifier implements Classifier { @Override public ClassificationResult assignClass(String text) throws IOException { if (mlt == null) { - throw new IOException("You must first call Classifier#train first"); + throw new IOException("You must first call Classifier#train"); } Query q = mlt.like(new StringReader(text), textFieldName); TopDocs topDocs = indexSearcher.search(q, k); @@ -71,13 +71,11 @@ public class KNearestNeighborClassifier implements Classifier { Map classCounts = new HashMap(); for (ScoreDoc scoreDoc : topDocs.scoreDocs) { BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue()); - if (cl != null) { - Integer count = classCounts.get(cl); - if (count != null) { - classCounts.put(cl, count + 1); - } else { - classCounts.put(cl, 1); - } + Integer count = classCounts.get(cl); + if (count != null) { + classCounts.put(cl, count + 1); + } else { + classCounts.put(cl, 1); } } double max = 0; diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java index 2fe5c832b18..74fc631ef28 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java @@ -102,7 +102,7 @@ public class SimpleNaiveBayesClassifier implements Classifier { @Override public ClassificationResult assignClass(String inputDocument) throws IOException { if (atomicReader == null) { - throw new IOException("You must first call Classifier#train first"); + throw new IOException("You must first call Classifier#train"); } double max = 0d; BytesRef foundClass = new BytesRef(); diff --git a/lucene/classification/src/java/org/apache/lucene/classification/package.html b/lucene/classification/src/java/org/apache/lucene/classification/package.html index 94b0ddc6b79..b68c1988a90 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/package.html +++ b/lucene/classification/src/java/org/apache/lucene/classification/package.html @@ -17,7 +17,7 @@ Uses already seen data (the indexed documents) to classify new documents. -Currently only contains a (simplistic) Lucene based Naive Bayes classifier -and a k-Nearest Neighbor classifier +Currently only contains a (simplistic) Lucene based Naive Bayes classifier, +a k-Nearest Neighbor classifier and a Perceptron based classifier diff --git a/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java new file mode 100644 index 00000000000..c6b7b10b543 --- /dev/null +++ b/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.classification; + +import org.apache.lucene.analysis.MockAnalyzer; +import org.junit.Test; + +/** + * Testcase for {@link org.apache.lucene.classification.BooleanPerceptronClassifier} + */ +public class BooleanPerceptronClassifierTest extends ClassificationTestBase { + + @Test + public void testBasicUsage() throws Exception { + checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName); + } + + @Test + public void testExplicitThreshold() throws Exception { + checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName); + } + + @Test + public void testPerformance() throws Exception { + checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName); + } + +} diff --git a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java index 2c5b604fcef..fe31b412159 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java @@ -27,9 +27,13 @@ import org.apache.lucene.index.SlowCompositeReaderWrapper; import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; +import org.apache.lucene.util._TestUtil; import org.junit.After; import org.junit.Before; +import java.io.IOException; +import java.util.Random; + /** * Base class for testing {@link Classifier}s */ @@ -41,8 +45,9 @@ public abstract class ClassificationTestBase extends LuceneTestCase { public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology"); private RandomIndexWriter indexWriter; - private String textFieldName; private Directory dir; + + String textFieldName; String categoryFieldName; String booleanFieldName; @@ -66,82 +71,141 @@ public abstract class ClassificationTestBase extends LuceneTestCase { } - protected void checkCorrectClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception { - AtomicReader compositeReaderWrapper = null; + protected void checkCorrectClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception { + AtomicReader atomicReader = null; try { - populateIndex(analyzer); - compositeReaderWrapper = SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); - classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer); + populateSampleIndex(analyzer); + atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); + classifier.train(atomicReader, textFieldName, classFieldName, analyzer); ClassificationResult classificationResult = classifier.assignClass(inputDoc); assertNotNull(classificationResult.getAssignedClass()); assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass()); assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0); } finally { - if (compositeReaderWrapper != null) - compositeReaderWrapper.close(); + if (atomicReader != null) + atomicReader.close(); } } - private void populateIndex(Analyzer analyzer) throws Exception { + protected void checkPerformance(Classifier classifier, Analyzer analyzer, String classFieldName) throws Exception { + AtomicReader atomicReader = null; + long trainStart = System.currentTimeMillis(); + long trainEnd = 0l; + try { + populatePerformanceIndex(analyzer); + atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); + classifier.train(atomicReader, textFieldName, classFieldName, analyzer); + trainEnd = System.currentTimeMillis(); + long trainTime = trainEnd - trainStart; + assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000); + } finally { + if (atomicReader != null) + atomicReader.close(); + } + } + + private void populatePerformanceIndex(Analyzer analyzer) throws IOException { + indexWriter.deleteAll(); + indexWriter.commit(); + + FieldType ft = new FieldType(TextField.TYPE_STORED); + ft.setStoreTermVectors(true); + ft.setStoreTermVectorOffsets(true); + ft.setStoreTermVectorPositions(true); + int docs = 1000; + Random random = random(); + for (int i = 0; i < docs; i++) { + boolean b = random.nextBoolean(); + Document doc = new Document(); + doc.add(new Field(textFieldName, createRandomString(random), ft)); + doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft)); + doc.add(new Field(booleanFieldName, String.valueOf(b), ft)); + indexWriter.addDocument(doc, analyzer); + } + indexWriter.commit(); + } + + private String createRandomString(Random random) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < 20; i++) { + builder.append(_TestUtil.randomSimpleString(random, 5)); + builder.append(" "); + } + return builder.toString(); + } + + private void populateSampleIndex(Analyzer analyzer) throws Exception { + + indexWriter.deleteAll(); + indexWriter.commit(); FieldType ft = new FieldType(TextField.TYPE_STORED); ft.setStoreTermVectors(true); ft.setStoreTermVectorOffsets(true); ft.setStoreTermVectorPositions(true); + String text; + Document doc = new Document(); - doc.add(new Field(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " + + text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " + "who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " + - "the Unknown Soldier in Warsaw Tuesday.", ft)); + "the Unknown Soldier in Warsaw Tuesday."; + doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(categoryFieldName, "politics", ft)); - doc.add(new Field(booleanFieldName, "false", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); - doc.add(new Field(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" + - " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", ft)); + text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" + + " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama."; + doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(categoryFieldName, "politics", ft)); - doc.add(new Field(booleanFieldName, "false", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); - doc.add(new Field(textFieldName, "And there's a threshold question that he has to answer for the American people and " + + text = "And there's a threshold question that he has to answer for the American people and " + "that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " + - "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", ft)); + "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\""; + doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(categoryFieldName, "politics", ft)); - doc.add(new Field(booleanFieldName, "false", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); - doc.add(new Field(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " + + text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " + "keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " + - "Albany's School of Criminal Justice.", ft)); + "Albany's School of Criminal Justice."; + doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " + + "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " + + "world through the Internet."; + doc.add(new Field(textFieldName, text, ft)); + doc.add(new Field(categoryFieldName, "technology", ft)); doc.add(new Field(booleanFieldName, "false", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); - doc.add(new Field(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " + - "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " + - "world through the Internet.", ft)); + text = "So, about all those experts and analysts who've spent the past year or so saying " + + "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen."; + doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(categoryFieldName, "technology", ft)); - doc.add(new Field(booleanFieldName, "true", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); - doc.add(new Field(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " + - "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", ft)); - doc.add(new Field(categoryFieldName, "technology", ft)); - doc.add(new Field(booleanFieldName, "true", ft)); - indexWriter.addDocument(doc, analyzer); - - doc = new Document(); - doc.add(new Field(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" + + text = "More than 400 million people trust Google with their e-mail, and 50 million store files" + " in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " + - "generally transfer or store huge volumes of personal data online.", ft)); + "generally transfer or store huge volumes of personal data online."; + doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(categoryFieldName, "technology", ft)); - doc.add(new Field(booleanFieldName, "true", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); indexWriter.addDocument(doc, analyzer); indexWriter.commit(); diff --git a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java index 2e2b066576c..664750a0f9b 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java @@ -27,7 +27,12 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase