mirror of https://github.com/apache/lucene.git
LUCENE-9905: Make sure to use configured vector format when merging (#176)
Before when creating a VectorWriter for merging, we would always load the default implementation. So if the format was configured with parameters, they were ignored. This issue was caught by `TestKnnGraph#testMergeProducesSameGraph`.
This commit is contained in:
parent
1ec2a715a2
commit
e9339253f5
|
@ -114,14 +114,7 @@ public abstract class PerFieldVectorFormat extends VectorFormat {
|
||||||
}
|
}
|
||||||
|
|
||||||
private VectorWriter getInstance(FieldInfo field) throws IOException {
|
private VectorWriter getInstance(FieldInfo field) throws IOException {
|
||||||
VectorFormat format = null;
|
VectorFormat format = getVectorFormatForField(field.name);
|
||||||
String fieldFormatName = field.getAttribute(PER_FIELD_FORMAT_KEY);
|
|
||||||
if (fieldFormatName != null) {
|
|
||||||
format = VectorFormat.forName(fieldFormatName);
|
|
||||||
}
|
|
||||||
if (format == null) {
|
|
||||||
format = getVectorFormatForField(field.name);
|
|
||||||
}
|
|
||||||
if (format == null) {
|
if (format == null) {
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException(
|
||||||
"invalid null VectorFormat for field=\"" + field.name + "\"");
|
"invalid null VectorFormat for field=\"" + field.name + "\"");
|
||||||
|
|
|
@ -16,26 +16,36 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.lucene.codecs.perfield;
|
package org.apache.lucene.codecs.perfield;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import java.util.Set;
|
||||||
import org.apache.lucene.analysis.MockAnalyzer;
|
import org.apache.lucene.analysis.MockAnalyzer;
|
||||||
import org.apache.lucene.codecs.Codec;
|
import org.apache.lucene.codecs.Codec;
|
||||||
import org.apache.lucene.codecs.VectorFormat;
|
import org.apache.lucene.codecs.VectorFormat;
|
||||||
|
import org.apache.lucene.codecs.VectorReader;
|
||||||
|
import org.apache.lucene.codecs.VectorWriter;
|
||||||
import org.apache.lucene.codecs.asserting.AssertingCodec;
|
import org.apache.lucene.codecs.asserting.AssertingCodec;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.VectorField;
|
import org.apache.lucene.document.VectorField;
|
||||||
import org.apache.lucene.index.BaseVectorFormatTestCase;
|
import org.apache.lucene.index.BaseVectorFormatTestCase;
|
||||||
import org.apache.lucene.index.DirectoryReader;
|
import org.apache.lucene.index.DirectoryReader;
|
||||||
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.IndexReader;
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.IndexWriter;
|
import org.apache.lucene.index.IndexWriter;
|
||||||
import org.apache.lucene.index.IndexWriterConfig;
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
import org.apache.lucene.index.RandomCodec;
|
import org.apache.lucene.index.RandomCodec;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.util.TestUtil;
|
import org.apache.lucene.util.TestUtil;
|
||||||
|
import org.hamcrest.MatcherAssert;
|
||||||
|
|
||||||
/** Basic tests of PerFieldDocValuesFormat */
|
/** Basic tests of PerFieldDocValuesFormat */
|
||||||
public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
|
public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
|
||||||
|
@ -53,21 +63,21 @@ public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testTwoFieldsTwoFormats() throws IOException {
|
public void testTwoFieldsTwoFormats() throws IOException {
|
||||||
Analyzer analyzer = new MockAnalyzer(random());
|
|
||||||
|
|
||||||
try (Directory directory = newDirectory()) {
|
try (Directory directory = newDirectory()) {
|
||||||
// we don't use RandomIndexWriter because it might add more values than we expect !!!!1
|
// we don't use RandomIndexWriter because it might add more values than we expect !!!!1
|
||||||
IndexWriterConfig iwc = newIndexWriterConfig(analyzer);
|
IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random()));
|
||||||
VectorFormat defaultFormat = TestUtil.getDefaultVectorFormat();
|
WriteRecordingVectorFormat format1 =
|
||||||
VectorFormat emptyFormat = VectorFormat.EMPTY;
|
new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat());
|
||||||
|
WriteRecordingVectorFormat format2 =
|
||||||
|
new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat());
|
||||||
iwc.setCodec(
|
iwc.setCodec(
|
||||||
new AssertingCodec() {
|
new AssertingCodec() {
|
||||||
@Override
|
@Override
|
||||||
public VectorFormat getVectorFormatForField(String field) {
|
public VectorFormat getVectorFormatForField(String field) {
|
||||||
if ("empty".equals(field)) {
|
if ("field1".equals(field)) {
|
||||||
return emptyFormat;
|
return format1;
|
||||||
} else {
|
} else {
|
||||||
return defaultFormat;
|
return format2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -75,32 +85,107 @@ public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
|
||||||
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
|
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
doc.add(newTextField("id", "1", Field.Store.YES));
|
doc.add(newTextField("id", "1", Field.Store.YES));
|
||||||
doc.add(new VectorField("field", new float[] {1, 2, 3}));
|
doc.add(new VectorField("field1", new float[] {1, 2, 3}));
|
||||||
iwriter.addDocument(doc);
|
iwriter.addDocument(doc);
|
||||||
iwriter.commit();
|
|
||||||
|
|
||||||
// Check that we use the empty vector format, which doesn't support writes
|
|
||||||
doc.clear();
|
doc.clear();
|
||||||
doc.add(newTextField("id", "2", Field.Store.YES));
|
doc.add(newTextField("id", "2", Field.Store.YES));
|
||||||
doc.add(new VectorField("empty", new float[] {4, 5, 6}));
|
doc.add(new VectorField("field2", new float[] {4, 5, 6}));
|
||||||
expectThrows(
|
|
||||||
RuntimeException.class,
|
|
||||||
() -> {
|
|
||||||
iwriter.addDocument(doc);
|
iwriter.addDocument(doc);
|
||||||
iwriter.commit();
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now search for the field that was successfully indexed
|
// Check that each format was used to write the expected field
|
||||||
|
MatcherAssert.assertThat(format1.fieldsWritten, equalTo(Set.of("field1")));
|
||||||
|
MatcherAssert.assertThat(format2.fieldsWritten, equalTo(Set.of("field2")));
|
||||||
|
|
||||||
|
// Double-check the vectors were written
|
||||||
try (IndexReader ireader = DirectoryReader.open(directory)) {
|
try (IndexReader ireader = DirectoryReader.open(directory)) {
|
||||||
TopDocs hits1 =
|
TopDocs hits1 =
|
||||||
ireader
|
ireader
|
||||||
.leaves()
|
.leaves()
|
||||||
.get(0)
|
.get(0)
|
||||||
.reader()
|
.reader()
|
||||||
.searchNearestVectors("field", new float[] {1, 2, 3}, 10, 1);
|
.searchNearestVectors("field1", new float[] {1, 2, 3}, 10, 1);
|
||||||
assertEquals(1, hits1.scoreDocs.length);
|
assertEquals(1, hits1.scoreDocs.length);
|
||||||
|
TopDocs hits2 =
|
||||||
|
ireader
|
||||||
|
.leaves()
|
||||||
|
.get(0)
|
||||||
|
.reader()
|
||||||
|
.searchNearestVectors("field2", new float[] {1, 2, 3}, 10, 1);
|
||||||
|
assertEquals(1, hits2.scoreDocs.length);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testMergeUsesNewFormat() throws IOException {
|
||||||
|
try (Directory directory = newDirectory()) {
|
||||||
|
IndexWriterConfig initialConfig = newIndexWriterConfig(new MockAnalyzer(random()));
|
||||||
|
try (IndexWriter iw = new IndexWriter(directory, initialConfig)) {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
Document doc = new Document();
|
||||||
|
doc.add(newTextField("id", "1", Field.Store.YES));
|
||||||
|
doc.add(new VectorField("field", new float[] {1, 2, 3}));
|
||||||
|
iw.addDocument(doc);
|
||||||
|
iw.commit();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
IndexWriterConfig newConfig = newIndexWriterConfig(new MockAnalyzer(random()));
|
||||||
|
WriteRecordingVectorFormat newFormat =
|
||||||
|
new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat());
|
||||||
|
newConfig.setCodec(
|
||||||
|
new AssertingCodec() {
|
||||||
|
@Override
|
||||||
|
public VectorFormat getVectorFormatForField(String field) {
|
||||||
|
return newFormat;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
try (IndexWriter iw = new IndexWriter(directory, newConfig)) {
|
||||||
|
iw.forceMerge(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the new format was used while merging
|
||||||
|
MatcherAssert.assertThat(newFormat.fieldsWritten, equalTo(Set.of("field")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class WriteRecordingVectorFormat extends VectorFormat {
|
||||||
|
private final VectorFormat delegate;
|
||||||
|
private final Set<String> fieldsWritten;
|
||||||
|
|
||||||
|
public WriteRecordingVectorFormat(VectorFormat delegate) {
|
||||||
|
super(delegate.getName());
|
||||||
|
this.delegate = delegate;
|
||||||
|
this.fieldsWritten = new HashSet<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public VectorWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||||
|
VectorWriter writer = delegate.fieldsWriter(state);
|
||||||
|
return new VectorWriter() {
|
||||||
|
@Override
|
||||||
|
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
|
||||||
|
fieldsWritten.add(fieldInfo.name);
|
||||||
|
writer.writeField(fieldInfo, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void finish() throws IOException {
|
||||||
|
writer.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() throws IOException {
|
||||||
|
writer.close();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public VectorReader fieldsReader(SegmentReadState state) throws IOException {
|
||||||
|
return delegate.fieldsReader(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -147,7 +147,6 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||||
* Verify that we get the *same* graph by indexing one segment as we do by indexing two segments
|
* Verify that we get the *same* graph by indexing one segment as we do by indexing two segments
|
||||||
* and merging.
|
* and merging.
|
||||||
*/
|
*/
|
||||||
@AwaitsFix(bugUrl = "https://issues.apache.org/jira/browse/LUCENE-9905")
|
|
||||||
public void testMergeProducesSameGraph() throws Exception {
|
public void testMergeProducesSameGraph() throws Exception {
|
||||||
long seed = random().nextLong();
|
long seed = random().nextLong();
|
||||||
int numDoc = atLeast(100);
|
int numDoc = atLeast(100);
|
||||||
|
|
Loading…
Reference in New Issue