mirror of https://github.com/apache/lucene.git
LUCNE-4818 - added boolean perceptron classifier
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1519590 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
1b9c6a24d4
commit
05940e34ab
|
@ -0,0 +1,226 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.classification;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.StringReader;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.SortedMap;
|
||||||
|
import java.util.TreeMap;
|
||||||
|
|
||||||
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
|
import org.apache.lucene.analysis.TokenStream;
|
||||||
|
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||||
|
import org.apache.lucene.index.AtomicReader;
|
||||||
|
import org.apache.lucene.index.MultiFields;
|
||||||
|
import org.apache.lucene.index.StorableField;
|
||||||
|
import org.apache.lucene.index.StoredDocument;
|
||||||
|
import org.apache.lucene.index.Terms;
|
||||||
|
import org.apache.lucene.index.TermsEnum;
|
||||||
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
|
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.apache.lucene.util.IntsRef;
|
||||||
|
import org.apache.lucene.util.fst.Builder;
|
||||||
|
import org.apache.lucene.util.fst.FST;
|
||||||
|
import org.apache.lucene.util.fst.PositiveIntOutputs;
|
||||||
|
import org.apache.lucene.util.fst.Util;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based
|
||||||
|
* <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The
|
||||||
|
* weights are calculated using
|
||||||
|
* {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field
|
||||||
|
* and a per document basis and then a corresponding
|
||||||
|
* {@link org.apache.lucene.util.fst.FST} is used for class assignment.
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
|
|
||||||
|
private Double threshold;
|
||||||
|
private final Integer batchSize;
|
||||||
|
private Terms textTerms;
|
||||||
|
private Analyzer analyzer;
|
||||||
|
private String textFieldName;
|
||||||
|
private FST<Long> fst;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a {@link BooleanPerceptronClassifier}
|
||||||
|
*
|
||||||
|
* @param threshold
|
||||||
|
* the binary threshold for perceptron output evaluation
|
||||||
|
*/
|
||||||
|
public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
|
||||||
|
this.threshold = threshold;
|
||||||
|
this.batchSize = batchSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default constructor, no batch updates of FST, perceptron threshold is
|
||||||
|
* calculated via underlying index metrics during
|
||||||
|
* {@link #train(org.apache.lucene.index.AtomicReader, String, String, org.apache.lucene.analysis.Analyzer)
|
||||||
|
* training}
|
||||||
|
*/
|
||||||
|
public BooleanPerceptronClassifier() {
|
||||||
|
batchSize = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public ClassificationResult<Boolean> assignClass(String text)
|
||||||
|
throws IOException {
|
||||||
|
if (textTerms == null) {
|
||||||
|
throw new IOException("You must first call Classifier#train");
|
||||||
|
}
|
||||||
|
Long output = 0l;
|
||||||
|
TokenStream tokenStream = analyzer.tokenStream(textFieldName,
|
||||||
|
new StringReader(text));
|
||||||
|
CharTermAttribute charTermAttribute = tokenStream
|
||||||
|
.addAttribute(CharTermAttribute.class);
|
||||||
|
tokenStream.reset();
|
||||||
|
while (tokenStream.incrementToken()) {
|
||||||
|
String s = charTermAttribute.toString();
|
||||||
|
Long d = Util.get(fst, new BytesRef(s));
|
||||||
|
if (d != null) {
|
||||||
|
output += d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokenStream.end();
|
||||||
|
tokenStream.close();
|
||||||
|
|
||||||
|
return new ClassificationResult<>(output >= threshold, output.doubleValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void train(AtomicReader atomicReader, String textFieldName,
|
||||||
|
String classFieldName, Analyzer analyzer) throws IOException {
|
||||||
|
this.textTerms = MultiFields.getTerms(atomicReader, textFieldName);
|
||||||
|
|
||||||
|
if (textTerms == null) {
|
||||||
|
throw new IOException(new StringBuilder(
|
||||||
|
"term vectors need to be available for field ").append(textFieldName)
|
||||||
|
.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
this.analyzer = analyzer;
|
||||||
|
this.textFieldName = textFieldName;
|
||||||
|
|
||||||
|
if (threshold == null || threshold == 0d) {
|
||||||
|
// automatic assign a threshold
|
||||||
|
long sumDocFreq = atomicReader.getSumDocFreq(textFieldName);
|
||||||
|
if (sumDocFreq != -1) {
|
||||||
|
this.threshold = (double) sumDocFreq / 2d;
|
||||||
|
} else {
|
||||||
|
throw new IOException(
|
||||||
|
"threshold cannot be assigned since term vectors for field "
|
||||||
|
+ textFieldName + " do not exist");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO : remove this map as soon as we have a writable FST
|
||||||
|
SortedMap<String,Double> weights = new TreeMap<>();
|
||||||
|
|
||||||
|
TermsEnum reuse = textTerms.iterator(null);
|
||||||
|
BytesRef textTerm;
|
||||||
|
while ((textTerm = reuse.next()) != null) {
|
||||||
|
weights.put(textTerm.utf8ToString(), (double) reuse.totalTermFreq());
|
||||||
|
}
|
||||||
|
updateFST(weights);
|
||||||
|
|
||||||
|
IndexSearcher indexSearcher = new IndexSearcher(atomicReader);
|
||||||
|
|
||||||
|
int batchCount = 0;
|
||||||
|
|
||||||
|
// do a *:* search and use stored field values
|
||||||
|
for (ScoreDoc scoreDoc : indexSearcher.search(new MatchAllDocsQuery(),
|
||||||
|
Integer.MAX_VALUE).scoreDocs) {
|
||||||
|
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
|
||||||
|
|
||||||
|
// assign class to the doc
|
||||||
|
ClassificationResult<Boolean> classificationResult = assignClass(doc
|
||||||
|
.getField(textFieldName).stringValue());
|
||||||
|
Boolean assignedClass = classificationResult.getAssignedClass();
|
||||||
|
|
||||||
|
// get the expected result
|
||||||
|
StorableField field = doc.getField(classFieldName);
|
||||||
|
|
||||||
|
Boolean correctClass = Boolean.valueOf(field.stringValue());
|
||||||
|
long modifier = correctClass.compareTo(assignedClass);
|
||||||
|
if (modifier != 0) {
|
||||||
|
reuse = updateWeights(atomicReader, reuse, scoreDoc.doc, assignedClass,
|
||||||
|
weights, modifier, batchCount % batchSize == 0);
|
||||||
|
}
|
||||||
|
batchCount++;
|
||||||
|
}
|
||||||
|
weights.clear(); // free memory while waiting for GC
|
||||||
|
}
|
||||||
|
|
||||||
|
private TermsEnum updateWeights(AtomicReader atomicReader, TermsEnum reuse,
|
||||||
|
int docId, Boolean assignedClass, SortedMap<String,Double> weights,
|
||||||
|
double modifier, boolean updateFST) throws IOException {
|
||||||
|
TermsEnum cte = textTerms.iterator(reuse);
|
||||||
|
|
||||||
|
// get the doc term vectors
|
||||||
|
Terms terms = atomicReader.getTermVector(docId, textFieldName);
|
||||||
|
|
||||||
|
if (terms == null) {
|
||||||
|
throw new IOException("term vectors must be stored for field "
|
||||||
|
+ textFieldName);
|
||||||
|
}
|
||||||
|
|
||||||
|
TermsEnum termsEnum = terms.iterator(null);
|
||||||
|
|
||||||
|
BytesRef term;
|
||||||
|
|
||||||
|
while ((term = termsEnum.next()) != null) {
|
||||||
|
cte.seekExact(term);
|
||||||
|
if (assignedClass != null) {
|
||||||
|
long termFreqLocal = termsEnum.totalTermFreq();
|
||||||
|
// update weights
|
||||||
|
Long previousValue = Util.get(fst, term);
|
||||||
|
String termString = term.utf8ToString();
|
||||||
|
weights.put(termString, previousValue + modifier * termFreqLocal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (updateFST) {
|
||||||
|
updateFST(weights);
|
||||||
|
}
|
||||||
|
reuse = cte;
|
||||||
|
return reuse;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateFST(SortedMap<String,Double> weights) throws IOException {
|
||||||
|
PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
|
||||||
|
Builder<Long> fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs);
|
||||||
|
BytesRef scratchBytes = new BytesRef();
|
||||||
|
IntsRef scratchInts = new IntsRef();
|
||||||
|
for (Map.Entry<String,Double> entry : weights.entrySet()) {
|
||||||
|
scratchBytes.copyChars(entry.getKey());
|
||||||
|
fstBuilder.add(Util.toIntsRef(scratchBytes, scratchInts), entry
|
||||||
|
.getValue().longValue());
|
||||||
|
}
|
||||||
|
fst = fstBuilder.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -59,7 +59,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
@Override
|
@Override
|
||||||
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
|
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
|
||||||
if (mlt == null) {
|
if (mlt == null) {
|
||||||
throw new IOException("You must first call Classifier#train first");
|
throw new IOException("You must first call Classifier#train");
|
||||||
}
|
}
|
||||||
Query q = mlt.like(new StringReader(text), textFieldName);
|
Query q = mlt.like(new StringReader(text), textFieldName);
|
||||||
TopDocs topDocs = indexSearcher.search(q, k);
|
TopDocs topDocs = indexSearcher.search(q, k);
|
||||||
|
@ -71,13 +71,11 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
Map<BytesRef, Integer> classCounts = new HashMap<BytesRef, Integer>();
|
Map<BytesRef, Integer> classCounts = new HashMap<BytesRef, Integer>();
|
||||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||||
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
|
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
|
||||||
if (cl != null) {
|
Integer count = classCounts.get(cl);
|
||||||
Integer count = classCounts.get(cl);
|
if (count != null) {
|
||||||
if (count != null) {
|
classCounts.put(cl, count + 1);
|
||||||
classCounts.put(cl, count + 1);
|
} else {
|
||||||
} else {
|
classCounts.put(cl, 1);
|
||||||
classCounts.put(cl, 1);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
double max = 0;
|
double max = 0;
|
||||||
|
|
|
@ -102,7 +102,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
@Override
|
@Override
|
||||||
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
|
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
|
||||||
if (atomicReader == null) {
|
if (atomicReader == null) {
|
||||||
throw new IOException("You must first call Classifier#train first");
|
throw new IOException("You must first call Classifier#train");
|
||||||
}
|
}
|
||||||
double max = 0d;
|
double max = 0d;
|
||||||
BytesRef foundClass = new BytesRef();
|
BytesRef foundClass = new BytesRef();
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
Uses already seen data (the indexed documents) to classify new documents.
|
Uses already seen data (the indexed documents) to classify new documents.
|
||||||
Currently only contains a (simplistic) Lucene based Naive Bayes classifier
|
Currently only contains a (simplistic) Lucene based Naive Bayes classifier,
|
||||||
and a k-Nearest Neighbor classifier
|
a k-Nearest Neighbor classifier and a Perceptron based classifier
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.classification;
|
||||||
|
|
||||||
|
import org.apache.lucene.analysis.MockAnalyzer;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Testcase for {@link org.apache.lucene.classification.BooleanPerceptronClassifier}
|
||||||
|
*/
|
||||||
|
public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Boolean> {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicUsage() throws Exception {
|
||||||
|
checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testExplicitThreshold() throws Exception {
|
||||||
|
checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPerformance() throws Exception {
|
||||||
|
checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -27,9 +27,13 @@ import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
|
import org.apache.lucene.util._TestUtil;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base class for testing {@link Classifier}s
|
* Base class for testing {@link Classifier}s
|
||||||
*/
|
*/
|
||||||
|
@ -41,8 +45,9 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
|
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
|
||||||
|
|
||||||
private RandomIndexWriter indexWriter;
|
private RandomIndexWriter indexWriter;
|
||||||
private String textFieldName;
|
|
||||||
private Directory dir;
|
private Directory dir;
|
||||||
|
|
||||||
|
String textFieldName;
|
||||||
String categoryFieldName;
|
String categoryFieldName;
|
||||||
String booleanFieldName;
|
String booleanFieldName;
|
||||||
|
|
||||||
|
@ -66,82 +71,141 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception {
|
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
|
||||||
AtomicReader compositeReaderWrapper = null;
|
AtomicReader atomicReader = null;
|
||||||
try {
|
try {
|
||||||
populateIndex(analyzer);
|
populateSampleIndex(analyzer);
|
||||||
compositeReaderWrapper = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||||
classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer);
|
classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
|
||||||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||||
assertNotNull(classificationResult.getAssignedClass());
|
assertNotNull(classificationResult.getAssignedClass());
|
||||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||||
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
|
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
|
||||||
} finally {
|
} finally {
|
||||||
if (compositeReaderWrapper != null)
|
if (atomicReader != null)
|
||||||
compositeReaderWrapper.close();
|
atomicReader.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void populateIndex(Analyzer analyzer) throws Exception {
|
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
|
||||||
|
AtomicReader atomicReader = null;
|
||||||
|
long trainStart = System.currentTimeMillis();
|
||||||
|
long trainEnd = 0l;
|
||||||
|
try {
|
||||||
|
populatePerformanceIndex(analyzer);
|
||||||
|
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||||
|
classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
|
||||||
|
trainEnd = System.currentTimeMillis();
|
||||||
|
long trainTime = trainEnd - trainStart;
|
||||||
|
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
|
||||||
|
} finally {
|
||||||
|
if (atomicReader != null)
|
||||||
|
atomicReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
|
||||||
|
indexWriter.deleteAll();
|
||||||
|
indexWriter.commit();
|
||||||
|
|
||||||
|
FieldType ft = new FieldType(TextField.TYPE_STORED);
|
||||||
|
ft.setStoreTermVectors(true);
|
||||||
|
ft.setStoreTermVectorOffsets(true);
|
||||||
|
ft.setStoreTermVectorPositions(true);
|
||||||
|
int docs = 1000;
|
||||||
|
Random random = random();
|
||||||
|
for (int i = 0; i < docs; i++) {
|
||||||
|
boolean b = random.nextBoolean();
|
||||||
|
Document doc = new Document();
|
||||||
|
doc.add(new Field(textFieldName, createRandomString(random), ft));
|
||||||
|
doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, String.valueOf(b), ft));
|
||||||
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
}
|
||||||
|
indexWriter.commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
private String createRandomString(Random random) {
|
||||||
|
StringBuilder builder = new StringBuilder();
|
||||||
|
for (int i = 0; i < 20; i++) {
|
||||||
|
builder.append(_TestUtil.randomSimpleString(random, 5));
|
||||||
|
builder.append(" ");
|
||||||
|
}
|
||||||
|
return builder.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void populateSampleIndex(Analyzer analyzer) throws Exception {
|
||||||
|
|
||||||
|
indexWriter.deleteAll();
|
||||||
|
indexWriter.commit();
|
||||||
|
|
||||||
FieldType ft = new FieldType(TextField.TYPE_STORED);
|
FieldType ft = new FieldType(TextField.TYPE_STORED);
|
||||||
ft.setStoreTermVectors(true);
|
ft.setStoreTermVectors(true);
|
||||||
ft.setStoreTermVectorOffsets(true);
|
ft.setStoreTermVectorOffsets(true);
|
||||||
ft.setStoreTermVectorPositions(true);
|
ft.setStoreTermVectorPositions(true);
|
||||||
|
|
||||||
|
String text;
|
||||||
|
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
doc.add(new Field(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
|
text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
|
||||||
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
|
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
|
||||||
"the Unknown Soldier in Warsaw Tuesday.", ft));
|
"the Unknown Soldier in Warsaw Tuesday.";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "politics", ft));
|
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||||
doc.add(new Field(booleanFieldName, "false", ft));
|
doc.add(new Field(booleanFieldName, "true", ft));
|
||||||
|
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
doc.add(new Field(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
|
text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
|
||||||
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", ft));
|
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "politics", ft));
|
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||||
doc.add(new Field(booleanFieldName, "false", ft));
|
doc.add(new Field(booleanFieldName, "true", ft));
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
doc.add(new Field(textFieldName, "And there's a threshold question that he has to answer for the American people and " +
|
text = "And there's a threshold question that he has to answer for the American people and " +
|
||||||
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
|
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
|
||||||
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", ft));
|
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "politics", ft));
|
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||||
doc.add(new Field(booleanFieldName, "false", ft));
|
doc.add(new Field(booleanFieldName, "true", ft));
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
doc.add(new Field(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
|
text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
|
||||||
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
|
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
|
||||||
"Albany's School of Criminal Justice.", ft));
|
"Albany's School of Criminal Justice.";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "politics", ft));
|
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "true", ft));
|
||||||
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
|
doc = new Document();
|
||||||
|
text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
|
||||||
|
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
|
||||||
|
"world through the Internet.";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
doc.add(new Field(categoryFieldName, "technology", ft));
|
||||||
doc.add(new Field(booleanFieldName, "false", ft));
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
doc.add(new Field(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
|
text = "So, about all those experts and analysts who've spent the past year or so saying " +
|
||||||
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
|
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
|
||||||
"world through the Internet.", ft));
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "technology", ft));
|
doc.add(new Field(categoryFieldName, "technology", ft));
|
||||||
doc.add(new Field(booleanFieldName, "true", ft));
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
doc.add(new Field(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " +
|
text = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
|
||||||
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", ft));
|
|
||||||
doc.add(new Field(categoryFieldName, "technology", ft));
|
|
||||||
doc.add(new Field(booleanFieldName, "true", ft));
|
|
||||||
indexWriter.addDocument(doc, analyzer);
|
|
||||||
|
|
||||||
doc = new Document();
|
|
||||||
doc.add(new Field(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" +
|
|
||||||
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
|
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
|
||||||
"generally transfer or store huge volumes of personal data online.", ft));
|
"generally transfer or store huge volumes of personal data online.";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "technology", ft));
|
doc.add(new Field(categoryFieldName, "technology", ft));
|
||||||
doc.add(new Field(booleanFieldName, "true", ft));
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
indexWriter.addDocument(doc, analyzer);
|
indexWriter.addDocument(doc, analyzer);
|
||||||
|
|
||||||
indexWriter.commit();
|
indexWriter.commit();
|
||||||
|
|
|
@ -27,7 +27,12 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsage() throws Exception {
|
public void testBasicUsage() throws Exception {
|
||||||
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), categoryFieldName);
|
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPerformance() throws Exception {
|
||||||
|
checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,11 +21,9 @@ import org.apache.lucene.analysis.MockAnalyzer;
|
||||||
import org.apache.lucene.analysis.Tokenizer;
|
import org.apache.lucene.analysis.Tokenizer;
|
||||||
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
||||||
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
||||||
import org.apache.lucene.analysis.ngram.EdgeNGramTokenizer;
|
|
||||||
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
import org.apache.lucene.util.Version;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import java.io.Reader;
|
import java.io.Reader;
|
||||||
|
@ -39,13 +37,13 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsage() throws Exception {
|
public void testBasicUsage() throws Exception {
|
||||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), categoryFieldName);
|
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
||||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), categoryFieldName);
|
checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNGramUsage() throws Exception {
|
public void testNGramUsage() throws Exception {
|
||||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), categoryFieldName);
|
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
private class NGramAnalyzer extends Analyzer {
|
private class NGramAnalyzer extends Analyzer {
|
||||||
|
@ -56,4 +54,9 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPerformance() throws Exception {
|
||||||
|
checkPerformance(new SimpleNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -131,9 +131,15 @@ public class DataSplitterTest extends LuceneTestCase {
|
||||||
closeQuietly(testReader);
|
closeQuietly(testReader);
|
||||||
closeQuietly(cvReader);
|
closeQuietly(cvReader);
|
||||||
} finally {
|
} finally {
|
||||||
trainingIndex.close();
|
if (trainingIndex != null) {
|
||||||
testIndex.close();
|
trainingIndex.close();
|
||||||
crossValidationIndex.close();
|
}
|
||||||
|
if (testIndex != null) {
|
||||||
|
testIndex.close();
|
||||||
|
}
|
||||||
|
if (crossValidationIndex != null) {
|
||||||
|
crossValidationIndex.close();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue