mirror of https://github.com/apache/lucene.git
LUCENE-7196 - guaranteed class coverage in split indexes through grouping by class
This commit is contained in:
parent
d7867b80f8
commit
112078eaf9
|
@ -27,6 +27,7 @@
|
||||||
<path id="classpath">
|
<path id="classpath">
|
||||||
<path refid="base.classpath"/>
|
<path refid="base.classpath"/>
|
||||||
<pathelement path="${queries.jar}"/>
|
<pathelement path="${queries.jar}"/>
|
||||||
|
<pathelement path="${grouping.jar}"/>
|
||||||
</path>
|
</path>
|
||||||
|
|
||||||
<path id="test.classpath">
|
<path id="test.classpath">
|
||||||
|
@ -35,15 +36,16 @@
|
||||||
<path refid="test.base.classpath"/>
|
<path refid="test.base.classpath"/>
|
||||||
</path>
|
</path>
|
||||||
|
|
||||||
<target name="compile-core" depends="jar-queries,jar-analyzers-common,common.compile-core" />
|
<target name="compile-core" depends="jar-grouping,jar-queries,jar-analyzers-common,common.compile-core" />
|
||||||
|
|
||||||
<target name="jar-core" depends="common.jar-core" />
|
<target name="jar-core" depends="common.jar-core" />
|
||||||
|
|
||||||
<target name="javadocs" depends="javadocs-queries,compile-core,check-javadocs-uptodate"
|
<target name="javadocs" depends="javadocs-grouping,javadocs-misc,compile-core,check-javadocs-uptodate"
|
||||||
unless="javadocs-uptodate-${name}">
|
unless="javadocs-uptodate-${name}">
|
||||||
<invoke-module-javadoc>
|
<invoke-module-javadoc>
|
||||||
<links>
|
<links>
|
||||||
<link href="../queries"/>
|
<link href="../queries"/>
|
||||||
|
<link href="../group"/>
|
||||||
</links>
|
</links>
|
||||||
</invoke-module-javadoc>
|
</invoke-module-javadoc>
|
||||||
</target>
|
</target>
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.apache.lucene.classification.utils;
|
||||||
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
|
@ -28,11 +29,16 @@ import org.apache.lucene.index.IndexWriter;
|
||||||
import org.apache.lucene.index.IndexWriterConfig;
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
import org.apache.lucene.index.IndexableField;
|
import org.apache.lucene.index.IndexableField;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.Terms;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.Sort;
|
||||||
|
import org.apache.lucene.search.grouping.GroupDocs;
|
||||||
|
import org.apache.lucene.search.grouping.GroupingSearch;
|
||||||
|
import org.apache.lucene.search.grouping.TopGroups;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.uninverting.UninvertingReader;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Utility class for creating training / test / cross validation indexes from the original index.
|
* Utility class for creating training / test / cross validation indexes from the original index.
|
||||||
|
@ -61,67 +67,78 @@ public class DatasetSplitter {
|
||||||
* @param testIndex a {@link Directory} used to write the test index
|
* @param testIndex a {@link Directory} used to write the test index
|
||||||
* @param crossValidationIndex a {@link Directory} used to write the cross validation index
|
* @param crossValidationIndex a {@link Directory} used to write the cross validation index
|
||||||
* @param analyzer {@link Analyzer} used to create the new docs
|
* @param analyzer {@link Analyzer} used to create the new docs
|
||||||
|
* @param termVectors {@code true} if term vectors should be kept
|
||||||
|
* @param classFieldName names of the field used as the label for classification
|
||||||
* @param fieldNames names of fields that need to be put in the new indexes or <code>null</code> if all should be used
|
* @param fieldNames names of fields that need to be put in the new indexes or <code>null</code> if all should be used
|
||||||
* @throws IOException if any writing operation fails on any of the indexes
|
* @throws IOException if any writing operation fails on any of the indexes
|
||||||
*/
|
*/
|
||||||
public void split(LeafReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
|
public void split(LeafReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
|
||||||
Analyzer analyzer, String... fieldNames) throws IOException {
|
Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException {
|
||||||
|
|
||||||
// create IWs for train / test / cv IDXs
|
// create IWs for train / test / cv IDXs
|
||||||
IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer));
|
IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer));
|
||||||
IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer));
|
IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer));
|
||||||
IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
|
IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
|
||||||
|
|
||||||
try {
|
// try to get the exact no. of existing classes
|
||||||
int size = originalIndex.maxDoc();
|
Terms terms = originalIndex.terms(classFieldName);
|
||||||
|
long noOfClasses = -1;
|
||||||
|
if (terms != null) {
|
||||||
|
noOfClasses = terms.size();
|
||||||
|
|
||||||
IndexSearcher indexSearcher = new IndexSearcher(originalIndex);
|
}
|
||||||
TopDocs topDocs = indexSearcher.search(new MatchAllDocsQuery(), Integer.MAX_VALUE);
|
if (noOfClasses == -1) {
|
||||||
|
noOfClasses = 10000; // fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
HashMap<String, UninvertingReader.Type> mapping = new HashMap<>();
|
||||||
|
mapping.put(classFieldName, UninvertingReader.Type.SORTED);
|
||||||
|
UninvertingReader uninvertingReader = new UninvertingReader(originalIndex, mapping);
|
||||||
|
|
||||||
|
try {
|
||||||
|
|
||||||
|
IndexSearcher indexSearcher = new IndexSearcher(uninvertingReader);
|
||||||
|
GroupingSearch gs = new GroupingSearch(classFieldName);
|
||||||
|
gs.setGroupSort(Sort.INDEXORDER);
|
||||||
|
gs.setSortWithinGroup(Sort.INDEXORDER);
|
||||||
|
gs.setAllGroups(true);
|
||||||
|
gs.setGroupDocsLimit(originalIndex.maxDoc());
|
||||||
|
TopGroups<Object> topGroups = gs.search(indexSearcher, new MatchAllDocsQuery(), 0, (int) noOfClasses);
|
||||||
|
|
||||||
// set the type to be indexed, stored, with term vectors
|
// set the type to be indexed, stored, with term vectors
|
||||||
FieldType ft = new FieldType(TextField.TYPE_STORED);
|
FieldType ft = new FieldType(TextField.TYPE_STORED);
|
||||||
ft.setStoreTermVectors(true);
|
if (termVectors) {
|
||||||
ft.setStoreTermVectorOffsets(true);
|
ft.setStoreTermVectors(true);
|
||||||
ft.setStoreTermVectorPositions(true);
|
ft.setStoreTermVectorOffsets(true);
|
||||||
|
ft.setStoreTermVectorPositions(true);
|
||||||
|
}
|
||||||
|
|
||||||
int b = 0;
|
int b = 0;
|
||||||
|
|
||||||
// iterate over existing documents
|
// iterate over existing documents
|
||||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
for (GroupDocs group : topGroups.groups) {
|
||||||
|
int totalHits = group.totalHits;
|
||||||
|
double testSize = totalHits * testRatio;
|
||||||
|
int tc = 0;
|
||||||
|
double cvSize = totalHits * crossValidationRatio;
|
||||||
|
int cvc = 0;
|
||||||
|
for (ScoreDoc scoreDoc : group.scoreDocs) {
|
||||||
|
|
||||||
// create a new document for indexing
|
// create a new document for indexing
|
||||||
Document doc = new Document();
|
Document doc = createNewDoc(originalIndex, ft, scoreDoc, fieldNames);
|
||||||
Document document = originalIndex.document(scoreDoc.doc);
|
|
||||||
if (fieldNames != null && fieldNames.length > 0) {
|
|
||||||
for (String fieldName : fieldNames) {
|
|
||||||
IndexableField field = document.getField(fieldName);
|
|
||||||
if (field != null) {
|
|
||||||
doc.add(new Field(fieldName, field.stringValue(), ft));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (IndexableField field : document.getFields()) {
|
|
||||||
if (field.readerValue() != null) {
|
|
||||||
doc.add(new Field(field.name(), field.readerValue(), ft));
|
|
||||||
} else if (field.binaryValue() != null) {
|
|
||||||
doc.add(new Field(field.name(), field.binaryValue(), ft));
|
|
||||||
} else if (field.stringValue() != null) {
|
|
||||||
doc.add(new Field(field.name(), field.stringValue(), ft));
|
|
||||||
} else if (field.numericValue() != null) {
|
|
||||||
doc.add(new Field(field.name(), field.numericValue().toString(), ft));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// add it to one of the IDXs
|
// add it to one of the IDXs
|
||||||
if (b % 2 == 0 && testWriter.maxDoc() < size * testRatio) {
|
if (b % 2 == 0 && tc < testSize) {
|
||||||
testWriter.addDocument(doc);
|
testWriter.addDocument(doc);
|
||||||
} else if (cvWriter.maxDoc() < size * crossValidationRatio) {
|
tc++;
|
||||||
cvWriter.addDocument(doc);
|
} else if (cvc < cvSize) {
|
||||||
} else {
|
cvWriter.addDocument(doc);
|
||||||
trainingWriter.addDocument(doc);
|
cvc++;
|
||||||
|
} else {
|
||||||
|
trainingWriter.addDocument(doc);
|
||||||
|
}
|
||||||
|
b++;
|
||||||
}
|
}
|
||||||
b++;
|
|
||||||
}
|
}
|
||||||
// commit
|
// commit
|
||||||
testWriter.commit();
|
testWriter.commit();
|
||||||
|
@ -139,7 +156,34 @@ public class DatasetSplitter {
|
||||||
testWriter.close();
|
testWriter.close();
|
||||||
cvWriter.close();
|
cvWriter.close();
|
||||||
trainingWriter.close();
|
trainingWriter.close();
|
||||||
|
uninvertingReader.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Document createNewDoc(LeafReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException {
|
||||||
|
Document doc = new Document();
|
||||||
|
Document document = originalIndex.document(scoreDoc.doc);
|
||||||
|
if (fieldNames != null && fieldNames.length > 0) {
|
||||||
|
for (String fieldName : fieldNames) {
|
||||||
|
IndexableField field = document.getField(fieldName);
|
||||||
|
if (field != null) {
|
||||||
|
doc.add(new Field(fieldName, field.stringValue(), ft));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (IndexableField field : document.getFields()) {
|
||||||
|
if (field.readerValue() != null) {
|
||||||
|
doc.add(new Field(field.name(), field.readerValue(), ft));
|
||||||
|
} else if (field.binaryValue() != null) {
|
||||||
|
doc.add(new Field(field.name(), field.binaryValue(), ft));
|
||||||
|
} else if (field.stringValue() != null) {
|
||||||
|
doc.add(new Field(field.name(), field.stringValue(), ft));
|
||||||
|
} else if (field.numericValue() != null) {
|
||||||
|
doc.add(new Field(field.name(), field.numericValue().toString(), ft));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,27 +17,28 @@
|
||||||
package org.apache.lucene.classification.utils;
|
package org.apache.lucene.classification.utils;
|
||||||
|
|
||||||
|
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import java.io.IOException;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.MockAnalyzer;
|
import org.apache.lucene.analysis.MockAnalyzer;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.FieldType;
|
import org.apache.lucene.document.FieldType;
|
||||||
|
import org.apache.lucene.document.SortedDocValuesField;
|
||||||
import org.apache.lucene.document.TextField;
|
import org.apache.lucene.document.TextField;
|
||||||
import org.apache.lucene.index.LeafReader;
|
|
||||||
import org.apache.lucene.index.DirectoryReader;
|
import org.apache.lucene.index.DirectoryReader;
|
||||||
import org.apache.lucene.index.IndexReader;
|
import org.apache.lucene.index.IndexReader;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.index.RandomIndexWriter;
|
import org.apache.lucene.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.store.BaseDirectoryWrapper;
|
import org.apache.lucene.store.BaseDirectoryWrapper;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.util.TestUtil;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
|
import org.apache.lucene.util.TestUtil;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Testcase for {@link org.apache.lucene.classification.utils.DatasetSplitter}
|
* Testcase for {@link org.apache.lucene.classification.utils.DatasetSplitter}
|
||||||
*/
|
*/
|
||||||
|
@ -47,9 +48,9 @@ public class DataSplitterTest extends LuceneTestCase {
|
||||||
private RandomIndexWriter indexWriter;
|
private RandomIndexWriter indexWriter;
|
||||||
private Directory dir;
|
private Directory dir;
|
||||||
|
|
||||||
private String textFieldName = "text";
|
private static final String textFieldName = "text";
|
||||||
private String classFieldName = "class";
|
private static final String classFieldName = "class";
|
||||||
private String idFieldName = "id";
|
private static final String idFieldName = "id";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@Before
|
@Before
|
||||||
|
@ -65,11 +66,11 @@ public class DataSplitterTest extends LuceneTestCase {
|
||||||
|
|
||||||
Document doc;
|
Document doc;
|
||||||
Random rnd = random();
|
Random rnd = random();
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 1000; i++) {
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
doc.add(new Field(idFieldName, Integer.toString(i), ft));
|
doc.add(new Field(idFieldName, "id" + Integer.toString(i), ft));
|
||||||
doc.add(new Field(textFieldName, TestUtil.randomUnicodeString(rnd, 1024), ft));
|
doc.add(new Field(textFieldName, TestUtil.randomUnicodeString(rnd, 1024), ft));
|
||||||
doc.add(new Field(classFieldName, TestUtil.randomUnicodeString(rnd, 10), ft));
|
doc.add(new Field(classFieldName, Integer.toString(rnd.nextInt(10)), ft));
|
||||||
indexWriter.addDocument(doc);
|
indexWriter.addDocument(doc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,18 +109,18 @@ public class DataSplitterTest extends LuceneTestCase {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
DatasetSplitter datasetSplitter = new DatasetSplitter(testRatio, crossValidationRatio);
|
DatasetSplitter datasetSplitter = new DatasetSplitter(testRatio, crossValidationRatio);
|
||||||
datasetSplitter.split(originalIndex, trainingIndex, testIndex, crossValidationIndex, new MockAnalyzer(random()), fieldNames);
|
datasetSplitter.split(originalIndex, trainingIndex, testIndex, crossValidationIndex, new MockAnalyzer(random()), true, classFieldName, fieldNames);
|
||||||
|
|
||||||
assertNotNull(trainingIndex);
|
assertNotNull(trainingIndex);
|
||||||
assertNotNull(testIndex);
|
assertNotNull(testIndex);
|
||||||
assertNotNull(crossValidationIndex);
|
assertNotNull(crossValidationIndex);
|
||||||
|
|
||||||
DirectoryReader trainingReader = DirectoryReader.open(trainingIndex);
|
DirectoryReader trainingReader = DirectoryReader.open(trainingIndex);
|
||||||
assertTrue((int) (originalIndex.maxDoc() * (1d - testRatio - crossValidationRatio)) == trainingReader.maxDoc());
|
assertEquals((int) (originalIndex.maxDoc() * (1d - testRatio - crossValidationRatio)), trainingReader.maxDoc(), 20);
|
||||||
DirectoryReader testReader = DirectoryReader.open(testIndex);
|
DirectoryReader testReader = DirectoryReader.open(testIndex);
|
||||||
assertTrue((int) (originalIndex.maxDoc() * testRatio) == testReader.maxDoc());
|
assertEquals((int) (originalIndex.maxDoc() * testRatio), testReader.maxDoc(), 20);
|
||||||
DirectoryReader cvReader = DirectoryReader.open(crossValidationIndex);
|
DirectoryReader cvReader = DirectoryReader.open(crossValidationIndex);
|
||||||
assertTrue((int) (originalIndex.maxDoc() * crossValidationRatio) == cvReader.maxDoc());
|
assertEquals((int) (originalIndex.maxDoc() * crossValidationRatio), cvReader.maxDoc(), 20);
|
||||||
|
|
||||||
trainingReader.close();
|
trainingReader.close();
|
||||||
testReader.close();
|
testReader.close();
|
||||||
|
|
Loading…
Reference in New Issue