mirror of https://github.com/apache/lucene.git
LUCENE-9905: Make sure to use configured vector format when merging (#176)
Before when creating a VectorWriter for merging, we would always load the default implementation. So if the format was configured with parameters, they were ignored. This issue was caught by `TestKnnGraph#testMergeProducesSameGraph`.
This commit is contained in:
parent
1ec2a715a2
commit
e9339253f5
|
@ -114,14 +114,7 @@ public abstract class PerFieldVectorFormat extends VectorFormat {
|
|||
}
|
||||
|
||||
private VectorWriter getInstance(FieldInfo field) throws IOException {
|
||||
VectorFormat format = null;
|
||||
String fieldFormatName = field.getAttribute(PER_FIELD_FORMAT_KEY);
|
||||
if (fieldFormatName != null) {
|
||||
format = VectorFormat.forName(fieldFormatName);
|
||||
}
|
||||
if (format == null) {
|
||||
format = getVectorFormatForField(field.name);
|
||||
}
|
||||
VectorFormat format = getVectorFormatForField(field.name);
|
||||
if (format == null) {
|
||||
throw new IllegalStateException(
|
||||
"invalid null VectorFormat for field=\"" + field.name + "\"");
|
||||
|
|
|
@ -16,26 +16,36 @@
|
|||
*/
|
||||
package org.apache.lucene.codecs.perfield;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.Random;
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import java.util.Set;
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.VectorFormat;
|
||||
import org.apache.lucene.codecs.VectorReader;
|
||||
import org.apache.lucene.codecs.VectorWriter;
|
||||
import org.apache.lucene.codecs.asserting.AssertingCodec;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.VectorField;
|
||||
import org.apache.lucene.index.BaseVectorFormatTestCase;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.RandomCodec;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.TestUtil;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
/** Basic tests of PerFieldDocValuesFormat */
|
||||
public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
|
||||
|
@ -53,21 +63,21 @@ public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
|
|||
}
|
||||
|
||||
public void testTwoFieldsTwoFormats() throws IOException {
|
||||
Analyzer analyzer = new MockAnalyzer(random());
|
||||
|
||||
try (Directory directory = newDirectory()) {
|
||||
// we don't use RandomIndexWriter because it might add more values than we expect !!!!1
|
||||
IndexWriterConfig iwc = newIndexWriterConfig(analyzer);
|
||||
VectorFormat defaultFormat = TestUtil.getDefaultVectorFormat();
|
||||
VectorFormat emptyFormat = VectorFormat.EMPTY;
|
||||
IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random()));
|
||||
WriteRecordingVectorFormat format1 =
|
||||
new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat());
|
||||
WriteRecordingVectorFormat format2 =
|
||||
new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat());
|
||||
iwc.setCodec(
|
||||
new AssertingCodec() {
|
||||
@Override
|
||||
public VectorFormat getVectorFormatForField(String field) {
|
||||
if ("empty".equals(field)) {
|
||||
return emptyFormat;
|
||||
if ("field1".equals(field)) {
|
||||
return format1;
|
||||
} else {
|
||||
return defaultFormat;
|
||||
return format2;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -75,32 +85,107 @@ public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
|
|||
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
|
||||
Document doc = new Document();
|
||||
doc.add(newTextField("id", "1", Field.Store.YES));
|
||||
doc.add(new VectorField("field", new float[] {1, 2, 3}));
|
||||
doc.add(new VectorField("field1", new float[] {1, 2, 3}));
|
||||
iwriter.addDocument(doc);
|
||||
iwriter.commit();
|
||||
|
||||
// Check that we use the empty vector format, which doesn't support writes
|
||||
doc.clear();
|
||||
doc.add(newTextField("id", "2", Field.Store.YES));
|
||||
doc.add(new VectorField("empty", new float[] {4, 5, 6}));
|
||||
expectThrows(
|
||||
RuntimeException.class,
|
||||
() -> {
|
||||
doc.add(new VectorField("field2", new float[] {4, 5, 6}));
|
||||
iwriter.addDocument(doc);
|
||||
iwriter.commit();
|
||||
});
|
||||
}
|
||||
|
||||
// Now search for the field that was successfully indexed
|
||||
// Check that each format was used to write the expected field
|
||||
MatcherAssert.assertThat(format1.fieldsWritten, equalTo(Set.of("field1")));
|
||||
MatcherAssert.assertThat(format2.fieldsWritten, equalTo(Set.of("field2")));
|
||||
|
||||
// Double-check the vectors were written
|
||||
try (IndexReader ireader = DirectoryReader.open(directory)) {
|
||||
TopDocs hits1 =
|
||||
ireader
|
||||
.leaves()
|
||||
.get(0)
|
||||
.reader()
|
||||
.searchNearestVectors("field", new float[] {1, 2, 3}, 10, 1);
|
||||
.searchNearestVectors("field1", new float[] {1, 2, 3}, 10, 1);
|
||||
assertEquals(1, hits1.scoreDocs.length);
|
||||
TopDocs hits2 =
|
||||
ireader
|
||||
.leaves()
|
||||
.get(0)
|
||||
.reader()
|
||||
.searchNearestVectors("field2", new float[] {1, 2, 3}, 10, 1);
|
||||
assertEquals(1, hits2.scoreDocs.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testMergeUsesNewFormat() throws IOException {
|
||||
try (Directory directory = newDirectory()) {
|
||||
IndexWriterConfig initialConfig = newIndexWriterConfig(new MockAnalyzer(random()));
|
||||
try (IndexWriter iw = new IndexWriter(directory, initialConfig)) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
Document doc = new Document();
|
||||
doc.add(newTextField("id", "1", Field.Store.YES));
|
||||
doc.add(new VectorField("field", new float[] {1, 2, 3}));
|
||||
iw.addDocument(doc);
|
||||
iw.commit();
|
||||
}
|
||||
}
|
||||
|
||||
IndexWriterConfig newConfig = newIndexWriterConfig(new MockAnalyzer(random()));
|
||||
WriteRecordingVectorFormat newFormat =
|
||||
new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat());
|
||||
newConfig.setCodec(
|
||||
new AssertingCodec() {
|
||||
@Override
|
||||
public VectorFormat getVectorFormatForField(String field) {
|
||||
return newFormat;
|
||||
}
|
||||
});
|
||||
|
||||
try (IndexWriter iw = new IndexWriter(directory, newConfig)) {
|
||||
iw.forceMerge(1);
|
||||
}
|
||||
|
||||
// Check that the new format was used while merging
|
||||
MatcherAssert.assertThat(newFormat.fieldsWritten, equalTo(Set.of("field")));
|
||||
}
|
||||
}
|
||||
|
||||
private static class WriteRecordingVectorFormat extends VectorFormat {
|
||||
private final VectorFormat delegate;
|
||||
private final Set<String> fieldsWritten;
|
||||
|
||||
public WriteRecordingVectorFormat(VectorFormat delegate) {
|
||||
super(delegate.getName());
|
||||
this.delegate = delegate;
|
||||
this.fieldsWritten = new HashSet<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||
VectorWriter writer = delegate.fieldsWriter(state);
|
||||
return new VectorWriter() {
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
|
||||
fieldsWritten.add(fieldInfo.name);
|
||||
writer.writeField(fieldInfo, values);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() throws IOException {
|
||||
writer.finish();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
writer.close();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorReader fieldsReader(SegmentReadState state) throws IOException {
|
||||
return delegate.fieldsReader(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -147,7 +147,6 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
* Verify that we get the *same* graph by indexing one segment as we do by indexing two segments
|
||||
* and merging.
|
||||
*/
|
||||
@AwaitsFix(bugUrl = "https://issues.apache.org/jira/browse/LUCENE-9905")
|
||||
public void testMergeProducesSameGraph() throws Exception {
|
||||
long seed = random().nextLong();
|
||||
int numDoc = atLeast(100);
|
||||
|
|
Loading…
Reference in New Issue