mirror of
https://github.com/apache/lucene.git
synced 2025-03-06 16:29:30 +00:00
Test KNN query works seamlessly regardless of underlying format (#13225)
* Test Knn query on different vector formats
This commit is contained in:
parent
6cba773318
commit
42a5ff6ace
@ -25,6 +25,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.IntPoint;
|
import org.apache.lucene.document.IntPoint;
|
||||||
@ -42,6 +43,8 @@ import org.apache.lucene.index.StoredFields;
|
|||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
||||||
|
import org.apache.lucene.tests.codecs.asserting.AssertingCodec;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
import org.apache.lucene.tests.util.TestUtil;
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
@ -949,4 +952,58 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||||||
return 31 * classHash() + docs.hashCode();
|
return 31 * classHash() + docs.hashCode();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testSameFieldDifferentFormats() throws IOException {
|
||||||
|
try (Directory directory = newDirectory()) {
|
||||||
|
MockAnalyzer mockAnalyzer = new MockAnalyzer(random());
|
||||||
|
IndexWriterConfig iwc = newIndexWriterConfig(mockAnalyzer);
|
||||||
|
KnnVectorsFormat format1 = randomVectorFormat();
|
||||||
|
KnnVectorsFormat format2 = randomVectorFormat();
|
||||||
|
iwc.setCodec(
|
||||||
|
new AssertingCodec() {
|
||||||
|
@Override
|
||||||
|
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||||
|
return format1;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
|
||||||
|
Document doc = new Document();
|
||||||
|
doc.add(getKnnVectorField("field1", new float[] {1, 1, 1}));
|
||||||
|
iwriter.addDocument(doc);
|
||||||
|
|
||||||
|
doc.clear();
|
||||||
|
doc.add(getKnnVectorField("field1", new float[] {1, 2, 3}));
|
||||||
|
iwriter.addDocument(doc);
|
||||||
|
iwriter.commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
iwc = newIndexWriterConfig(mockAnalyzer);
|
||||||
|
iwc.setCodec(
|
||||||
|
new AssertingCodec() {
|
||||||
|
@Override
|
||||||
|
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||||
|
return format2;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
|
||||||
|
Document doc = new Document();
|
||||||
|
doc.clear();
|
||||||
|
doc.add(getKnnVectorField("field1", new float[] {1, 1, 2}));
|
||||||
|
iwriter.addDocument(doc);
|
||||||
|
|
||||||
|
doc.clear();
|
||||||
|
doc.add(getKnnVectorField("field1", new float[] {4, 5, 6}));
|
||||||
|
iwriter.addDocument(doc);
|
||||||
|
iwriter.commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
try (IndexReader ireader = DirectoryReader.open(directory)) {
|
||||||
|
AbstractKnnVectorQuery vectorQuery = getKnnVectorQuery("field1", new float[] {1, 2, 3}, 10);
|
||||||
|
TopDocs hits1 = new IndexSearcher(ireader).search(vectorQuery, 4);
|
||||||
|
assertEquals(4, hits1.scoreDocs.length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,8 @@
|
|||||||
/** Lucene test framework. */
|
/** Lucene test framework. */
|
||||||
@SuppressWarnings({"module", "requires-automatic", "requires-transitive-automatic"})
|
@SuppressWarnings({"module", "requires-automatic", "requires-transitive-automatic"})
|
||||||
module org.apache.lucene.test_framework {
|
module org.apache.lucene.test_framework {
|
||||||
|
uses org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
|
|
||||||
requires org.apache.lucene.core;
|
requires org.apache.lucene.core;
|
||||||
requires org.apache.lucene.codecs;
|
requires org.apache.lucene.codecs;
|
||||||
requires transitive junit;
|
requires transitive junit;
|
||||||
|
@ -89,6 +89,7 @@ import java.util.List;
|
|||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
import java.util.ServiceLoader;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.TimeZone;
|
import java.util.TimeZone;
|
||||||
import java.util.TreeSet;
|
import java.util.TreeSet;
|
||||||
@ -100,6 +101,7 @@ import java.util.concurrent.atomic.AtomicReference;
|
|||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
import junit.framework.AssertionFailedError;
|
import junit.framework.AssertionFailedError;
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.Field.Store;
|
import org.apache.lucene.document.Field.Store;
|
||||||
@ -3213,4 +3215,13 @@ public abstract class LuceneTestCase extends Assert {
|
|||||||
|
|
||||||
return it;
|
return it;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected KnnVectorsFormat randomVectorFormat() {
|
||||||
|
ServiceLoader<KnnVectorsFormat> formats = java.util.ServiceLoader.load(KnnVectorsFormat.class);
|
||||||
|
List<KnnVectorsFormat> availableFormats = new ArrayList<>();
|
||||||
|
for (KnnVectorsFormat f : formats) {
|
||||||
|
availableFormats.add(f);
|
||||||
|
}
|
||||||
|
return RandomPicks.randomFrom(random(), availableFormats);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user