mirror of https://github.com/apache/lucene.git
[LUCENE-4345] - first impl of a classification module for lucene
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1384219 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
aed9e0f350
commit
4711f774ee
|
@ -0,0 +1,26 @@
|
|||
<?xml version="1.0"?>
|
||||
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
contributor license agreements. See the NOTICE file distributed with
|
||||
this work for additional information regarding copyright ownership.
|
||||
The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
<project name="classification" default="default">
|
||||
<description>
|
||||
Classification module for Lucene
|
||||
</description>
|
||||
|
||||
<import file="../module-build.xml"/>
|
||||
</project>
|
|
@ -0,0 +1,21 @@
|
|||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
<ivy-module version="2.0">
|
||||
<info organisation="org.apache.lucene" module="classification"/>
|
||||
</ivy-module>
|
|
@ -0,0 +1,49 @@
|
|||
package org.apache.lucene.classification;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.index.AtomicReader;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>
|
||||
*/
|
||||
public interface Classifier {
|
||||
|
||||
/**
|
||||
* Assign a class to the given text String
|
||||
* @param text a String containing text to be classified
|
||||
* @return a String representing a class
|
||||
* @throws IOException
|
||||
*/
|
||||
public String assignClass(String text) throws IOException;
|
||||
|
||||
/**
|
||||
* Train the classifier using the underlying Lucene index
|
||||
* @param atomicReader the reader to use to access the Lucene index
|
||||
* @param textFieldName the name of the field used to compare documents
|
||||
* @param classFieldName the name of the field containing the class assigned to documents
|
||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||
* @throws IOException
|
||||
*/
|
||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
|
||||
throws IOException;
|
||||
|
||||
}
|
|
@ -0,0 +1,142 @@
|
|||
package org.apache.lucene.classification;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
import org.apache.lucene.index.AtomicReader;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.search.BooleanClause;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.StringReader;
|
||||
import java.util.Collection;
|
||||
import java.util.LinkedList;
|
||||
|
||||
/**
|
||||
* A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code>
|
||||
*/
|
||||
public class SimpleNaiveBayesClassifier implements Classifier {
|
||||
|
||||
private AtomicReader atomicReader;
|
||||
private String textFieldName;
|
||||
private String classFieldName;
|
||||
private int docsWithClassSize;
|
||||
private Analyzer analyzer;
|
||||
private IndexSearcher indexSearcher;
|
||||
|
||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
|
||||
throws IOException {
|
||||
this.atomicReader = atomicReader;
|
||||
this.indexSearcher = new IndexSearcher(this.atomicReader);
|
||||
this.textFieldName = textFieldName;
|
||||
this.classFieldName = classFieldName;
|
||||
this.analyzer = analyzer;
|
||||
this.docsWithClassSize = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
|
||||
}
|
||||
|
||||
private String[] tokenizeDoc(String doc) throws IOException {
|
||||
Collection<String> result = new LinkedList<String>();
|
||||
TokenStream tokenStream = analyzer.tokenStream(textFieldName, new StringReader(doc));
|
||||
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
|
||||
tokenStream.reset();
|
||||
while (tokenStream.incrementToken()) {
|
||||
result.add(charTermAttribute.toString());
|
||||
}
|
||||
tokenStream.end();
|
||||
tokenStream.close();
|
||||
return result.toArray(new String[result.size()]);
|
||||
}
|
||||
|
||||
public String assignClass(String inputDocument) throws IOException {
|
||||
if (atomicReader == null) {
|
||||
throw new RuntimeException("need to train the classifier first");
|
||||
}
|
||||
Double max = 0d;
|
||||
String foundClass = null;
|
||||
|
||||
Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
|
||||
TermsEnum termsEnum = terms.iterator(null);
|
||||
BytesRef t = termsEnum.next();
|
||||
while (t != null) {
|
||||
String classValue = t.utf8ToString();
|
||||
// TODO : turn it to be in log scale
|
||||
Double clVal = calculatePrior(classValue) * calculateLikelihood(inputDocument, classValue);
|
||||
if (clVal > max) {
|
||||
max = clVal;
|
||||
foundClass = classValue;
|
||||
}
|
||||
t = termsEnum.next();
|
||||
}
|
||||
return foundClass;
|
||||
}
|
||||
|
||||
|
||||
private Double calculateLikelihood(String document, String c) throws IOException {
|
||||
// for each word
|
||||
Double result = 1d;
|
||||
for (String word : tokenizeDoc(document)) {
|
||||
// search with text:word AND class:c
|
||||
int hits = getWordFreqForClass(word, c);
|
||||
|
||||
// num : count the no of times the word appears in documents of class c (+1)
|
||||
double num = hits + 1; // +1 is added because of add 1 smoothing
|
||||
|
||||
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
|
||||
double den = getTextTermFreqForClass(c) + docsWithClassSize;
|
||||
|
||||
// P(w|c) = num/den
|
||||
double wordProbability = num / den;
|
||||
result *= wordProbability;
|
||||
}
|
||||
|
||||
// P(d|c) = P(w1|c)*...*P(wn|c)
|
||||
return result;
|
||||
}
|
||||
|
||||
private double getTextTermFreqForClass(String c) throws IOException {
|
||||
Terms terms = MultiFields.getTerms(atomicReader, textFieldName);
|
||||
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||
double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
||||
int docsWithC = atomicReader.docFreq(classFieldName, new BytesRef(c));
|
||||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
|
||||
}
|
||||
|
||||
private int getWordFreqForClass(String word, String c) throws IOException {
|
||||
BooleanQuery booleanQuery = new BooleanQuery();
|
||||
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
|
||||
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
|
||||
return indexSearcher.search(booleanQuery, 1).totalHits;
|
||||
}
|
||||
|
||||
private Double calculatePrior(String currentClass) throws IOException {
|
||||
return (double) docCount(currentClass) / docsWithClassSize;
|
||||
}
|
||||
|
||||
private int docCount(String countedClass) throws IOException {
|
||||
return atomicReader.docFreq(new Term(classFieldName, countedClass));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
package org.apache.lucene.classification;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.TextField;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Testcase for {@link SimpleNaiveBayesClassifier}
|
||||
*/
|
||||
public class SimpleNaiveBayesClassifierTest extends LuceneTestCase {
|
||||
|
||||
private RandomIndexWriter indexWriter;
|
||||
private String textFieldName;
|
||||
private String classFieldName;
|
||||
private Analyzer analyzer;
|
||||
private Directory dir;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
analyzer = new MockAnalyzer(random());
|
||||
dir = newDirectory();
|
||||
indexWriter = new RandomIndexWriter(random(), dir);
|
||||
textFieldName = "text";
|
||||
classFieldName = "cat";
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() throws Exception {
|
||||
super.tearDown();
|
||||
indexWriter.close();
|
||||
dir.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicUsage() throws Exception {
|
||||
SlowCompositeReaderWrapper compositeReaderWrapper = null;
|
||||
try {
|
||||
populateIndex();
|
||||
SimpleNaiveBayesClassifier simpleNaiveBayesClassifier = new SimpleNaiveBayesClassifier();
|
||||
compositeReaderWrapper = new SlowCompositeReaderWrapper(indexWriter.getReader());
|
||||
simpleNaiveBayesClassifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer);
|
||||
String newText = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more. ";
|
||||
assertEquals("technology", simpleNaiveBayesClassifier.assignClass(newText));
|
||||
} finally {
|
||||
if (compositeReaderWrapper != null)
|
||||
compositeReaderWrapper.close();
|
||||
}
|
||||
}
|
||||
|
||||
private void populateIndex() throws Exception {
|
||||
|
||||
Document doc = new Document();
|
||||
doc.add(new TextField(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
|
||||
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
|
||||
"the Unknown Soldier in Warsaw Tuesday.", Field.Store.YES));
|
||||
doc.add(new TextField(classFieldName, "politics", Field.Store.YES));
|
||||
|
||||
indexWriter.addDocument(doc, analyzer);
|
||||
|
||||
doc = new Document();
|
||||
doc.add(new TextField(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
|
||||
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", Field.Store.YES));
|
||||
doc.add(new TextField(classFieldName, "politics", Field.Store.YES));
|
||||
indexWriter.addDocument(doc, analyzer);
|
||||
|
||||
doc = new Document();
|
||||
doc.add(new TextField(textFieldName, "And there's a threshold question that he has to answer for the American people and " +
|
||||
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
|
||||
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", Field.Store.YES));
|
||||
doc.add(new TextField(classFieldName, "politics", Field.Store.YES));
|
||||
indexWriter.addDocument(doc, analyzer);
|
||||
|
||||
doc = new Document();
|
||||
doc.add(new TextField(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
|
||||
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
|
||||
"Albany's School of Criminal Justice.", Field.Store.YES));
|
||||
doc.add(new TextField(classFieldName, "politics", Field.Store.YES));
|
||||
indexWriter.addDocument(doc, analyzer);
|
||||
|
||||
doc = new Document();
|
||||
doc.add(new TextField(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
|
||||
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
|
||||
"world through the Internet.", Field.Store.YES));
|
||||
doc.add(new TextField(classFieldName, "technology", Field.Store.YES));
|
||||
indexWriter.addDocument(doc, analyzer);
|
||||
|
||||
doc = new Document();
|
||||
doc.add(new TextField(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " +
|
||||
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", Field.Store.YES));
|
||||
doc.add(new TextField(classFieldName, "technology", Field.Store.YES));
|
||||
indexWriter.addDocument(doc, analyzer);
|
||||
|
||||
doc = new Document();
|
||||
doc.add(new TextField(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" +
|
||||
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
|
||||
"generally transfer or store huge volumes of personal data online.", Field.Store.YES));
|
||||
doc.add(new TextField(classFieldName, "technology", Field.Store.YES));
|
||||
indexWriter.addDocument(doc, analyzer);
|
||||
|
||||
indexWriter.commit();
|
||||
}
|
||||
|
||||
}
|
|
@ -178,6 +178,28 @@
|
|||
</ant>
|
||||
<property name="queries-javadocs.uptodate" value="true"/>
|
||||
</target>
|
||||
|
||||
<property name="classification.jar" value="${common.dir}/build/classification/lucene-classification-${version}.jar"/>
|
||||
<target name="check-classification-uptodate" unless="classification.uptodate">
|
||||
<module-uptodate name="classification" jarfile="${classification.jar}" property="classification.uptodate"/>
|
||||
</target>
|
||||
<target name="jar-classification" unless="classification.uptodate" depends="check-classification-uptodate">
|
||||
<ant dir="${common.dir}/classification" target="jar-core" inheritAll="false">
|
||||
<propertyset refid="uptodate.and.compiled.properties"/>
|
||||
</ant>
|
||||
<property name="classification.uptodate" value="true"/>
|
||||
</target>
|
||||
|
||||
<property name="classification-javadoc.jar" value="${common.dir}/build/classification/lucene-classification-${version}-javadoc.jar"/>
|
||||
<target name="check-classification-javadocs-uptodate" unless="classification-javadocs.uptodate">
|
||||
<module-uptodate name="classification" jarfile="${classification-javadoc.jar}" property="classification-javadocs.uptodate"/>
|
||||
</target>
|
||||
<target name="javadocs-classification" unless="classification-javadocs.uptodate" depends="check-classification-javadocs-uptodate">
|
||||
<ant dir="${common.dir}/classification" target="javadocs" inheritAll="false">
|
||||
<propertyset refid="uptodate.and.compiled.properties"/>
|
||||
</ant>
|
||||
<property name="classification-javadocs.uptodate" value="true"/>
|
||||
</target>
|
||||
|
||||
<property name="facet.jar" value="${common.dir}/build/facet/lucene-facet-${version}.jar"/>
|
||||
<target name="check-facet-uptodate" unless="facet.uptodate">
|
||||
|
|
Loading…
Reference in New Issue