LUCENE-9905: Make sure to use configured vector format when merging (#176)

Before when creating a VectorWriter for merging, we would always load the
default implementation. So if the format was configured with parameters, they
were ignored.

This issue was caught by `TestKnnGraph#testMergeProducesSameGraph`.
This commit is contained in:
Julie Tibshirani 2021-06-08 08:07:35 -07:00 committed by GitHub
parent 1ec2a715a2
commit e9339253f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 107 additions and 30 deletions

View File

@ -114,14 +114,7 @@ public abstract class PerFieldVectorFormat extends VectorFormat {
}
private VectorWriter getInstance(FieldInfo field) throws IOException {
VectorFormat format = null;
String fieldFormatName = field.getAttribute(PER_FIELD_FORMAT_KEY);
if (fieldFormatName != null) {
format = VectorFormat.forName(fieldFormatName);
}
if (format == null) {
format = getVectorFormatForField(field.name);
}
VectorFormat format = getVectorFormatForField(field.name);
if (format == null) {
throw new IllegalStateException(
"invalid null VectorFormat for field=\"" + field.name + "\"");

View File

@ -16,26 +16,36 @@
*/
package org.apache.lucene.codecs.perfield;
import static org.hamcrest.Matchers.equalTo;
import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Random;
import org.apache.lucene.analysis.Analyzer;
import java.util.Set;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.VectorFormat;
import org.apache.lucene.codecs.VectorReader;
import org.apache.lucene.codecs.VectorWriter;
import org.apache.lucene.codecs.asserting.AssertingCodec;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.VectorField;
import org.apache.lucene.index.BaseVectorFormatTestCase;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.RandomCodec;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.TestUtil;
import org.hamcrest.MatcherAssert;
/** Basic tests of PerFieldDocValuesFormat */
public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
@ -53,21 +63,21 @@ public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
}
public void testTwoFieldsTwoFormats() throws IOException {
Analyzer analyzer = new MockAnalyzer(random());
try (Directory directory = newDirectory()) {
// we don't use RandomIndexWriter because it might add more values than we expect !!!!1
IndexWriterConfig iwc = newIndexWriterConfig(analyzer);
VectorFormat defaultFormat = TestUtil.getDefaultVectorFormat();
VectorFormat emptyFormat = VectorFormat.EMPTY;
IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random()));
WriteRecordingVectorFormat format1 =
new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat());
WriteRecordingVectorFormat format2 =
new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat());
iwc.setCodec(
new AssertingCodec() {
@Override
public VectorFormat getVectorFormatForField(String field) {
if ("empty".equals(field)) {
return emptyFormat;
if ("field1".equals(field)) {
return format1;
} else {
return defaultFormat;
return format2;
}
}
});
@ -75,32 +85,107 @@ public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
Document doc = new Document();
doc.add(newTextField("id", "1", Field.Store.YES));
doc.add(new VectorField("field", new float[] {1, 2, 3}));
doc.add(new VectorField("field1", new float[] {1, 2, 3}));
iwriter.addDocument(doc);
iwriter.commit();
// Check that we use the empty vector format, which doesn't support writes
doc.clear();
doc.add(newTextField("id", "2", Field.Store.YES));
doc.add(new VectorField("empty", new float[] {4, 5, 6}));
expectThrows(
RuntimeException.class,
() -> {
iwriter.addDocument(doc);
iwriter.commit();
});
doc.add(new VectorField("field2", new float[] {4, 5, 6}));
iwriter.addDocument(doc);
}
// Now search for the field that was successfully indexed
// Check that each format was used to write the expected field
MatcherAssert.assertThat(format1.fieldsWritten, equalTo(Set.of("field1")));
MatcherAssert.assertThat(format2.fieldsWritten, equalTo(Set.of("field2")));
// Double-check the vectors were written
try (IndexReader ireader = DirectoryReader.open(directory)) {
TopDocs hits1 =
ireader
.leaves()
.get(0)
.reader()
.searchNearestVectors("field", new float[] {1, 2, 3}, 10, 1);
.searchNearestVectors("field1", new float[] {1, 2, 3}, 10, 1);
assertEquals(1, hits1.scoreDocs.length);
TopDocs hits2 =
ireader
.leaves()
.get(0)
.reader()
.searchNearestVectors("field2", new float[] {1, 2, 3}, 10, 1);
assertEquals(1, hits2.scoreDocs.length);
}
}
}
public void testMergeUsesNewFormat() throws IOException {
try (Directory directory = newDirectory()) {
IndexWriterConfig initialConfig = newIndexWriterConfig(new MockAnalyzer(random()));
try (IndexWriter iw = new IndexWriter(directory, initialConfig)) {
for (int i = 0; i < 3; i++) {
Document doc = new Document();
doc.add(newTextField("id", "1", Field.Store.YES));
doc.add(new VectorField("field", new float[] {1, 2, 3}));
iw.addDocument(doc);
iw.commit();
}
}
IndexWriterConfig newConfig = newIndexWriterConfig(new MockAnalyzer(random()));
WriteRecordingVectorFormat newFormat =
new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat());
newConfig.setCodec(
new AssertingCodec() {
@Override
public VectorFormat getVectorFormatForField(String field) {
return newFormat;
}
});
try (IndexWriter iw = new IndexWriter(directory, newConfig)) {
iw.forceMerge(1);
}
// Check that the new format was used while merging
MatcherAssert.assertThat(newFormat.fieldsWritten, equalTo(Set.of("field")));
}
}
private static class WriteRecordingVectorFormat extends VectorFormat {
private final VectorFormat delegate;
private final Set<String> fieldsWritten;
public WriteRecordingVectorFormat(VectorFormat delegate) {
super(delegate.getName());
this.delegate = delegate;
this.fieldsWritten = new HashSet<>();
}
@Override
public VectorWriter fieldsWriter(SegmentWriteState state) throws IOException {
VectorWriter writer = delegate.fieldsWriter(state);
return new VectorWriter() {
@Override
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
fieldsWritten.add(fieldInfo.name);
writer.writeField(fieldInfo, values);
}
@Override
public void finish() throws IOException {
writer.finish();
}
@Override
public void close() throws IOException {
writer.close();
}
};
}
@Override
public VectorReader fieldsReader(SegmentReadState state) throws IOException {
return delegate.fieldsReader(state);
}
}
}

View File

@ -147,7 +147,6 @@ public class TestKnnGraph extends LuceneTestCase {
* Verify that we get the *same* graph by indexing one segment as we do by indexing two segments
* and merging.
*/
@AwaitsFix(bugUrl = "https://issues.apache.org/jira/browse/LUCENE-9905")
public void testMergeProducesSameGraph() throws Exception {
long seed = random().nextLong();
int numDoc = atLeast(100);