LUCENE-7196 - guaranteed class coverage in split indexes through grouping by class

This commit is contained in:
Tommaso Teofili 2016-04-08 16:59:41 +02:00
parent d7867b80f8
commit 112078eaf9
3 changed files with 106 additions and 59 deletions

View File

@ -27,6 +27,7 @@
<path id="classpath"> <path id="classpath">
<path refid="base.classpath"/> <path refid="base.classpath"/>
<pathelement path="${queries.jar}"/> <pathelement path="${queries.jar}"/>
<pathelement path="${grouping.jar}"/>
</path> </path>
<path id="test.classpath"> <path id="test.classpath">
@ -35,15 +36,16 @@
<path refid="test.base.classpath"/> <path refid="test.base.classpath"/>
</path> </path>
<target name="compile-core" depends="jar-queries,jar-analyzers-common,common.compile-core" /> <target name="compile-core" depends="jar-grouping,jar-queries,jar-analyzers-common,common.compile-core" />
<target name="jar-core" depends="common.jar-core" /> <target name="jar-core" depends="common.jar-core" />
<target name="javadocs" depends="javadocs-queries,compile-core,check-javadocs-uptodate" <target name="javadocs" depends="javadocs-grouping,javadocs-misc,compile-core,check-javadocs-uptodate"
unless="javadocs-uptodate-${name}"> unless="javadocs-uptodate-${name}">
<invoke-module-javadoc> <invoke-module-javadoc>
<links> <links>
<link href="../queries"/> <link href="../queries"/>
<link href="../group"/>
</links> </links>
</invoke-module-javadoc> </invoke-module-javadoc>
</target> </target>

View File

@ -18,6 +18,7 @@ package org.apache.lucene.classification.utils;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
@ -28,11 +29,16 @@ import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Sort;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.GroupingSearch;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.uninverting.UninvertingReader;
/** /**
* Utility class for creating training / test / cross validation indexes from the original index. * Utility class for creating training / test / cross validation indexes from the original index.
@ -61,35 +67,100 @@ public class DatasetSplitter {
* @param testIndex a {@link Directory} used to write the test index * @param testIndex a {@link Directory} used to write the test index
* @param crossValidationIndex a {@link Directory} used to write the cross validation index * @param crossValidationIndex a {@link Directory} used to write the cross validation index
* @param analyzer {@link Analyzer} used to create the new docs * @param analyzer {@link Analyzer} used to create the new docs
* @param termVectors {@code true} if term vectors should be kept
* @param classFieldName names of the field used as the label for classification
* @param fieldNames names of fields that need to be put in the new indexes or <code>null</code> if all should be used * @param fieldNames names of fields that need to be put in the new indexes or <code>null</code> if all should be used
* @throws IOException if any writing operation fails on any of the indexes * @throws IOException if any writing operation fails on any of the indexes
*/ */
public void split(LeafReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, public void split(LeafReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
Analyzer analyzer, String... fieldNames) throws IOException { Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException {
// create IWs for train / test / cv IDXs // create IWs for train / test / cv IDXs
IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer)); IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer));
IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer)); IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer));
IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer)); IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
try { // try to get the exact no. of existing classes
int size = originalIndex.maxDoc(); Terms terms = originalIndex.terms(classFieldName);
long noOfClasses = -1;
if (terms != null) {
noOfClasses = terms.size();
IndexSearcher indexSearcher = new IndexSearcher(originalIndex); }
TopDocs topDocs = indexSearcher.search(new MatchAllDocsQuery(), Integer.MAX_VALUE); if (noOfClasses == -1) {
noOfClasses = 10000; // fallback
}
HashMap<String, UninvertingReader.Type> mapping = new HashMap<>();
mapping.put(classFieldName, UninvertingReader.Type.SORTED);
UninvertingReader uninvertingReader = new UninvertingReader(originalIndex, mapping);
try {
IndexSearcher indexSearcher = new IndexSearcher(uninvertingReader);
GroupingSearch gs = new GroupingSearch(classFieldName);
gs.setGroupSort(Sort.INDEXORDER);
gs.setSortWithinGroup(Sort.INDEXORDER);
gs.setAllGroups(true);
gs.setGroupDocsLimit(originalIndex.maxDoc());
TopGroups<Object> topGroups = gs.search(indexSearcher, new MatchAllDocsQuery(), 0, (int) noOfClasses);
// set the type to be indexed, stored, with term vectors // set the type to be indexed, stored, with term vectors
FieldType ft = new FieldType(TextField.TYPE_STORED); FieldType ft = new FieldType(TextField.TYPE_STORED);
if (termVectors) {
ft.setStoreTermVectors(true); ft.setStoreTermVectors(true);
ft.setStoreTermVectorOffsets(true); ft.setStoreTermVectorOffsets(true);
ft.setStoreTermVectorPositions(true); ft.setStoreTermVectorPositions(true);
}
int b = 0; int b = 0;
// iterate over existing documents // iterate over existing documents
for (ScoreDoc scoreDoc : topDocs.scoreDocs) { for (GroupDocs group : topGroups.groups) {
int totalHits = group.totalHits;
double testSize = totalHits * testRatio;
int tc = 0;
double cvSize = totalHits * crossValidationRatio;
int cvc = 0;
for (ScoreDoc scoreDoc : group.scoreDocs) {
// create a new document for indexing // create a new document for indexing
Document doc = createNewDoc(originalIndex, ft, scoreDoc, fieldNames);
// add it to one of the IDXs
if (b % 2 == 0 && tc < testSize) {
testWriter.addDocument(doc);
tc++;
} else if (cvc < cvSize) {
cvWriter.addDocument(doc);
cvc++;
} else {
trainingWriter.addDocument(doc);
}
b++;
}
}
// commit
testWriter.commit();
cvWriter.commit();
trainingWriter.commit();
// merge
testWriter.forceMerge(3);
cvWriter.forceMerge(3);
trainingWriter.forceMerge(3);
} catch (Exception e) {
throw new IOException(e);
} finally {
// close IWs
testWriter.close();
cvWriter.close();
trainingWriter.close();
uninvertingReader.close();
}
}
private Document createNewDoc(LeafReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException {
Document doc = new Document(); Document doc = new Document();
Document document = originalIndex.document(scoreDoc.doc); Document document = originalIndex.document(scoreDoc.doc);
if (fieldNames != null && fieldNames.length > 0) { if (fieldNames != null && fieldNames.length > 0) {
@ -112,34 +183,7 @@ public class DatasetSplitter {
} }
} }
} }
return doc;
// add it to one of the IDXs
if (b % 2 == 0 && testWriter.maxDoc() < size * testRatio) {
testWriter.addDocument(doc);
} else if (cvWriter.maxDoc() < size * crossValidationRatio) {
cvWriter.addDocument(doc);
} else {
trainingWriter.addDocument(doc);
}
b++;
}
// commit
testWriter.commit();
cvWriter.commit();
trainingWriter.commit();
// merge
testWriter.forceMerge(3);
cvWriter.forceMerge(3);
trainingWriter.forceMerge(3);
} catch (Exception e) {
throw new IOException(e);
} finally {
// close IWs
testWriter.close();
cvWriter.close();
trainingWriter.close();
}
} }
} }

View File

@ -17,27 +17,28 @@
package org.apache.lucene.classification.utils; package org.apache.lucene.classification.utils;
import org.apache.lucene.analysis.Analyzer; import java.io.IOException;
import java.util.Random;
import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType; import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.TextField; import org.apache.lucene.document.TextField;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.store.BaseDirectoryWrapper; import org.apache.lucene.store.BaseDirectoryWrapper;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.util.TestUtil; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import java.io.IOException;
import java.util.Random;
/** /**
* Testcase for {@link org.apache.lucene.classification.utils.DatasetSplitter} * Testcase for {@link org.apache.lucene.classification.utils.DatasetSplitter}
*/ */
@ -47,9 +48,9 @@ public class DataSplitterTest extends LuceneTestCase {
private RandomIndexWriter indexWriter; private RandomIndexWriter indexWriter;
private Directory dir; private Directory dir;
private String textFieldName = "text"; private static final String textFieldName = "text";
private String classFieldName = "class"; private static final String classFieldName = "class";
private String idFieldName = "id"; private static final String idFieldName = "id";
@Override @Override
@Before @Before
@ -65,11 +66,11 @@ public class DataSplitterTest extends LuceneTestCase {
Document doc; Document doc;
Random rnd = random(); Random rnd = random();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 1000; i++) {
doc = new Document(); doc = new Document();
doc.add(new Field(idFieldName, Integer.toString(i), ft)); doc.add(new Field(idFieldName, "id" + Integer.toString(i), ft));
doc.add(new Field(textFieldName, TestUtil.randomUnicodeString(rnd, 1024), ft)); doc.add(new Field(textFieldName, TestUtil.randomUnicodeString(rnd, 1024), ft));
doc.add(new Field(classFieldName, TestUtil.randomUnicodeString(rnd, 10), ft)); doc.add(new Field(classFieldName, Integer.toString(rnd.nextInt(10)), ft));
indexWriter.addDocument(doc); indexWriter.addDocument(doc);
} }
@ -108,18 +109,18 @@ public class DataSplitterTest extends LuceneTestCase {
try { try {
DatasetSplitter datasetSplitter = new DatasetSplitter(testRatio, crossValidationRatio); DatasetSplitter datasetSplitter = new DatasetSplitter(testRatio, crossValidationRatio);
datasetSplitter.split(originalIndex, trainingIndex, testIndex, crossValidationIndex, new MockAnalyzer(random()), fieldNames); datasetSplitter.split(originalIndex, trainingIndex, testIndex, crossValidationIndex, new MockAnalyzer(random()), true, classFieldName, fieldNames);
assertNotNull(trainingIndex); assertNotNull(trainingIndex);
assertNotNull(testIndex); assertNotNull(testIndex);
assertNotNull(crossValidationIndex); assertNotNull(crossValidationIndex);
DirectoryReader trainingReader = DirectoryReader.open(trainingIndex); DirectoryReader trainingReader = DirectoryReader.open(trainingIndex);
assertTrue((int) (originalIndex.maxDoc() * (1d - testRatio - crossValidationRatio)) == trainingReader.maxDoc()); assertEquals((int) (originalIndex.maxDoc() * (1d - testRatio - crossValidationRatio)), trainingReader.maxDoc(), 20);
DirectoryReader testReader = DirectoryReader.open(testIndex); DirectoryReader testReader = DirectoryReader.open(testIndex);
assertTrue((int) (originalIndex.maxDoc() * testRatio) == testReader.maxDoc()); assertEquals((int) (originalIndex.maxDoc() * testRatio), testReader.maxDoc(), 20);
DirectoryReader cvReader = DirectoryReader.open(crossValidationIndex); DirectoryReader cvReader = DirectoryReader.open(crossValidationIndex);
assertTrue((int) (originalIndex.maxDoc() * crossValidationRatio) == cvReader.maxDoc()); assertEquals((int) (originalIndex.maxDoc() * crossValidationRatio), cvReader.maxDoc(), 20);
trainingReader.close(); trainingReader.close();
testReader.close(); testReader.close();