diff --git a/lucene/classification/build.xml b/lucene/classification/build.xml
index 930b1fa80bd..fd15239fd18 100644
--- a/lucene/classification/build.xml
+++ b/lucene/classification/build.xml
@@ -27,6 +27,7 @@
+
@@ -35,15 +36,16 @@
-
+
-
+
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
index 0b03b94f56d..fce786bf1e9 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
@@ -18,6 +18,7 @@ package org.apache.lucene.classification.utils;
import java.io.IOException;
+import java.util.HashMap;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
@@ -28,11 +29,16 @@ import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.Terms;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreDoc;
-import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.Sort;
+import org.apache.lucene.search.grouping.GroupDocs;
+import org.apache.lucene.search.grouping.GroupingSearch;
+import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.store.Directory;
+import org.apache.lucene.uninverting.UninvertingReader;
/**
* Utility class for creating training / test / cross validation indexes from the original index.
@@ -61,67 +67,78 @@ public class DatasetSplitter {
* @param testIndex a {@link Directory} used to write the test index
* @param crossValidationIndex a {@link Directory} used to write the cross validation index
* @param analyzer {@link Analyzer} used to create the new docs
+ * @param termVectors {@code true} if term vectors should be kept
+ * @param classFieldName names of the field used as the label for classification
* @param fieldNames names of fields that need to be put in the new indexes or null if all should be used
* @throws IOException if any writing operation fails on any of the indexes
*/
public void split(LeafReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
- Analyzer analyzer, String... fieldNames) throws IOException {
+ Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException {
// create IWs for train / test / cv IDXs
IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer));
IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer));
IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
- try {
- int size = originalIndex.maxDoc();
+ // try to get the exact no. of existing classes
+ Terms terms = originalIndex.terms(classFieldName);
+ long noOfClasses = -1;
+ if (terms != null) {
+ noOfClasses = terms.size();
- IndexSearcher indexSearcher = new IndexSearcher(originalIndex);
- TopDocs topDocs = indexSearcher.search(new MatchAllDocsQuery(), Integer.MAX_VALUE);
+ }
+ if (noOfClasses == -1) {
+ noOfClasses = 10000; // fallback
+ }
+
+ HashMap mapping = new HashMap<>();
+ mapping.put(classFieldName, UninvertingReader.Type.SORTED);
+ UninvertingReader uninvertingReader = new UninvertingReader(originalIndex, mapping);
+
+ try {
+
+ IndexSearcher indexSearcher = new IndexSearcher(uninvertingReader);
+ GroupingSearch gs = new GroupingSearch(classFieldName);
+ gs.setGroupSort(Sort.INDEXORDER);
+ gs.setSortWithinGroup(Sort.INDEXORDER);
+ gs.setAllGroups(true);
+ gs.setGroupDocsLimit(originalIndex.maxDoc());
+ TopGroups