Test KNN query works seamlessly regardless of underlying format (#13225)

* Test Knn query on different vector formats
This commit is contained in:
Tommaso Teofili 2024-03-29 08:28:33 +01:00 committed by GitHub
parent 6cba773318
commit 42a5ff6ace
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 0 deletions

View File

@ -25,6 +25,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.IntPoint;
@ -42,6 +43,8 @@ import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.codecs.asserting.AssertingCodec;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
@ -949,4 +952,58 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
return 31 * classHash() + docs.hashCode();
}
}
public void testSameFieldDifferentFormats() throws IOException {
try (Directory directory = newDirectory()) {
MockAnalyzer mockAnalyzer = new MockAnalyzer(random());
IndexWriterConfig iwc = newIndexWriterConfig(mockAnalyzer);
KnnVectorsFormat format1 = randomVectorFormat();
KnnVectorsFormat format2 = randomVectorFormat();
iwc.setCodec(
new AssertingCodec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return format1;
}
});
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
Document doc = new Document();
doc.add(getKnnVectorField("field1", new float[] {1, 1, 1}));
iwriter.addDocument(doc);
doc.clear();
doc.add(getKnnVectorField("field1", new float[] {1, 2, 3}));
iwriter.addDocument(doc);
iwriter.commit();
}
iwc = newIndexWriterConfig(mockAnalyzer);
iwc.setCodec(
new AssertingCodec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return format2;
}
});
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
Document doc = new Document();
doc.clear();
doc.add(getKnnVectorField("field1", new float[] {1, 1, 2}));
iwriter.addDocument(doc);
doc.clear();
doc.add(getKnnVectorField("field1", new float[] {4, 5, 6}));
iwriter.addDocument(doc);
iwriter.commit();
}
try (IndexReader ireader = DirectoryReader.open(directory)) {
AbstractKnnVectorQuery vectorQuery = getKnnVectorQuery("field1", new float[] {1, 2, 3}, 10);
TopDocs hits1 = new IndexSearcher(ireader).search(vectorQuery, 4);
assertEquals(4, hits1.scoreDocs.length);
}
}
}
}

View File

@ -18,6 +18,8 @@
/** Lucene test framework. */
@SuppressWarnings({"module", "requires-automatic", "requires-transitive-automatic"})
module org.apache.lucene.test_framework {
uses org.apache.lucene.codecs.KnnVectorsFormat;
requires org.apache.lucene.core;
requires org.apache.lucene.codecs;
requires transitive junit;

View File

@ -89,6 +89,7 @@ import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.TimeZone;
import java.util.TreeSet;
@ -100,6 +101,7 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Pattern;
import junit.framework.AssertionFailedError;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.Field.Store;
@ -3213,4 +3215,13 @@ public abstract class LuceneTestCase extends Assert {
return it;
}
protected KnnVectorsFormat randomVectorFormat() {
ServiceLoader<KnnVectorsFormat> formats = java.util.ServiceLoader.load(KnnVectorsFormat.class);
List<KnnVectorsFormat> availableFormats = new ArrayList<>();
for (KnnVectorsFormat f : formats) {
availableFormats.add(f);
}
return RandomPicks.randomFrom(random(), availableFormats);
}
}