LUCENE-9905: Make sure to use configured vector format when merging (#176)

Before when creating a VectorWriter for merging, we would always load the
default implementation. So if the format was configured with parameters, they
were ignored.

This issue was caught by `TestKnnGraph#testMergeProducesSameGraph`.
This commit is contained in:
Julie Tibshirani 2021-06-08 08:07:35 -07:00 committed by GitHub
parent 1ec2a715a2
commit e9339253f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 107 additions and 30 deletions

View File

@ -114,14 +114,7 @@ public abstract class PerFieldVectorFormat extends VectorFormat {
} }
private VectorWriter getInstance(FieldInfo field) throws IOException { private VectorWriter getInstance(FieldInfo field) throws IOException {
VectorFormat format = null; VectorFormat format = getVectorFormatForField(field.name);
String fieldFormatName = field.getAttribute(PER_FIELD_FORMAT_KEY);
if (fieldFormatName != null) {
format = VectorFormat.forName(fieldFormatName);
}
if (format == null) {
format = getVectorFormatForField(field.name);
}
if (format == null) { if (format == null) {
throw new IllegalStateException( throw new IllegalStateException(
"invalid null VectorFormat for field=\"" + field.name + "\""); "invalid null VectorFormat for field=\"" + field.name + "\"");

View File

@ -16,26 +16,36 @@
*/ */
package org.apache.lucene.codecs.perfield; package org.apache.lucene.codecs.perfield;
import static org.hamcrest.Matchers.equalTo;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.Random; import java.util.Random;
import org.apache.lucene.analysis.Analyzer; import java.util.Set;
import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.VectorFormat; import org.apache.lucene.codecs.VectorFormat;
import org.apache.lucene.codecs.VectorReader;
import org.apache.lucene.codecs.VectorWriter;
import org.apache.lucene.codecs.asserting.AssertingCodec; import org.apache.lucene.codecs.asserting.AssertingCodec;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.VectorField; import org.apache.lucene.document.VectorField;
import org.apache.lucene.index.BaseVectorFormatTestCase; import org.apache.lucene.index.BaseVectorFormatTestCase;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.RandomCodec; import org.apache.lucene.index.RandomCodec;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.util.TestUtil; import org.apache.lucene.util.TestUtil;
import org.hamcrest.MatcherAssert;
/** Basic tests of PerFieldDocValuesFormat */ /** Basic tests of PerFieldDocValuesFormat */
public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase { public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
@ -53,21 +63,21 @@ public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
} }
public void testTwoFieldsTwoFormats() throws IOException { public void testTwoFieldsTwoFormats() throws IOException {
Analyzer analyzer = new MockAnalyzer(random());
try (Directory directory = newDirectory()) { try (Directory directory = newDirectory()) {
// we don't use RandomIndexWriter because it might add more values than we expect !!!!1 // we don't use RandomIndexWriter because it might add more values than we expect !!!!1
IndexWriterConfig iwc = newIndexWriterConfig(analyzer); IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random()));
VectorFormat defaultFormat = TestUtil.getDefaultVectorFormat(); WriteRecordingVectorFormat format1 =
VectorFormat emptyFormat = VectorFormat.EMPTY; new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat());
WriteRecordingVectorFormat format2 =
new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat());
iwc.setCodec( iwc.setCodec(
new AssertingCodec() { new AssertingCodec() {
@Override @Override
public VectorFormat getVectorFormatForField(String field) { public VectorFormat getVectorFormatForField(String field) {
if ("empty".equals(field)) { if ("field1".equals(field)) {
return emptyFormat; return format1;
} else { } else {
return defaultFormat; return format2;
} }
} }
}); });
@ -75,32 +85,107 @@ public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) { try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
Document doc = new Document(); Document doc = new Document();
doc.add(newTextField("id", "1", Field.Store.YES)); doc.add(newTextField("id", "1", Field.Store.YES));
doc.add(new VectorField("field", new float[] {1, 2, 3})); doc.add(new VectorField("field1", new float[] {1, 2, 3}));
iwriter.addDocument(doc); iwriter.addDocument(doc);
iwriter.commit();
// Check that we use the empty vector format, which doesn't support writes
doc.clear(); doc.clear();
doc.add(newTextField("id", "2", Field.Store.YES)); doc.add(newTextField("id", "2", Field.Store.YES));
doc.add(new VectorField("empty", new float[] {4, 5, 6})); doc.add(new VectorField("field2", new float[] {4, 5, 6}));
expectThrows(
RuntimeException.class,
() -> {
iwriter.addDocument(doc); iwriter.addDocument(doc);
iwriter.commit();
});
} }
// Now search for the field that was successfully indexed // Check that each format was used to write the expected field
MatcherAssert.assertThat(format1.fieldsWritten, equalTo(Set.of("field1")));
MatcherAssert.assertThat(format2.fieldsWritten, equalTo(Set.of("field2")));
// Double-check the vectors were written
try (IndexReader ireader = DirectoryReader.open(directory)) { try (IndexReader ireader = DirectoryReader.open(directory)) {
TopDocs hits1 = TopDocs hits1 =
ireader ireader
.leaves() .leaves()
.get(0) .get(0)
.reader() .reader()
.searchNearestVectors("field", new float[] {1, 2, 3}, 10, 1); .searchNearestVectors("field1", new float[] {1, 2, 3}, 10, 1);
assertEquals(1, hits1.scoreDocs.length); assertEquals(1, hits1.scoreDocs.length);
TopDocs hits2 =
ireader
.leaves()
.get(0)
.reader()
.searchNearestVectors("field2", new float[] {1, 2, 3}, 10, 1);
assertEquals(1, hits2.scoreDocs.length);
} }
} }
} }
public void testMergeUsesNewFormat() throws IOException {
try (Directory directory = newDirectory()) {
IndexWriterConfig initialConfig = newIndexWriterConfig(new MockAnalyzer(random()));
try (IndexWriter iw = new IndexWriter(directory, initialConfig)) {
for (int i = 0; i < 3; i++) {
Document doc = new Document();
doc.add(newTextField("id", "1", Field.Store.YES));
doc.add(new VectorField("field", new float[] {1, 2, 3}));
iw.addDocument(doc);
iw.commit();
}
}
IndexWriterConfig newConfig = newIndexWriterConfig(new MockAnalyzer(random()));
WriteRecordingVectorFormat newFormat =
new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat());
newConfig.setCodec(
new AssertingCodec() {
@Override
public VectorFormat getVectorFormatForField(String field) {
return newFormat;
}
});
try (IndexWriter iw = new IndexWriter(directory, newConfig)) {
iw.forceMerge(1);
}
// Check that the new format was used while merging
MatcherAssert.assertThat(newFormat.fieldsWritten, equalTo(Set.of("field")));
}
}
private static class WriteRecordingVectorFormat extends VectorFormat {
private final VectorFormat delegate;
private final Set<String> fieldsWritten;
public WriteRecordingVectorFormat(VectorFormat delegate) {
super(delegate.getName());
this.delegate = delegate;
this.fieldsWritten = new HashSet<>();
}
@Override
public VectorWriter fieldsWriter(SegmentWriteState state) throws IOException {
VectorWriter writer = delegate.fieldsWriter(state);
return new VectorWriter() {
@Override
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
fieldsWritten.add(fieldInfo.name);
writer.writeField(fieldInfo, values);
}
@Override
public void finish() throws IOException {
writer.finish();
}
@Override
public void close() throws IOException {
writer.close();
}
};
}
@Override
public VectorReader fieldsReader(SegmentReadState state) throws IOException {
return delegate.fieldsReader(state);
}
}
} }

View File

@ -147,7 +147,6 @@ public class TestKnnGraph extends LuceneTestCase {
* Verify that we get the *same* graph by indexing one segment as we do by indexing two segments * Verify that we get the *same* graph by indexing one segment as we do by indexing two segments
* and merging. * and merging.
*/ */
@AwaitsFix(bugUrl = "https://issues.apache.org/jira/browse/LUCENE-9905")
public void testMergeProducesSameGraph() throws Exception { public void testMergeProducesSameGraph() throws Exception {
long seed = random().nextLong(); long seed = random().nextLong();
int numDoc = atLeast(100); int numDoc = atLeast(100);