diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 3822f5f1e92..e9246a8b575 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -21,7 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Objects; @@ -49,10 +48,6 @@ import org.apache.lucene.util.Bits; *
  • Otherwise run a kNN search subject to the filter *
  • If the kNN search visits too many vectors without completing, stop and run an exact search * - * - *

    When a seed query is provided, this query is executed first to seed the kNN search (subject to - * the same rules provided by the filter). If the seed query fails to identify any documents, it - * falls back on the strategy above. */ abstract class AbstractKnnVectorQuery extends Query { @@ -60,21 +55,15 @@ abstract class AbstractKnnVectorQuery extends Query { protected final String field; protected final int k; - private final Query filter; - private final Query seed; + protected final Query filter; public AbstractKnnVectorQuery(String field, int k, Query filter) { - this(field, k, filter, null); - } - - public AbstractKnnVectorQuery(String field, int k, Query filter, Query seed) { this.field = Objects.requireNonNull(field, "field"); this.k = k; if (k < 1) { throw new IllegalArgumentException("k must be at least 1, got: " + k); } this.filter = filter; - this.seed = seed; } @Override @@ -94,21 +83,6 @@ abstract class AbstractKnnVectorQuery extends Query { filterWeight = null; } - final Weight seedWeight; - if (seed != null) { - BooleanQuery.Builder booleanSeedQueryBuilder = - new BooleanQuery.Builder() - .add(seed, BooleanClause.Occur.MUST) - .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); - if (filter != null) { - booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); - } - Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); - seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); - } else { - seedWeight = null; - } - TimeLimitingKnnCollectorManager knnCollectorManager = new TimeLimitingKnnCollectorManager( getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout()); @@ -116,7 +90,7 @@ abstract class AbstractKnnVectorQuery extends Query { List leafReaderContexts = reader.leaves(); List> tasks = new ArrayList<>(leafReaderContexts.size()); for (LeafReaderContext context : leafReaderContexts) { - tasks.add(() -> searchLeaf(context, filterWeight, seedWeight, knnCollectorManager)); + tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager)); } TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new); @@ -131,11 +105,9 @@ abstract class AbstractKnnVectorQuery extends Query { private TopDocs searchLeaf( LeafReaderContext ctx, Weight filterWeight, - Weight seedWeight, TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) throws IOException { - TopDocs results = - getLeafResults(ctx, filterWeight, seedWeight, timeLimitingKnnCollectorManager); + TopDocs results = getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager); if (ctx.docBase > 0) { for (ScoreDoc scoreDoc : results.scoreDocs) { scoreDoc.doc += ctx.docBase; @@ -147,19 +119,13 @@ abstract class AbstractKnnVectorQuery extends Query { private TopDocs getLeafResults( LeafReaderContext ctx, Weight filterWeight, - Weight seedWeight, TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) throws IOException { final LeafReader reader = ctx.reader(); final Bits liveDocs = reader.getLiveDocs(); if (filterWeight == null) { - return approximateSearch( - ctx, - liveDocs, - executeSeedQuery(ctx, seedWeight), - Integer.MAX_VALUE, - timeLimitingKnnCollectorManager); + return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager); } Scorer scorer = filterWeight.scorer(ctx); @@ -179,13 +145,7 @@ abstract class AbstractKnnVectorQuery extends Query { // Perform the approximate kNN search // We pass cost + 1 here to account for the edge case when we explore exactly cost vectors - TopDocs results = - approximateSearch( - ctx, - acceptDocs, - executeSeedQuery(ctx, seedWeight), - cost + 1, - timeLimitingKnnCollectorManager); + TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager); if (results.totalHits.relation() == TotalHits.Relation.EQUAL_TO // Return partial results only when timeout is met || (queryTimeout != null && queryTimeout.shouldExit())) { @@ -196,49 +156,6 @@ abstract class AbstractKnnVectorQuery extends Query { } } - private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeight) - throws IOException { - if (seedWeight == null) return null; - // Execute the seed query - TopScoreDocCollector seedCollector = - new TopScoreDocCollectorManager( - k /* numHits */, - null /* after */, - Integer.MAX_VALUE /* totalHitsThreshold */, - false /* supportsConcurrency */) - .newCollector(); - final LeafReader leafReader = ctx.reader(); - final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx); - if (leafCollector != null) { - try { - BulkScorer scorer = seedWeight.bulkScorer(ctx); - if (scorer != null) { - scorer.score( - leafCollector, - leafReader.getLiveDocs(), - 0 /* min */, - DocIdSetIterator.NO_MORE_DOCS /* max */); - } - leafCollector.finish(); - } catch ( - @SuppressWarnings("unused") - CollectionTerminatedException e) { - } - } - - TopDocs seedTopDocs = seedCollector.topDocs(); - return convertDocIdsToVectorOrdinals(leafReader, new TopDocsDISI(seedTopDocs)); - } - - /** - * Returns a new iterator that maps the provided docIds to the vector ordinals. - * - * @lucene.internal - * @lucene.experimental - */ - protected abstract DocIdSetIterator convertDocIdsToVectorOrdinals( - LeafReader reader, DocIdSetIterator docIds) throws IOException; - private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException { if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) { @@ -264,7 +181,6 @@ abstract class AbstractKnnVectorQuery extends Query { protected abstract TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, - DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException; @@ -386,15 +302,12 @@ abstract class AbstractKnnVectorQuery extends Query { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; AbstractKnnVectorQuery that = (AbstractKnnVectorQuery) o; - return k == that.k - && Objects.equals(field, that.field) - && Objects.equals(filter, that.filter) - && Objects.equals(seed, that.seed); + return k == that.k && Objects.equals(field, that.field) && Objects.equals(filter, that.filter); } @Override public int hashCode() { - return Objects.hash(field, k, filter, seed); + return Objects.hash(field, k, filter); } /** @@ -419,13 +332,6 @@ abstract class AbstractKnnVectorQuery extends Query { return filter; } - /** - * @return the query that seeds the kNN search. - */ - public Query getSeed() { - return seed; - } - /** Caches the results of a KnnVector search: a list of docs and their scores */ static class DocAndScoreQuery extends Query { @@ -585,44 +491,4 @@ abstract class AbstractKnnVectorQuery extends Query { classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores)); } } - - private static class TopDocsDISI extends DocIdSetIterator { - private final List sortedDocIdList; - private int idx = -1; - - public TopDocsDISI(TopDocs topDocs) { - sortedDocIdList = new ArrayList(topDocs.scoreDocs.length); - for (int i = 0; i < topDocs.scoreDocs.length; i++) { - sortedDocIdList.add(topDocs.scoreDocs[i].doc); - } - Collections.sort(sortedDocIdList); - } - - @Override - public int advance(int target) throws IOException { - return slowAdvance(target); - } - - @Override - public long cost() { - return sortedDocIdList.size(); - } - - @Override - public int docID() { - if (idx == -1) { - return -1; - } else if (idx >= sortedDocIdList.size()) { - return DocIdSetIterator.NO_MORE_DOCS; - } else { - return sortedDocIdList.get(idx); - } - } - - @Override - public int nextDoc() throws IOException { - idx += 1; - return docID(); - } - } } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index eb637da0d75..05157ab65cb 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -46,7 +46,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; - private final byte[] target; + protected final byte[] target; /** * Find the k nearest documents to the target vector according to the vectors in the @@ -72,22 +72,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery { * @throws IllegalArgumentException if k is less than 1 */ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) { - this(field, target, k, filter, null); - } - - /** - * Find the k nearest documents to the target vector according to the vectors in the - * given field. target vector. - * - * @param field a field that has been indexed as a {@link KnnByteVectorField}. - * @param target the target of the search - * @param k the number of documents to find - * @param filter a filter applied before the vector search - * @param seed a query that is executed to seed the vector search - * @throws IllegalArgumentException if k is less than 1 - */ - public KnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { - super(field, k, filter, seed); + super(field, k, filter); this.target = Objects.requireNonNull(target, "target"); } @@ -95,14 +80,10 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery { protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, - DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); - if (seedDocs != null) { - knnCollector = new KnnCollector.Seeded(knnCollector, seedDocs); - } LeafReader reader = context.reader(); ByteVectorValues byteVectorValues = reader.getByteVectorValues(field); if (byteVectorValues == null) { @@ -159,18 +140,4 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery { public byte[] getTargetCopy() { return ArrayUtil.copyArray(target); } - - /** - * Returns a new iterator that maps the provided docIds to the vector ordinals. - * - *

    This method assumes that all docIds have corresponding ordinals. - * - * @lucene.internal - * @lucene.experimental - */ - @Override - protected DocIdSetIterator convertDocIdsToVectorOrdinals( - LeafReader reader, DocIdSetIterator docIds) throws IOException { - return reader.getByteVectorValues(field).convertDocIdsToVectorOrdinals(docIds); - } } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java index 43338fec72e..a05ca674771 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java @@ -86,26 +86,13 @@ public interface KnnCollector { */ TopDocs topDocs(); - /** - * This method returns a {@link DocIdSetIterator} over entry points that seed the KNN search, or - * {@code null} (default) to perform a full KNN search (without seeds). - * - *

    Note that the entry points should represent ordinals, rather than true document IDs. - * - * @return the seed entry points or {@code null}. - * @lucene.experimental - */ - default DocIdSetIterator getSeedEntryPoints() { - return null; - } - /** * KnnCollector.Decorator is the base class for decorators of KnnCollector objects, which extend * the object with new behaviors. * * @lucene.experimental */ - public abstract static class Decorator implements KnnCollector { + abstract class Decorator implements KnnCollector { private KnnCollector collector; public Decorator(KnnCollector collector) { @@ -151,29 +138,5 @@ public interface KnnCollector { public TopDocs topDocs() { return collector.topDocs(); } - - @Override - public DocIdSetIterator getSeedEntryPoints() { - return collector.getSeedEntryPoints(); - } - } - - /** - * KnnCollector.Seeded is a KnnCollector decorator that replaces the seedEntryPoints. - * - * @lucene.experimental - */ - public static class Seeded extends Decorator { - private DocIdSetIterator seedEntryPoints; - - public Seeded(KnnCollector collector, DocIdSetIterator seedEntryPoints) { - super(collector); - this.seedEntryPoints = seedEntryPoints; - } - - @Override - public DocIdSetIterator getSeedEntryPoints() { - return seedEntryPoints; - } } } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 28660756453..c7d6fdb3608 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -47,7 +47,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; - private final float[] target; + protected final float[] target; /** * Find the k nearest documents to the target vector according to the vectors in the @@ -73,22 +73,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { * @throws IllegalArgumentException if k is less than 1 */ public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) { - this(field, target, k, filter, null); - } - - /** - * Find the k nearest documents to the target vector according to the vectors in the - * given field. target vector. - * - * @param field a field that has been indexed as a {@link KnnFloatVectorField}. - * @param target the target of the search - * @param k the number of documents to find - * @param filter a filter applied before the vector search - * @param seed a query that is executed to seed the vector search - * @throws IllegalArgumentException if k is less than 1 - */ - public KnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) { - super(field, k, filter, seed); + super(field, k, filter); this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target")); } @@ -96,14 +81,10 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, - DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); - if (seedDocs != null) { - knnCollector = new KnnCollector.Seeded(knnCollector, seedDocs); - } LeafReader reader = context.reader(); FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field); if (floatVectorValues == null) { @@ -162,18 +143,4 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { public float[] getTargetCopy() { return ArrayUtil.copyArray(target); } - - /** - * Returns a new iterator that maps the provided docIds to the vector ordinals. - * - *

    This method assumes that all docIds have corresponding ordinals. - * - * @lucene.internal - * @lucene.experimental - */ - @Override - protected DocIdSetIterator convertDocIdsToVectorOrdinals( - LeafReader reader, DocIdSetIterator docIds) throws IOException { - return reader.getFloatVectorValues(field).convertDocIdsToVectorOrdinals(docIds); - } } diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java new file mode 100644 index 00000000000..93286948b0a --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.SeededKnnCollectorManager; + +/** + * This is a version of knn byte vector query that provides a query seed to initiate the vector + * search. NOTE: The underlying format is free to ignore the provided seed + * + * @lucene.experimental + */ +public class SeededKnnByteVectorQuery extends KnnByteVectorQuery { + private final Query seed; + private final Weight seedWeight; + + /** + * Construct a new SeededKnnFloatVectorQuery instance + * + * @param field knn byte vector field to query + * @param target the query vector + * @param k number of neighbors to return + * @param filter a filter on the neighbors to return + * @param seed a query seed to initiate the vector format search + */ + public SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { + super(field, target, k, filter); + this.seed = Objects.requireNonNull(seed); + this.seedWeight = null; + } + + SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter); + this.seed = null; + this.seedWeight = Objects.requireNonNull(seedWeight); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + SeededKnnByteVectorQuery rewritten = + new SeededKnnByteVectorQuery(field, target, k, filter, seedWeight); + return rewritten.rewrite(indexSearcher); + } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + if (seedWeight == null) { + throw new UnsupportedOperationException("must be rewritten before constructing manager"); + } + return new SeededKnnCollectorManager( + super.getKnnCollectorManager(k, searcher), + seedWeight, + k, + leaf -> leaf.getFloatVectorValues(field)); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java new file mode 100644 index 00000000000..f64e0b29bc6 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.SeededKnnCollectorManager; + +/** + * This is a version of knn float vector query that provides a query seed to initiate the vector + * search. NOTE: The underlying format is free to ignore the provided seed + * + * @lucene.experimental + */ +public class SeededKnnFloatVectorQuery extends KnnFloatVectorQuery { + private final Query seed; + private final Weight seedWeight; + + /** + * Construct a new SeededKnnFloatVectorQuery instance + * + * @param field knn float vector field to query + * @param target the query vector + * @param k number of neighbors to return + * @param filter a filter on the neighbors to return + * @param seed a query seed to initiate the vector format search + */ + public SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) { + super(field, target, k, filter); + this.seed = Objects.requireNonNull(seed); + this.seedWeight = null; + } + + SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter); + this.seed = null; + this.seedWeight = Objects.requireNonNull(seedWeight); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + SeededKnnFloatVectorQuery rewritten = + new SeededKnnFloatVectorQuery(field, target, k, filter, seedWeight); + return rewritten.rewrite(indexSearcher); + } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + if (seedWeight == null) { + throw new UnsupportedOperationException("must be rewritten before constructing manager"); + } + return new SeededKnnCollectorManager( + super.getKnnCollectorManager(k, searcher), + seedWeight, + k, + leaf -> leaf.getFloatVectorValues(field)); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java new file mode 100644 index 00000000000..40eda94c654 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.knn; + +import org.apache.lucene.search.DocIdSetIterator; + +/** Provides entry points for the kNN search */ +public interface EntryPointProvider { + /** Iterator of valid entry points for the kNN search */ + DocIdSetIterator entryPoints(); +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java new file mode 100644 index 00000000000..ac0c643eac5 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.knn; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.KnnCollector; + +public class SeededKnnCollector extends KnnCollector.Decorator implements EntryPointProvider { + private final DocIdSetIterator entryPoints; + + public SeededKnnCollector(KnnCollector collector, DocIdSetIterator entryPoints) { + super(collector); + this.entryPoints = entryPoints; + } + + @Override + public DocIdSetIterator entryPoints() { + return entryPoints; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java new file mode 100644 index 00000000000..1cea53c4179 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.knn; + +import java.io.IOException; +import java.util.Arrays; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopScoreDocCollector; +import org.apache.lucene.search.TopScoreDocCollectorManager; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.IOFunction; + +/** A {@link KnnCollectorManager} that collects results with a timeout. */ +public class SeededKnnCollectorManager implements KnnCollectorManager { + private final KnnCollectorManager delegate; + private final Weight seedWeight; + private final int k; + private final IOFunction vectorValuesSupplier; + + public SeededKnnCollectorManager( + KnnCollectorManager delegate, + Weight seedWeight, + int k, + IOFunction vectorValuesSupplier) { + this.delegate = delegate; + this.seedWeight = seedWeight; + this.k = k; + this.vectorValuesSupplier = vectorValuesSupplier; + } + + @Override + public KnnCollector newCollector(int visitedLimit, LeafReaderContext ctx) throws IOException { + // Execute the seed query + TopScoreDocCollector seedCollector = + new TopScoreDocCollectorManager( + k /* numHits */, null /* after */, Integer.MAX_VALUE /* totalHitsThreshold */) + .newCollector(); + final LeafReader leafReader = ctx.reader(); + final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx); + if (leafCollector != null) { + try { + BulkScorer scorer = seedWeight.bulkScorer(ctx); + if (scorer != null) { + scorer.score( + leafCollector, + leafReader.getLiveDocs(), + 0 /* min */, + DocIdSetIterator.NO_MORE_DOCS /* max */); + } + leafCollector.finish(); + } catch ( + @SuppressWarnings("unused") + CollectionTerminatedException e) { + } + } + + TopDocs seedTopDocs = seedCollector.topDocs(); + KnnVectorValues vectorValues = vectorValuesSupplier.apply(leafReader); + if (seedTopDocs.totalHits.value() == 0 || vectorValues == null) { + return delegate.newCollector(visitedLimit, ctx); + } + KnnVectorValues.DocIndexIterator indexIterator = vectorValues.iterator(); + DocIdSetIterator seedDocs = new MappedDISI(indexIterator, new TopDocsDISI(seedTopDocs)); + return new SeededKnnCollector(delegate.newCollector(visitedLimit, ctx), seedDocs); + } + + public static class MappedDISI extends DocIdSetIterator { + KnnVectorValues.DocIndexIterator indexedDISI; + DocIdSetIterator sourceDISI; + + public MappedDISI(KnnVectorValues.DocIndexIterator indexedDISI, DocIdSetIterator sourceDISI) { + this.indexedDISI = indexedDISI; + this.sourceDISI = sourceDISI; + } + + /** + * Advances the source iterator to the first document number that is greater than or equal to + * the provided target and returns the corresponding index. + */ + @Override + public int advance(int target) throws IOException { + int newTarget = sourceDISI.advance(target); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + + @Override + public long cost() { + return this.sourceDISI.cost(); + } + + @Override + public int docID() { + if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + return indexedDISI.index(); + } + + /** Advances to the next document in the source iterator and returns the corresponding index. */ + @Override + public int nextDoc() throws IOException { + int newTarget = sourceDISI.nextDoc(); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + } + + private static class TopDocsDISI extends DocIdSetIterator { + private final int[] sortedDocIds; + private int idx = -1; + + private TopDocsDISI(TopDocs topDocs) { + sortedDocIds = new int[topDocs.scoreDocs.length]; + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + sortedDocIds[i] = topDocs.scoreDocs[i].doc; + } + Arrays.sort(sortedDocIds); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return sortedDocIds.length; + } + + @Override + public int docID() { + if (idx == -1) { + return -1; + } else if (idx >= sortedDocIds.length) { + return DocIdSetIterator.NO_MORE_DOCS; + } else { + return sortedDocIds[idx]; + } + } + + @Override + public int nextDoc() { + idx += 1; + return docID(); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index b3f400372b9..d08e1165cf5 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -20,10 +20,10 @@ package org.apache.lucene.util.hnsw; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; -import java.util.ArrayList; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; +import org.apache.lucene.search.knn.EntryPointProvider; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; @@ -67,25 +67,23 @@ public class HnswGraphSearcher { public static void search( RandomVectorScorer scorer, KnnCollector knnCollector, HnswGraph graph, Bits acceptOrds) throws IOException { - ArrayList entryPointOrdInts = null; - DocIdSetIterator entryPoints = knnCollector.getSeedEntryPoints(); - if (entryPoints != null) { - entryPointOrdInts = new ArrayList(); - int entryPointOrdInt; - while ((entryPointOrdInt = entryPoints.nextDoc()) != NO_MORE_DOCS) { - entryPointOrdInts.add(entryPointOrdInt); - } - } HnswGraphSearcher graphSearcher = new HnswGraphSearcher( new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph))); - if (entryPointOrdInts == null || entryPointOrdInts.isEmpty()) { - search(scorer, knnCollector, graph, graphSearcher, acceptOrds); - } else { - int[] entryPointOrdIntsArr = entryPointOrdInts.stream().mapToInt(Integer::intValue).toArray(); + final int[] entryPoints; + if (knnCollector instanceof EntryPointProvider epp) { + DocIdSetIterator eps = epp.entryPoints(); + entryPoints = new int[(int) eps.cost()]; + int idx = 0; + int entryPointOrdInt; + while ((entryPointOrdInt = eps.nextDoc()) != NO_MORE_DOCS) { + entryPoints[idx++] = entryPointOrdInt; + } // We use provided entry point ordinals to search the complete graph (level 0) graphSearcher.searchLevel( - knnCollector, scorer, 0 /* level */, entryPointOrdIntsArr, graph, acceptOrds); + knnCollector, scorer, 0 /* level */, entryPoints, graph, acceptOrds); + } else { + search(scorer, knnCollector, graph, graphSearcher, acceptOrds); } } diff --git a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java index 92d6562dd3f..1e485515a62 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java @@ -28,6 +28,7 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.SeededKnnFloatVectorQuery; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; @@ -108,7 +109,7 @@ public class TestManyKnnDocs extends LuceneTestCase { vector[1] = 1; TopDocs docs = searcher.search( - new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); assertEquals(5, docs.scoreDocs.length); String s = ""; for (int j = 0; j < docs.scoreDocs.length - 1; j++) { @@ -131,7 +132,7 @@ public class TestManyKnnDocs extends LuceneTestCase { vector[1] = 1; TopDocs docs = searcher.search( - new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); assertEquals(5, docs.scoreDocs.length); String s = ""; for (int j = 0; j < docs.scoreDocs.length - 1; j++) { @@ -154,7 +155,7 @@ public class TestManyKnnDocs extends LuceneTestCase { vector[1] = 1; TopDocs docs = searcher.search( - new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); assertEquals(5, docs.scoreDocs.length); Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); String s = ""; @@ -177,7 +178,7 @@ public class TestManyKnnDocs extends LuceneTestCase { vector[1] = 1; TopDocs docs = searcher.search( - new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); assertEquals(5, docs.scoreDocs.length); Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); String s = ""; diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index fb75ef9b50e..8a0d3b65aea 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -66,15 +66,9 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { abstract AbstractKnnVectorQuery getKnnVectorQuery( String field, float[] query, int k, Query queryFilter); - abstract AbstractKnnVectorQuery getKnnVectorQuery( - String field, float[] query, int k, Query queryFilter, Query seedQuery); - abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery( String field, float[] query, int k, Query queryFilter); - abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery( - String field, float[] query, int k, Query queryFilter, Query seedQuery); - AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k) { return getKnnVectorQuery(field, query, k, null); } @@ -613,91 +607,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { } } - /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ - public void testRandomWithSeed() throws IOException { - int numDocs = 1000; - int dimension = atLeast(5); - int numIters = atLeast(10); - int numDocsWithVector = 0; - try (Directory d = newDirectoryForTest()) { - // Always use the default kNN format to have predictable behavior around when it hits - // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN - // format - // implementation. - IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); - RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); - for (int i = 0; i < numDocs; i++) { - Document doc = new Document(); - if (random().nextBoolean()) { - // Randomly skip some vectors to test the mapping from docid to ordinals - doc.add(getKnnVectorField("field", randomVector(dimension))); - numDocsWithVector += 1; - } - doc.add(new NumericDocValuesField("tag", i)); - doc.add(new IntPoint("tag", i)); - w.addDocument(doc); - } - w.forceMerge(1); - w.close(); - - try (IndexReader reader = DirectoryReader.open(d)) { - IndexSearcher searcher = newSearcher(reader); - for (int i = 0; i < numIters; i++) { - int k = random().nextInt(80) + 1; - int n = random().nextInt(100) + 1; - // we may get fewer results than requested if there are deletions, but this test doesn't - // check that - assert reader.hasDeletions() == false; - - // All documents as seeds - Query seed1 = new MatchAllDocsQuery(); - Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); - AbstractKnnVectorQuery query = - getKnnVectorQuery("field", randomVector(dimension), k, filter, seed1); - TopDocs results = searcher.search(query, n); - int expected = Math.min(Math.min(n, k), numDocsWithVector); - - assertEquals(expected, results.scoreDocs.length); - assertTrue(results.totalHits.value() >= results.scoreDocs.length); - // verify the results are in descending score order - float last = Float.MAX_VALUE; - for (ScoreDoc scoreDoc : results.scoreDocs) { - assertTrue(scoreDoc.score <= last); - last = scoreDoc.score; - } - - // Restrictive seed query -- 6 documents - Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); - query = getKnnVectorQuery("field", randomVector(dimension), k, null, seed2); - results = searcher.search(query, n); - expected = Math.min(Math.min(n, k), reader.numDocs()); - assertEquals(expected, results.scoreDocs.length); - assertTrue(results.totalHits.value() >= results.scoreDocs.length); - // verify the results are in descending score order - last = Float.MAX_VALUE; - for (ScoreDoc scoreDoc : results.scoreDocs) { - assertTrue(scoreDoc.score <= last); - last = scoreDoc.score; - } - - // No seed documents -- falls back on full approx search - Query seed3 = new MatchNoDocsQuery(); - query = getKnnVectorQuery("field", randomVector(dimension), k, null, seed3); - results = searcher.search(query, n); - expected = Math.min(Math.min(n, k), reader.numDocs()); - assertEquals(expected, results.scoreDocs.length); - assertTrue(results.totalHits.value() >= results.scoreDocs.length); - // verify the results are in descending score order - last = Float.MAX_VALUE; - for (ScoreDoc scoreDoc : results.scoreDocs) { - assertTrue(scoreDoc.score <= last); - last = scoreDoc.score; - } - } - } - } - } - /** Tests filtering when all vectors have the same score. */ public void testFilterWithSameScore() throws IOException { int numDocs = 100; diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java index 651a1219f07..21219e0e1d9 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java @@ -34,21 +34,9 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter); } - @Override - AbstractKnnVectorQuery getKnnVectorQuery( - String field, float[] query, int k, Query queryFilter, Query seedQuery) { - return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter, seedQuery); - } - @Override AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { - return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, null); - } - - @Override - AbstractKnnVectorQuery getThrowingKnnVectorQuery( - String field, float[] vec, int k, Query query, Query seedQuery) { - return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, seedQuery); + return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query); } @Override @@ -73,7 +61,7 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN); } - private static byte[] floatToBytes(float[] query) { + static byte[] floatToBytes(float[] query) { byte[] bytes = new byte[query.length]; for (int i = 0; i < query.length; i++) { assert query[i] <= Byte.MAX_VALUE && query[i] >= Byte.MIN_VALUE && (query[i] % 1) == 0 @@ -121,10 +109,10 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { } } - private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { + static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { - public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { - super(field, target, k, filter, seed); + public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) { + super(field, target, k, filter); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index 5a4b49b8e2d..ece2b385654 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -50,21 +50,9 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase { return new KnnFloatVectorQuery(field, query, k, queryFilter); } - @Override - KnnFloatVectorQuery getKnnVectorQuery( - String field, float[] query, int k, Query queryFilter, Query seedQuery) { - return new KnnFloatVectorQuery(field, query, k, queryFilter, seedQuery); - } - @Override AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { - return new ThrowingKnnVectorQuery(field, vec, k, query, null); - } - - @Override - AbstractKnnVectorQuery getThrowingKnnVectorQuery( - String field, float[] vec, int k, Query query, Query seedQuery) { - return new ThrowingKnnVectorQuery(field, vec, k, query, seedQuery); + return new ThrowingKnnVectorQuery(field, vec, k, query); } @Override @@ -271,10 +259,10 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase { } } - private static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { + static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { - public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) { - super(field, target, k, filter, seed); + public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) { + super(field, target, k, filter); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java new file mode 100644 index 00000000000..c4ce074b57f --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import static org.apache.lucene.search.TestKnnByteVectorQuery.floatToBytes; + +import java.io.IOException; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.TestVectorUtil; + +public class TestSeededKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { + + private static final Query MATCH_NONE = new MatchNoDocsQuery(); + + @Override + AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { + return new SeededKnnByteVectorQuery(field, floatToBytes(query), k, queryFilter, MATCH_NONE); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { + return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, MATCH_NONE); + } + + @Override + float[] randomVector(int dim) { + byte[] b = TestVectorUtil.randomVectorBytes(dim); + float[] v = new float[b.length]; + int vi = 0; + for (int i = 0; i < v.length; i++) { + v[vi++] = b[i]; + } + return v; + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnByteVectorField(name, floatToBytes(vector), similarityFunction); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN); + } + + /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ + public void testRandomWithSeed() throws IOException { + int numDocs = 1000; + int dimension = atLeast(5); + int numIters = atLeast(10); + int numDocsWithVector = 0; + try (Directory d = newDirectoryForTest()) { + // Always use the default kNN format to have predictable behavior around when it hits + // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN + // format + // implementation. + IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); + RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (random().nextBoolean()) { + // Randomly skip some vectors to test the mapping from docid to ordinals + doc.add(getKnnVectorField("field", randomVector(dimension))); + numDocsWithVector += 1; + } + doc.add(new NumericDocValuesField("tag", i)); + doc.add(new IntPoint("tag", i)); + w.addDocument(doc); + } + w.forceMerge(1); + w.close(); + + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + for (int i = 0; i < numIters; i++) { + int k = random().nextInt(80) + 1; + int n = random().nextInt(100) + 1; + // we may get fewer results than requested if there are deletions, but this test doesn't + // check that + assert reader.hasDeletions() == false; + + // All documents as seeds + Query seed1 = new MatchAllDocsQuery(); + Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); + SeededKnnByteVectorQuery query = + new SeededKnnByteVectorQuery( + "field", floatToBytes(randomVector(dimension)), k, filter, seed1); + TopDocs results = searcher.search(query, n); + int expected = Math.min(Math.min(n, k), numDocsWithVector); + + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + float last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // Restrictive seed query -- 6 documents + Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); + query = + new SeededKnnByteVectorQuery( + "field", floatToBytes(randomVector(dimension)), k, null, seed2); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // No seed documents -- falls back on full approx search + Query seed3 = new MatchNoDocsQuery(); + query = + new SeededKnnByteVectorQuery( + "field", floatToBytes(randomVector(dimension)), k, null, seed3); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + } + } + } + } + + private static class ThrowingKnnVectorQuery extends SeededKnnByteVectorQuery { + + public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { + super(field, target, k, filter, seed); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { + throw new UnsupportedOperationException("exact search is not supported"); + } + + @Override + public String toString(String field) { + return null; + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java new file mode 100644 index 00000000000..268e7d80637 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.TestVectorUtil; + +public class TestSeededKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase { + private static final Query MATCH_NONE = new MatchNoDocsQuery(); + + @Override + KnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { + return new SeededKnnFloatVectorQuery(field, query, k, queryFilter, MATCH_NONE); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { + return new ThrowingKnnVectorQuery(field, vec, k, query, MATCH_NONE); + } + + @Override + float[] randomVector(int dim) { + return TestVectorUtil.randomVector(dim); + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnFloatVectorField(name, vector, similarityFunction); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnFloatVectorField(name, vector); + } + + /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ + public void testRandomWithSeed() throws IOException { + int numDocs = 1000; + int dimension = atLeast(5); + int numIters = atLeast(10); + int numDocsWithVector = 0; + try (Directory d = newDirectoryForTest()) { + // Always use the default kNN format to have predictable behavior around when it hits + // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN + // format + // implementation. + IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); + RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (random().nextBoolean()) { + // Randomly skip some vectors to test the mapping from docid to ordinals + doc.add(getKnnVectorField("field", randomVector(dimension))); + numDocsWithVector += 1; + } + doc.add(new NumericDocValuesField("tag", i)); + doc.add(new IntPoint("tag", i)); + w.addDocument(doc); + } + w.forceMerge(1); + w.close(); + + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + for (int i = 0; i < numIters; i++) { + int k = random().nextInt(80) + 1; + int n = random().nextInt(100) + 1; + // we may get fewer results than requested if there are deletions, but this test doesn't + // check that + assert reader.hasDeletions() == false; + + // All documents as seeds + Query seed1 = new MatchAllDocsQuery(); + Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); + AbstractKnnVectorQuery query = + new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, filter, seed1); + TopDocs results = searcher.search(query, n); + int expected = Math.min(Math.min(n, k), numDocsWithVector); + + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + float last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // Restrictive seed query -- 6 documents + Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); + query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed2); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // No seed documents -- falls back on full approx search + Query seed3 = new MatchNoDocsQuery(); + query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed3); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + } + } + } + } + + private static class ThrowingKnnVectorQuery extends SeededKnnFloatVectorQuery { + + public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) { + super(field, target, k, filter, seed); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { + throw new UnsupportedOperationException("exact search is not supported"); + } + + @Override + public String toString(String field) { + return null; + } + } +} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index ab140be7113..45cb8b9c88f 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -140,7 +140,6 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery { protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, - DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index ab2e1462c4c..9c44a2f7856 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -139,7 +139,6 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, - DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException {