From 42a5ff6aceb66ac8c66e51319befbdaed5b273a6 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Fri, 29 Mar 2024 08:28:33 +0100 Subject: [PATCH] Test KNN query works seamlessly regardless of underlying format (#13225) * Test Knn query on different vector formats --- .../search/BaseKnnVectorQueryTestCase.java | 57 +++++++++++++++++++ .../test-framework/src/java/module-info.java | 2 + .../lucene/tests/util/LuceneTestCase.java | 11 ++++ 3 files changed, 70 insertions(+) diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index 5591a9059c9..4ac4935c39c 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -25,6 +25,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; import java.util.HashSet; import java.util.Set; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.IntPoint; @@ -42,6 +43,8 @@ import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.apache.lucene.tests.codecs.asserting.AssertingCodec; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -949,4 +952,58 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { return 31 * classHash() + docs.hashCode(); } } + + public void testSameFieldDifferentFormats() throws IOException { + try (Directory directory = newDirectory()) { + MockAnalyzer mockAnalyzer = new MockAnalyzer(random()); + IndexWriterConfig iwc = newIndexWriterConfig(mockAnalyzer); + KnnVectorsFormat format1 = randomVectorFormat(); + KnnVectorsFormat format2 = randomVectorFormat(); + iwc.setCodec( + new AssertingCodec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return format1; + } + }); + + try (IndexWriter iwriter = new IndexWriter(directory, iwc)) { + Document doc = new Document(); + doc.add(getKnnVectorField("field1", new float[] {1, 1, 1})); + iwriter.addDocument(doc); + + doc.clear(); + doc.add(getKnnVectorField("field1", new float[] {1, 2, 3})); + iwriter.addDocument(doc); + iwriter.commit(); + } + + iwc = newIndexWriterConfig(mockAnalyzer); + iwc.setCodec( + new AssertingCodec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return format2; + } + }); + + try (IndexWriter iwriter = new IndexWriter(directory, iwc)) { + Document doc = new Document(); + doc.clear(); + doc.add(getKnnVectorField("field1", new float[] {1, 1, 2})); + iwriter.addDocument(doc); + + doc.clear(); + doc.add(getKnnVectorField("field1", new float[] {4, 5, 6})); + iwriter.addDocument(doc); + iwriter.commit(); + } + + try (IndexReader ireader = DirectoryReader.open(directory)) { + AbstractKnnVectorQuery vectorQuery = getKnnVectorQuery("field1", new float[] {1, 2, 3}, 10); + TopDocs hits1 = new IndexSearcher(ireader).search(vectorQuery, 4); + assertEquals(4, hits1.scoreDocs.length); + } + } + } } diff --git a/lucene/test-framework/src/java/module-info.java b/lucene/test-framework/src/java/module-info.java index f366d1f52b7..2af42e6b12d 100644 --- a/lucene/test-framework/src/java/module-info.java +++ b/lucene/test-framework/src/java/module-info.java @@ -18,6 +18,8 @@ /** Lucene test framework. */ @SuppressWarnings({"module", "requires-automatic", "requires-transitive-automatic"}) module org.apache.lucene.test_framework { + uses org.apache.lucene.codecs.KnnVectorsFormat; + requires org.apache.lucene.core; requires org.apache.lucene.codecs; requires transitive junit; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/util/LuceneTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/util/LuceneTestCase.java index f7609392252..37fc92185d0 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/util/LuceneTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/util/LuceneTestCase.java @@ -89,6 +89,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Random; +import java.util.ServiceLoader; import java.util.Set; import java.util.TimeZone; import java.util.TreeSet; @@ -100,6 +101,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.regex.Pattern; import junit.framework.AssertionFailedError; import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.Field.Store; @@ -3213,4 +3215,13 @@ public abstract class LuceneTestCase extends Assert { return it; } + + protected KnnVectorsFormat randomVectorFormat() { + ServiceLoader formats = java.util.ServiceLoader.load(KnnVectorsFormat.class); + List availableFormats = new ArrayList<>(); + for (KnnVectorsFormat f : formats) { + availableFormats.add(f); + } + return RandomPicks.randomFrom(random(), availableFormats); + } }