From e9339253f5ebcd88282297bdadcbe1705e15f91b Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Tue, 8 Jun 2021 08:07:35 -0700 Subject: [PATCH] LUCENE-9905: Make sure to use configured vector format when merging (#176) Before when creating a VectorWriter for merging, we would always load the default implementation. So if the format was configured with parameters, they were ignored. This issue was caught by `TestKnnGraph#testMergeProducesSameGraph`. --- .../codecs/perfield/PerFieldVectorFormat.java | 9 +- .../perfield/TestPerFieldVectorFormat.java | 127 +++++++++++++++--- .../org/apache/lucene/index/TestKnnGraph.java | 1 - 3 files changed, 107 insertions(+), 30 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldVectorFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldVectorFormat.java index e8347224ec6..f020c5ebd30 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldVectorFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldVectorFormat.java @@ -114,14 +114,7 @@ public abstract class PerFieldVectorFormat extends VectorFormat { } private VectorWriter getInstance(FieldInfo field) throws IOException { - VectorFormat format = null; - String fieldFormatName = field.getAttribute(PER_FIELD_FORMAT_KEY); - if (fieldFormatName != null) { - format = VectorFormat.forName(fieldFormatName); - } - if (format == null) { - format = getVectorFormatForField(field.name); - } + VectorFormat format = getVectorFormatForField(field.name); if (format == null) { throw new IllegalStateException( "invalid null VectorFormat for field=\"" + field.name + "\""); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldVectorFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldVectorFormat.java index ab0f68ff53f..63210b04227 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldVectorFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldVectorFormat.java @@ -16,26 +16,36 @@ */ package org.apache.lucene.codecs.perfield; +import static org.hamcrest.Matchers.equalTo; + import java.io.IOException; import java.util.Collections; +import java.util.HashSet; import java.util.Random; -import org.apache.lucene.analysis.Analyzer; +import java.util.Set; import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.VectorFormat; +import org.apache.lucene.codecs.VectorReader; +import org.apache.lucene.codecs.VectorWriter; import org.apache.lucene.codecs.asserting.AssertingCodec; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.VectorField; import org.apache.lucene.index.BaseVectorFormatTestCase; import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.RandomCodec; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorValues; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.util.TestUtil; +import org.hamcrest.MatcherAssert; /** Basic tests of PerFieldDocValuesFormat */ public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase { @@ -53,21 +63,21 @@ public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase { } public void testTwoFieldsTwoFormats() throws IOException { - Analyzer analyzer = new MockAnalyzer(random()); - try (Directory directory = newDirectory()) { // we don't use RandomIndexWriter because it might add more values than we expect !!!!1 - IndexWriterConfig iwc = newIndexWriterConfig(analyzer); - VectorFormat defaultFormat = TestUtil.getDefaultVectorFormat(); - VectorFormat emptyFormat = VectorFormat.EMPTY; + IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random())); + WriteRecordingVectorFormat format1 = + new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat()); + WriteRecordingVectorFormat format2 = + new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat()); iwc.setCodec( new AssertingCodec() { @Override public VectorFormat getVectorFormatForField(String field) { - if ("empty".equals(field)) { - return emptyFormat; + if ("field1".equals(field)) { + return format1; } else { - return defaultFormat; + return format2; } } }); @@ -75,32 +85,107 @@ public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase { try (IndexWriter iwriter = new IndexWriter(directory, iwc)) { Document doc = new Document(); doc.add(newTextField("id", "1", Field.Store.YES)); - doc.add(new VectorField("field", new float[] {1, 2, 3})); + doc.add(new VectorField("field1", new float[] {1, 2, 3})); iwriter.addDocument(doc); - iwriter.commit(); - // Check that we use the empty vector format, which doesn't support writes doc.clear(); doc.add(newTextField("id", "2", Field.Store.YES)); - doc.add(new VectorField("empty", new float[] {4, 5, 6})); - expectThrows( - RuntimeException.class, - () -> { - iwriter.addDocument(doc); - iwriter.commit(); - }); + doc.add(new VectorField("field2", new float[] {4, 5, 6})); + iwriter.addDocument(doc); } - // Now search for the field that was successfully indexed + // Check that each format was used to write the expected field + MatcherAssert.assertThat(format1.fieldsWritten, equalTo(Set.of("field1"))); + MatcherAssert.assertThat(format2.fieldsWritten, equalTo(Set.of("field2"))); + + // Double-check the vectors were written try (IndexReader ireader = DirectoryReader.open(directory)) { TopDocs hits1 = ireader .leaves() .get(0) .reader() - .searchNearestVectors("field", new float[] {1, 2, 3}, 10, 1); + .searchNearestVectors("field1", new float[] {1, 2, 3}, 10, 1); assertEquals(1, hits1.scoreDocs.length); + TopDocs hits2 = + ireader + .leaves() + .get(0) + .reader() + .searchNearestVectors("field2", new float[] {1, 2, 3}, 10, 1); + assertEquals(1, hits2.scoreDocs.length); } } } + + public void testMergeUsesNewFormat() throws IOException { + try (Directory directory = newDirectory()) { + IndexWriterConfig initialConfig = newIndexWriterConfig(new MockAnalyzer(random())); + try (IndexWriter iw = new IndexWriter(directory, initialConfig)) { + for (int i = 0; i < 3; i++) { + Document doc = new Document(); + doc.add(newTextField("id", "1", Field.Store.YES)); + doc.add(new VectorField("field", new float[] {1, 2, 3})); + iw.addDocument(doc); + iw.commit(); + } + } + + IndexWriterConfig newConfig = newIndexWriterConfig(new MockAnalyzer(random())); + WriteRecordingVectorFormat newFormat = + new WriteRecordingVectorFormat(TestUtil.getDefaultVectorFormat()); + newConfig.setCodec( + new AssertingCodec() { + @Override + public VectorFormat getVectorFormatForField(String field) { + return newFormat; + } + }); + + try (IndexWriter iw = new IndexWriter(directory, newConfig)) { + iw.forceMerge(1); + } + + // Check that the new format was used while merging + MatcherAssert.assertThat(newFormat.fieldsWritten, equalTo(Set.of("field"))); + } + } + + private static class WriteRecordingVectorFormat extends VectorFormat { + private final VectorFormat delegate; + private final Set fieldsWritten; + + public WriteRecordingVectorFormat(VectorFormat delegate) { + super(delegate.getName()); + this.delegate = delegate; + this.fieldsWritten = new HashSet<>(); + } + + @Override + public VectorWriter fieldsWriter(SegmentWriteState state) throws IOException { + VectorWriter writer = delegate.fieldsWriter(state); + return new VectorWriter() { + @Override + public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException { + fieldsWritten.add(fieldInfo.name); + writer.writeField(fieldInfo, values); + } + + @Override + public void finish() throws IOException { + writer.finish(); + } + + @Override + public void close() throws IOException { + writer.close(); + } + }; + } + + @Override + public VectorReader fieldsReader(SegmentReadState state) throws IOException { + return delegate.fieldsReader(state); + } + } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 06f46af4458..1baa9ae2d86 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -147,7 +147,6 @@ public class TestKnnGraph extends LuceneTestCase { * Verify that we get the *same* graph by indexing one segment as we do by indexing two segments * and merging. */ - @AwaitsFix(bugUrl = "https://issues.apache.org/jira/browse/LUCENE-9905") public void testMergeProducesSameGraph() throws Exception { long seed = random().nextLong(); int numDoc = atLeast(100);