[LUCENE-4345] - first impl of a classification module for lucene

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1384219 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2012-09-13 07:09:58 +00:00
parent aed9e0f350
commit 4711f774ee
6 changed files with 390 additions and 0 deletions

View File

@ -0,0 +1,26 @@
<?xml version="1.0"?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<project name="classification" default="default">
<description>
Classification module for Lucene
</description>
<import file="../module-build.xml"/>
</project>

View File

@ -0,0 +1,21 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<ivy-module version="2.0">
<info organisation="org.apache.lucene" module="classification"/>
</ivy-module>

View File

@ -0,0 +1,49 @@
package org.apache.lucene.classification;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.AtomicReader;
import java.io.IOException;
/**
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>
*/
public interface Classifier {
/**
* Assign a class to the given text String
* @param text a String containing text to be classified
* @return a String representing a class
* @throws IOException
*/
public String assignClass(String text) throws IOException;
/**
* Train the classifier using the underlying Lucene index
* @param atomicReader the reader to use to access the Lucene index
* @param textFieldName the name of the field used to compare documents
* @param classFieldName the name of the field containing the class assigned to documents
* @param analyzer the analyzer used to tokenize / filter the unseen text
* @throws IOException
*/
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
throws IOException;
}

View File

@ -0,0 +1,142 @@
package org.apache.lucene.classification;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
import java.io.IOException;
import java.io.StringReader;
import java.util.Collection;
import java.util.LinkedList;
/**
* A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code>
*/
public class SimpleNaiveBayesClassifier implements Classifier {
private AtomicReader atomicReader;
private String textFieldName;
private String classFieldName;
private int docsWithClassSize;
private Analyzer analyzer;
private IndexSearcher indexSearcher;
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
throws IOException {
this.atomicReader = atomicReader;
this.indexSearcher = new IndexSearcher(this.atomicReader);
this.textFieldName = textFieldName;
this.classFieldName = classFieldName;
this.analyzer = analyzer;
this.docsWithClassSize = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
}
private String[] tokenizeDoc(String doc) throws IOException {
Collection<String> result = new LinkedList<String>();
TokenStream tokenStream = analyzer.tokenStream(textFieldName, new StringReader(doc));
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
tokenStream.reset();
while (tokenStream.incrementToken()) {
result.add(charTermAttribute.toString());
}
tokenStream.end();
tokenStream.close();
return result.toArray(new String[result.size()]);
}
public String assignClass(String inputDocument) throws IOException {
if (atomicReader == null) {
throw new RuntimeException("need to train the classifier first");
}
Double max = 0d;
String foundClass = null;
Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
TermsEnum termsEnum = terms.iterator(null);
BytesRef t = termsEnum.next();
while (t != null) {
String classValue = t.utf8ToString();
// TODO : turn it to be in log scale
Double clVal = calculatePrior(classValue) * calculateLikelihood(inputDocument, classValue);
if (clVal > max) {
max = clVal;
foundClass = classValue;
}
t = termsEnum.next();
}
return foundClass;
}
private Double calculateLikelihood(String document, String c) throws IOException {
// for each word
Double result = 1d;
for (String word : tokenizeDoc(document)) {
// search with text:word AND class:c
int hits = getWordFreqForClass(word, c);
// num : count the no of times the word appears in documents of class c (+1)
double num = hits + 1; // +1 is added because of add 1 smoothing
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
double den = getTextTermFreqForClass(c) + docsWithClassSize;
// P(w|c) = num/den
double wordProbability = num / den;
result *= wordProbability;
}
// P(d|c) = P(w1|c)*...*P(wn|c)
return result;
}
private double getTextTermFreqForClass(String c) throws IOException {
Terms terms = MultiFields.getTerms(atomicReader, textFieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
int docsWithC = atomicReader.docFreq(classFieldName, new BytesRef(c));
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
}
private int getWordFreqForClass(String word, String c) throws IOException {
BooleanQuery booleanQuery = new BooleanQuery();
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
return indexSearcher.search(booleanQuery, 1).totalHits;
}
private Double calculatePrior(String currentClass) throws IOException {
return (double) docCount(currentClass) / docsWithClassSize;
}
private int docCount(String countedClass) throws IOException {
return atomicReader.docFreq(new Term(classFieldName, countedClass));
}
}

View File

@ -0,0 +1,130 @@
package org.apache.lucene.classification;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
/**
* Testcase for {@link SimpleNaiveBayesClassifier}
*/
public class SimpleNaiveBayesClassifierTest extends LuceneTestCase {
private RandomIndexWriter indexWriter;
private String textFieldName;
private String classFieldName;
private Analyzer analyzer;
private Directory dir;
@Before
public void setUp() throws Exception {
super.setUp();
analyzer = new MockAnalyzer(random());
dir = newDirectory();
indexWriter = new RandomIndexWriter(random(), dir);
textFieldName = "text";
classFieldName = "cat";
}
@After
public void tearDown() throws Exception {
super.tearDown();
indexWriter.close();
dir.close();
}
@Test
public void testBasicUsage() throws Exception {
SlowCompositeReaderWrapper compositeReaderWrapper = null;
try {
populateIndex();
SimpleNaiveBayesClassifier simpleNaiveBayesClassifier = new SimpleNaiveBayesClassifier();
compositeReaderWrapper = new SlowCompositeReaderWrapper(indexWriter.getReader());
simpleNaiveBayesClassifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer);
String newText = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more. ";
assertEquals("technology", simpleNaiveBayesClassifier.assignClass(newText));
} finally {
if (compositeReaderWrapper != null)
compositeReaderWrapper.close();
}
}
private void populateIndex() throws Exception {
Document doc = new Document();
doc.add(new TextField(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
"the Unknown Soldier in Warsaw Tuesday.", Field.Store.YES));
doc.add(new TextField(classFieldName, "politics", Field.Store.YES));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
doc.add(new TextField(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", Field.Store.YES));
doc.add(new TextField(classFieldName, "politics", Field.Store.YES));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
doc.add(new TextField(textFieldName, "And there's a threshold question that he has to answer for the American people and " +
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", Field.Store.YES));
doc.add(new TextField(classFieldName, "politics", Field.Store.YES));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
doc.add(new TextField(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
"Albany's School of Criminal Justice.", Field.Store.YES));
doc.add(new TextField(classFieldName, "politics", Field.Store.YES));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
doc.add(new TextField(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
"world through the Internet.", Field.Store.YES));
doc.add(new TextField(classFieldName, "technology", Field.Store.YES));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
doc.add(new TextField(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " +
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", Field.Store.YES));
doc.add(new TextField(classFieldName, "technology", Field.Store.YES));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
doc.add(new TextField(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" +
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
"generally transfer or store huge volumes of personal data online.", Field.Store.YES));
doc.add(new TextField(classFieldName, "technology", Field.Store.YES));
indexWriter.addDocument(doc, analyzer);
indexWriter.commit();
}
}

View File

@ -178,6 +178,28 @@
</ant>
<property name="queries-javadocs.uptodate" value="true"/>
</target>
<property name="classification.jar" value="${common.dir}/build/classification/lucene-classification-${version}.jar"/>
<target name="check-classification-uptodate" unless="classification.uptodate">
<module-uptodate name="classification" jarfile="${classification.jar}" property="classification.uptodate"/>
</target>
<target name="jar-classification" unless="classification.uptodate" depends="check-classification-uptodate">
<ant dir="${common.dir}/classification" target="jar-core" inheritAll="false">
<propertyset refid="uptodate.and.compiled.properties"/>
</ant>
<property name="classification.uptodate" value="true"/>
</target>
<property name="classification-javadoc.jar" value="${common.dir}/build/classification/lucene-classification-${version}-javadoc.jar"/>
<target name="check-classification-javadocs-uptodate" unless="classification-javadocs.uptodate">
<module-uptodate name="classification" jarfile="${classification-javadoc.jar}" property="classification-javadocs.uptodate"/>
</target>
<target name="javadocs-classification" unless="classification-javadocs.uptodate" depends="check-classification-javadocs-uptodate">
<ant dir="${common.dir}/classification" target="javadocs" inheritAll="false">
<propertyset refid="uptodate.and.compiled.properties"/>
</ant>
<property name="classification-javadocs.uptodate" value="true"/>
</target>
<property name="facet.jar" value="${common.dir}/build/facet/lucene-facet-${version}.jar"/>
<target name="check-facet-uptodate" unless="facet.uptodate">