mirror of
https://github.com/apache/lucene.git
synced 2025-03-06 16:29:30 +00:00
Test KNN query works seamlessly regardless of underlying format (#13225)
* Test Knn query on different vector formats
This commit is contained in:
parent
6cba773318
commit
42a5ff6ace
@ -25,6 +25,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import java.io.IOException;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.IntPoint;
|
||||
@ -42,6 +43,8 @@ import org.apache.lucene.index.StoredFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.tests.codecs.asserting.AssertingCodec;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
@ -949,4 +952,58 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
||||
return 31 * classHash() + docs.hashCode();
|
||||
}
|
||||
}
|
||||
|
||||
public void testSameFieldDifferentFormats() throws IOException {
|
||||
try (Directory directory = newDirectory()) {
|
||||
MockAnalyzer mockAnalyzer = new MockAnalyzer(random());
|
||||
IndexWriterConfig iwc = newIndexWriterConfig(mockAnalyzer);
|
||||
KnnVectorsFormat format1 = randomVectorFormat();
|
||||
KnnVectorsFormat format2 = randomVectorFormat();
|
||||
iwc.setCodec(
|
||||
new AssertingCodec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return format1;
|
||||
}
|
||||
});
|
||||
|
||||
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
|
||||
Document doc = new Document();
|
||||
doc.add(getKnnVectorField("field1", new float[] {1, 1, 1}));
|
||||
iwriter.addDocument(doc);
|
||||
|
||||
doc.clear();
|
||||
doc.add(getKnnVectorField("field1", new float[] {1, 2, 3}));
|
||||
iwriter.addDocument(doc);
|
||||
iwriter.commit();
|
||||
}
|
||||
|
||||
iwc = newIndexWriterConfig(mockAnalyzer);
|
||||
iwc.setCodec(
|
||||
new AssertingCodec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return format2;
|
||||
}
|
||||
});
|
||||
|
||||
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
|
||||
Document doc = new Document();
|
||||
doc.clear();
|
||||
doc.add(getKnnVectorField("field1", new float[] {1, 1, 2}));
|
||||
iwriter.addDocument(doc);
|
||||
|
||||
doc.clear();
|
||||
doc.add(getKnnVectorField("field1", new float[] {4, 5, 6}));
|
||||
iwriter.addDocument(doc);
|
||||
iwriter.commit();
|
||||
}
|
||||
|
||||
try (IndexReader ireader = DirectoryReader.open(directory)) {
|
||||
AbstractKnnVectorQuery vectorQuery = getKnnVectorQuery("field1", new float[] {1, 2, 3}, 10);
|
||||
TopDocs hits1 = new IndexSearcher(ireader).search(vectorQuery, 4);
|
||||
assertEquals(4, hits1.scoreDocs.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -18,6 +18,8 @@
|
||||
/** Lucene test framework. */
|
||||
@SuppressWarnings({"module", "requires-automatic", "requires-transitive-automatic"})
|
||||
module org.apache.lucene.test_framework {
|
||||
uses org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
|
||||
requires org.apache.lucene.core;
|
||||
requires org.apache.lucene.codecs;
|
||||
requires transitive junit;
|
||||
|
@ -89,6 +89,7 @@ import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.ServiceLoader;
|
||||
import java.util.Set;
|
||||
import java.util.TimeZone;
|
||||
import java.util.TreeSet;
|
||||
@ -100,6 +101,7 @@ import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.regex.Pattern;
|
||||
import junit.framework.AssertionFailedError;
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.Field.Store;
|
||||
@ -3213,4 +3215,13 @@ public abstract class LuceneTestCase extends Assert {
|
||||
|
||||
return it;
|
||||
}
|
||||
|
||||
protected KnnVectorsFormat randomVectorFormat() {
|
||||
ServiceLoader<KnnVectorsFormat> formats = java.util.ServiceLoader.load(KnnVectorsFormat.class);
|
||||
List<KnnVectorsFormat> availableFormats = new ArrayList<>();
|
||||
for (KnnVectorsFormat f : formats) {
|
||||
availableFormats.add(f);
|
||||
}
|
||||
return RandomPicks.randomFrom(random(), availableFormats);
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user