diff --git a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
new file mode 100644
index 00000000000..267ac99949c
--- /dev/null
+++ b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.classification;
+
+import java.io.IOException;
+import java.io.StringReader;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.lucene.index.AtomicReader;
+import org.apache.lucene.index.MultiFields;
+import org.apache.lucene.index.StorableField;
+import org.apache.lucene.index.StoredDocument;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.IntsRef;
+import org.apache.lucene.util.fst.Builder;
+import org.apache.lucene.util.fst.FST;
+import org.apache.lucene.util.fst.PositiveIntOutputs;
+import org.apache.lucene.util.fst.Util;
+
+/**
+ * A perceptron (see http://en.wikipedia.org/wiki/Perceptron
) based
+ * Boolean
{@link org.apache.lucene.classification.Classifier}. The
+ * weights are calculated using
+ * {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field
+ * and a per document basis and then a corresponding
+ * {@link org.apache.lucene.util.fst.FST} is used for class assignment.
+ *
+ * @lucene.experimental
+ */
+public class BooleanPerceptronClassifier implements Classifier {
+
+ private Double threshold;
+ private final Integer batchSize;
+ private Terms textTerms;
+ private Analyzer analyzer;
+ private String textFieldName;
+ private FST fst;
+
+ /**
+ * Create a {@link BooleanPerceptronClassifier}
+ *
+ * @param threshold
+ * the binary threshold for perceptron output evaluation
+ */
+ public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
+ this.threshold = threshold;
+ this.batchSize = batchSize;
+ }
+
+ /**
+ * Default constructor, no batch updates of FST, perceptron threshold is
+ * calculated via underlying index metrics during
+ * {@link #train(org.apache.lucene.index.AtomicReader, String, String, org.apache.lucene.analysis.Analyzer)
+ * training}
+ */
+ public BooleanPerceptronClassifier() {
+ batchSize = 1;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public ClassificationResult assignClass(String text)
+ throws IOException {
+ if (textTerms == null) {
+ throw new IOException("You must first call Classifier#train");
+ }
+ Long output = 0l;
+ TokenStream tokenStream = analyzer.tokenStream(textFieldName,
+ new StringReader(text));
+ CharTermAttribute charTermAttribute = tokenStream
+ .addAttribute(CharTermAttribute.class);
+ tokenStream.reset();
+ while (tokenStream.incrementToken()) {
+ String s = charTermAttribute.toString();
+ Long d = Util.get(fst, new BytesRef(s));
+ if (d != null) {
+ output += d;
+ }
+ }
+ tokenStream.end();
+ tokenStream.close();
+
+ return new ClassificationResult<>(output >= threshold, output.doubleValue());
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public void train(AtomicReader atomicReader, String textFieldName,
+ String classFieldName, Analyzer analyzer) throws IOException {
+ this.textTerms = MultiFields.getTerms(atomicReader, textFieldName);
+
+ if (textTerms == null) {
+ throw new IOException(new StringBuilder(
+ "term vectors need to be available for field ").append(textFieldName)
+ .toString());
+ }
+
+ this.analyzer = analyzer;
+ this.textFieldName = textFieldName;
+
+ if (threshold == null || threshold == 0d) {
+ // automatic assign a threshold
+ long sumDocFreq = atomicReader.getSumDocFreq(textFieldName);
+ if (sumDocFreq != -1) {
+ this.threshold = (double) sumDocFreq / 2d;
+ } else {
+ throw new IOException(
+ "threshold cannot be assigned since term vectors for field "
+ + textFieldName + " do not exist");
+ }
+ }
+
+ // TODO : remove this map as soon as we have a writable FST
+ SortedMap weights = new TreeMap<>();
+
+ TermsEnum reuse = textTerms.iterator(null);
+ BytesRef textTerm;
+ while ((textTerm = reuse.next()) != null) {
+ weights.put(textTerm.utf8ToString(), (double) reuse.totalTermFreq());
+ }
+ updateFST(weights);
+
+ IndexSearcher indexSearcher = new IndexSearcher(atomicReader);
+
+ int batchCount = 0;
+
+ // do a *:* search and use stored field values
+ for (ScoreDoc scoreDoc : indexSearcher.search(new MatchAllDocsQuery(),
+ Integer.MAX_VALUE).scoreDocs) {
+ StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
+
+ // assign class to the doc
+ ClassificationResult classificationResult = assignClass(doc
+ .getField(textFieldName).stringValue());
+ Boolean assignedClass = classificationResult.getAssignedClass();
+
+ // get the expected result
+ StorableField field = doc.getField(classFieldName);
+
+ Boolean correctClass = Boolean.valueOf(field.stringValue());
+ long modifier = correctClass.compareTo(assignedClass);
+ if (modifier != 0) {
+ reuse = updateWeights(atomicReader, reuse, scoreDoc.doc, assignedClass,
+ weights, modifier, batchCount % batchSize == 0);
+ }
+ batchCount++;
+ }
+ weights.clear(); // free memory while waiting for GC
+ }
+
+ private TermsEnum updateWeights(AtomicReader atomicReader, TermsEnum reuse,
+ int docId, Boolean assignedClass, SortedMap weights,
+ double modifier, boolean updateFST) throws IOException {
+ TermsEnum cte = textTerms.iterator(reuse);
+
+ // get the doc term vectors
+ Terms terms = atomicReader.getTermVector(docId, textFieldName);
+
+ if (terms == null) {
+ throw new IOException("term vectors must be stored for field "
+ + textFieldName);
+ }
+
+ TermsEnum termsEnum = terms.iterator(null);
+
+ BytesRef term;
+
+ while ((term = termsEnum.next()) != null) {
+ cte.seekExact(term);
+ if (assignedClass != null) {
+ long termFreqLocal = termsEnum.totalTermFreq();
+ // update weights
+ Long previousValue = Util.get(fst, term);
+ String termString = term.utf8ToString();
+ weights.put(termString, previousValue + modifier * termFreqLocal);
+ }
+ }
+ if (updateFST) {
+ updateFST(weights);
+ }
+ reuse = cte;
+ return reuse;
+ }
+
+ private void updateFST(SortedMap weights) throws IOException {
+ PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
+ Builder fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs);
+ BytesRef scratchBytes = new BytesRef();
+ IntsRef scratchInts = new IntsRef();
+ for (Map.Entry entry : weights.entrySet()) {
+ scratchBytes.copyChars(entry.getKey());
+ fstBuilder.add(Util.toIntsRef(scratchBytes, scratchInts), entry
+ .getValue().longValue());
+ }
+ fst = fstBuilder.finish();
+ }
+
+}
\ No newline at end of file
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
index e8069eea47a..bbaa0566d5b 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
@@ -59,7 +59,7 @@ public class KNearestNeighborClassifier implements Classifier {
@Override
public ClassificationResult assignClass(String text) throws IOException {
if (mlt == null) {
- throw new IOException("You must first call Classifier#train first");
+ throw new IOException("You must first call Classifier#train");
}
Query q = mlt.like(new StringReader(text), textFieldName);
TopDocs topDocs = indexSearcher.search(q, k);
@@ -71,13 +71,11 @@ public class KNearestNeighborClassifier implements Classifier {
Map classCounts = new HashMap();
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
- if (cl != null) {
- Integer count = classCounts.get(cl);
- if (count != null) {
- classCounts.put(cl, count + 1);
- } else {
- classCounts.put(cl, 1);
- }
+ Integer count = classCounts.get(cl);
+ if (count != null) {
+ classCounts.put(cl, count + 1);
+ } else {
+ classCounts.put(cl, 1);
}
}
double max = 0;
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
index 2fe5c832b18..74fc631ef28 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
@@ -102,7 +102,7 @@ public class SimpleNaiveBayesClassifier implements Classifier {
@Override
public ClassificationResult assignClass(String inputDocument) throws IOException {
if (atomicReader == null) {
- throw new IOException("You must first call Classifier#train first");
+ throw new IOException("You must first call Classifier#train");
}
double max = 0d;
BytesRef foundClass = new BytesRef();
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/package.html b/lucene/classification/src/java/org/apache/lucene/classification/package.html
index 94b0ddc6b79..b68c1988a90 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/package.html
+++ b/lucene/classification/src/java/org/apache/lucene/classification/package.html
@@ -17,7 +17,7 @@
Uses already seen data (the indexed documents) to classify new documents.
-Currently only contains a (simplistic) Lucene based Naive Bayes classifier
-and a k-Nearest Neighbor classifier
+Currently only contains a (simplistic) Lucene based Naive Bayes classifier,
+a k-Nearest Neighbor classifier and a Perceptron based classifier
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
new file mode 100644
index 00000000000..c6b7b10b543
--- /dev/null
+++ b/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.classification;
+
+import org.apache.lucene.analysis.MockAnalyzer;
+import org.junit.Test;
+
+/**
+ * Testcase for {@link org.apache.lucene.classification.BooleanPerceptronClassifier}
+ */
+public class BooleanPerceptronClassifierTest extends ClassificationTestBase {
+
+ @Test
+ public void testBasicUsage() throws Exception {
+ checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
+ }
+
+ @Test
+ public void testExplicitThreshold() throws Exception {
+ checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
+ }
+
+ @Test
+ public void testPerformance() throws Exception {
+ checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName);
+ }
+
+}
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
index 2c5b604fcef..fe31b412159 100644
--- a/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
+++ b/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
@@ -27,9 +27,13 @@ import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
+import org.apache.lucene.util._TestUtil;
import org.junit.After;
import org.junit.Before;
+import java.io.IOException;
+import java.util.Random;
+
/**
* Base class for testing {@link Classifier}s
*/
@@ -41,8 +45,9 @@ public abstract class ClassificationTestBase extends LuceneTestCase {
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
private RandomIndexWriter indexWriter;
- private String textFieldName;
private Directory dir;
+
+ String textFieldName;
String categoryFieldName;
String booleanFieldName;
@@ -66,82 +71,141 @@ public abstract class ClassificationTestBase extends LuceneTestCase {
}
- protected void checkCorrectClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception {
- AtomicReader compositeReaderWrapper = null;
+ protected void checkCorrectClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
+ AtomicReader atomicReader = null;
try {
- populateIndex(analyzer);
- compositeReaderWrapper = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
- classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer);
+ populateSampleIndex(analyzer);
+ atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
+ classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
ClassificationResult classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
} finally {
- if (compositeReaderWrapper != null)
- compositeReaderWrapper.close();
+ if (atomicReader != null)
+ atomicReader.close();
}
}
- private void populateIndex(Analyzer analyzer) throws Exception {
+ protected void checkPerformance(Classifier classifier, Analyzer analyzer, String classFieldName) throws Exception {
+ AtomicReader atomicReader = null;
+ long trainStart = System.currentTimeMillis();
+ long trainEnd = 0l;
+ try {
+ populatePerformanceIndex(analyzer);
+ atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
+ classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
+ trainEnd = System.currentTimeMillis();
+ long trainTime = trainEnd - trainStart;
+ assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
+ } finally {
+ if (atomicReader != null)
+ atomicReader.close();
+ }
+ }
+
+ private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
+ indexWriter.deleteAll();
+ indexWriter.commit();
+
+ FieldType ft = new FieldType(TextField.TYPE_STORED);
+ ft.setStoreTermVectors(true);
+ ft.setStoreTermVectorOffsets(true);
+ ft.setStoreTermVectorPositions(true);
+ int docs = 1000;
+ Random random = random();
+ for (int i = 0; i < docs; i++) {
+ boolean b = random.nextBoolean();
+ Document doc = new Document();
+ doc.add(new Field(textFieldName, createRandomString(random), ft));
+ doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft));
+ doc.add(new Field(booleanFieldName, String.valueOf(b), ft));
+ indexWriter.addDocument(doc, analyzer);
+ }
+ indexWriter.commit();
+ }
+
+ private String createRandomString(Random random) {
+ StringBuilder builder = new StringBuilder();
+ for (int i = 0; i < 20; i++) {
+ builder.append(_TestUtil.randomSimpleString(random, 5));
+ builder.append(" ");
+ }
+ return builder.toString();
+ }
+
+ private void populateSampleIndex(Analyzer analyzer) throws Exception {
+
+ indexWriter.deleteAll();
+ indexWriter.commit();
FieldType ft = new FieldType(TextField.TYPE_STORED);
ft.setStoreTermVectors(true);
ft.setStoreTermVectorOffsets(true);
ft.setStoreTermVectorPositions(true);
+ String text;
+
Document doc = new Document();
- doc.add(new Field(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
+ text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
- "the Unknown Soldier in Warsaw Tuesday.", ft));
+ "the Unknown Soldier in Warsaw Tuesday.";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
- doc.add(new Field(booleanFieldName, "false", ft));
+ doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
- doc.add(new Field(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
- " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", ft));
+ text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
+ " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
- doc.add(new Field(booleanFieldName, "false", ft));
+ doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
- doc.add(new Field(textFieldName, "And there's a threshold question that he has to answer for the American people and " +
+ text = "And there's a threshold question that he has to answer for the American people and " +
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
- "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", ft));
+ "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
- doc.add(new Field(booleanFieldName, "false", ft));
+ doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
- doc.add(new Field(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
+ text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
- "Albany's School of Criminal Justice.", ft));
+ "Albany's School of Criminal Justice.";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
+ doc.add(new Field(booleanFieldName, "true", ft));
+ indexWriter.addDocument(doc, analyzer);
+
+ doc = new Document();
+ text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
+ "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
+ "world through the Internet.";
+ doc.add(new Field(textFieldName, text, ft));
+ doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
- doc.add(new Field(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
- "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
- "world through the Internet.", ft));
+ text = "So, about all those experts and analysts who've spent the past year or so saying " +
+ "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
- doc.add(new Field(booleanFieldName, "true", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
- doc.add(new Field(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " +
- "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", ft));
- doc.add(new Field(categoryFieldName, "technology", ft));
- doc.add(new Field(booleanFieldName, "true", ft));
- indexWriter.addDocument(doc, analyzer);
-
- doc = new Document();
- doc.add(new Field(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" +
+ text = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
- "generally transfer or store huge volumes of personal data online.", ft));
+ "generally transfer or store huge volumes of personal data online.";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
- doc.add(new Field(booleanFieldName, "true", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
indexWriter.commit();
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
index 2e2b066576c..664750a0f9b 100644
--- a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
+++ b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
@@ -27,7 +27,12 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase