diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java index 3d87a078e69..faea4dcc03d 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java @@ -45,7 +45,8 @@ public class DatasetSplitter { /** * Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes - * @param testRatio the ratio of the original index to be used for the test IDX as a double between 0.0 and 1.0 + * + * @param testRatio the ratio of the original index to be used for the test IDX as a double between 0.0 and 1.0 * @param crossValidationRatio the ratio of the original index to be used for the c.v. IDX as a double between 0.0 and 1.0 */ public DatasetSplitter(double testRatio, double crossValidationRatio) { @@ -55,12 +56,13 @@ public class DatasetSplitter { /** * Split a given index into 3 indexes for training, test and cross validation tasks respectively - * @param originalIndex an {@link AtomicReader} on the source index - * @param trainingIndex a {@link Directory} used to write the training index - * @param testIndex a {@link Directory} used to write the test index + * + * @param originalIndex an {@link AtomicReader} on the source index + * @param trainingIndex a {@link Directory} used to write the training index + * @param testIndex a {@link Directory} used to write the test index * @param crossValidationIndex a {@link Directory} used to write the cross validation index - * @param analyzer {@link Analyzer} used to create the new docs - * @param fieldNames names of fields that need to be put in the new indexes or null if all should be used + * @param analyzer {@link Analyzer} used to create the new docs + * @param fieldNames names of fields that need to be put in the new indexes or null if all should be used * @throws IOException if any writing operation fails on any of the indexes */ public void split(AtomicReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, @@ -98,16 +100,13 @@ public class DatasetSplitter { } } else { for (StorableField storableField : originalIndex.document(scoreDoc.doc).getFields()) { - if (storableField.readerValue()!= null){ + if (storableField.readerValue() != null) { doc.add(new Field(storableField.name(), storableField.readerValue(), ft)); - } - else if (storableField.binaryValue()!= null){ + } else if (storableField.binaryValue() != null) { doc.add(new Field(storableField.name(), storableField.binaryValue(), ft)); - } - else if (storableField.stringValue()!= null){ + } else if (storableField.stringValue() != null) { doc.add(new Field(storableField.name(), storableField.stringValue(), ft)); - } - else if (storableField.numericValue()!= null){ + } else if (storableField.numericValue() != null) { doc.add(new Field(storableField.name(), storableField.numericValue().toString(), ft)); } } @@ -116,19 +115,19 @@ public class DatasetSplitter { // add it to one of the IDXs if (b % 2 == 0 && testWriter.maxDoc() < size * testRatio) { testWriter.addDocument(doc); - testWriter.commit(); } else if (cvWriter.maxDoc() < size * crossValidationRatio) { cvWriter.addDocument(doc); - cvWriter.commit(); } else { trainingWriter.addDocument(doc); - trainingWriter.commit(); } b++; } } catch (Exception e) { throw new IOException(e); } finally { + testWriter.commit(); + cvWriter.commit(); + trainingWriter.commit(); // close IWs testWriter.close(); cvWriter.close(); diff --git a/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java b/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java index 1f6d79af911..31a8704dd5d 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java @@ -92,7 +92,7 @@ public class DataSplitterTest extends LuceneTestCase { @Test public void testSplitOnAllFields() throws Exception { - assertSplit(originalIndex, 0.1, 0.1, null); + assertSplit(originalIndex, 0.1, 0.1); }