diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
index 3822f5f1e92..e9246a8b575 100644
--- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
@@ -21,7 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
@@ -49,10 +48,6 @@ import org.apache.lucene.util.Bits;
*
If the kNN search visits too many vectors without completing, stop and run an exact search
*
- *
- * When a seed query is provided, this query is executed first to seed the kNN search (subject to
- * the same rules provided by the filter). If the seed query fails to identify any documents, it
- * falls back on the strategy above.
*/
abstract class AbstractKnnVectorQuery extends Query {
@@ -60,21 +55,15 @@ abstract class AbstractKnnVectorQuery extends Query {
protected final String field;
protected final int k;
- private final Query filter;
- private final Query seed;
+ protected final Query filter;
public AbstractKnnVectorQuery(String field, int k, Query filter) {
- this(field, k, filter, null);
- }
-
- public AbstractKnnVectorQuery(String field, int k, Query filter, Query seed) {
this.field = Objects.requireNonNull(field, "field");
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
}
this.filter = filter;
- this.seed = seed;
}
@Override
@@ -94,21 +83,6 @@ abstract class AbstractKnnVectorQuery extends Query {
filterWeight = null;
}
- final Weight seedWeight;
- if (seed != null) {
- BooleanQuery.Builder booleanSeedQueryBuilder =
- new BooleanQuery.Builder()
- .add(seed, BooleanClause.Occur.MUST)
- .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER);
- if (filter != null) {
- booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
- }
- Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
- seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
- } else {
- seedWeight = null;
- }
-
TimeLimitingKnnCollectorManager knnCollectorManager =
new TimeLimitingKnnCollectorManager(
getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout());
@@ -116,7 +90,7 @@ abstract class AbstractKnnVectorQuery extends Query {
List leafReaderContexts = reader.leaves();
List> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext context : leafReaderContexts) {
- tasks.add(() -> searchLeaf(context, filterWeight, seedWeight, knnCollectorManager));
+ tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
}
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
@@ -131,11 +105,9 @@ abstract class AbstractKnnVectorQuery extends Query {
private TopDocs searchLeaf(
LeafReaderContext ctx,
Weight filterWeight,
- Weight seedWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
- TopDocs results =
- getLeafResults(ctx, filterWeight, seedWeight, timeLimitingKnnCollectorManager);
+ TopDocs results = getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
@@ -147,19 +119,13 @@ abstract class AbstractKnnVectorQuery extends Query {
private TopDocs getLeafResults(
LeafReaderContext ctx,
Weight filterWeight,
- Weight seedWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
final LeafReader reader = ctx.reader();
final Bits liveDocs = reader.getLiveDocs();
if (filterWeight == null) {
- return approximateSearch(
- ctx,
- liveDocs,
- executeSeedQuery(ctx, seedWeight),
- Integer.MAX_VALUE,
- timeLimitingKnnCollectorManager);
+ return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
}
Scorer scorer = filterWeight.scorer(ctx);
@@ -179,13 +145,7 @@ abstract class AbstractKnnVectorQuery extends Query {
// Perform the approximate kNN search
// We pass cost + 1 here to account for the edge case when we explore exactly cost vectors
- TopDocs results =
- approximateSearch(
- ctx,
- acceptDocs,
- executeSeedQuery(ctx, seedWeight),
- cost + 1,
- timeLimitingKnnCollectorManager);
+ TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager);
if (results.totalHits.relation() == TotalHits.Relation.EQUAL_TO
// Return partial results only when timeout is met
|| (queryTimeout != null && queryTimeout.shouldExit())) {
@@ -196,49 +156,6 @@ abstract class AbstractKnnVectorQuery extends Query {
}
}
- private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeight)
- throws IOException {
- if (seedWeight == null) return null;
- // Execute the seed query
- TopScoreDocCollector seedCollector =
- new TopScoreDocCollectorManager(
- k /* numHits */,
- null /* after */,
- Integer.MAX_VALUE /* totalHitsThreshold */,
- false /* supportsConcurrency */)
- .newCollector();
- final LeafReader leafReader = ctx.reader();
- final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx);
- if (leafCollector != null) {
- try {
- BulkScorer scorer = seedWeight.bulkScorer(ctx);
- if (scorer != null) {
- scorer.score(
- leafCollector,
- leafReader.getLiveDocs(),
- 0 /* min */,
- DocIdSetIterator.NO_MORE_DOCS /* max */);
- }
- leafCollector.finish();
- } catch (
- @SuppressWarnings("unused")
- CollectionTerminatedException e) {
- }
- }
-
- TopDocs seedTopDocs = seedCollector.topDocs();
- return convertDocIdsToVectorOrdinals(leafReader, new TopDocsDISI(seedTopDocs));
- }
-
- /**
- * Returns a new iterator that maps the provided docIds to the vector ordinals.
- *
- * @lucene.internal
- * @lucene.experimental
- */
- protected abstract DocIdSetIterator convertDocIdsToVectorOrdinals(
- LeafReader reader, DocIdSetIterator docIds) throws IOException;
-
private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
throws IOException {
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
@@ -264,7 +181,6 @@ abstract class AbstractKnnVectorQuery extends Query {
protected abstract TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
- DocIdSetIterator seedDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException;
@@ -386,15 +302,12 @@ abstract class AbstractKnnVectorQuery extends Query {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AbstractKnnVectorQuery that = (AbstractKnnVectorQuery) o;
- return k == that.k
- && Objects.equals(field, that.field)
- && Objects.equals(filter, that.filter)
- && Objects.equals(seed, that.seed);
+ return k == that.k && Objects.equals(field, that.field) && Objects.equals(filter, that.filter);
}
@Override
public int hashCode() {
- return Objects.hash(field, k, filter, seed);
+ return Objects.hash(field, k, filter);
}
/**
@@ -419,13 +332,6 @@ abstract class AbstractKnnVectorQuery extends Query {
return filter;
}
- /**
- * @return the query that seeds the kNN search.
- */
- public Query getSeed() {
- return seed;
- }
-
/** Caches the results of a KnnVector search: a list of docs and their scores */
static class DocAndScoreQuery extends Query {
@@ -585,44 +491,4 @@ abstract class AbstractKnnVectorQuery extends Query {
classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));
}
}
-
- private static class TopDocsDISI extends DocIdSetIterator {
- private final List sortedDocIdList;
- private int idx = -1;
-
- public TopDocsDISI(TopDocs topDocs) {
- sortedDocIdList = new ArrayList(topDocs.scoreDocs.length);
- for (int i = 0; i < topDocs.scoreDocs.length; i++) {
- sortedDocIdList.add(topDocs.scoreDocs[i].doc);
- }
- Collections.sort(sortedDocIdList);
- }
-
- @Override
- public int advance(int target) throws IOException {
- return slowAdvance(target);
- }
-
- @Override
- public long cost() {
- return sortedDocIdList.size();
- }
-
- @Override
- public int docID() {
- if (idx == -1) {
- return -1;
- } else if (idx >= sortedDocIdList.size()) {
- return DocIdSetIterator.NO_MORE_DOCS;
- } else {
- return sortedDocIdList.get(idx);
- }
- }
-
- @Override
- public int nextDoc() throws IOException {
- idx += 1;
- return docID();
- }
- }
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
index eb637da0d75..05157ab65cb 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
@@ -46,7 +46,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
- private final byte[] target;
+ protected final byte[] target;
/**
* Find the k
nearest documents to the target vector according to the vectors in the
@@ -72,22 +72,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
* @throws IllegalArgumentException if k
is less than 1
*/
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
- this(field, target, k, filter, null);
- }
-
- /**
- * Find the k
nearest documents to the target vector according to the vectors in the
- * given field. target
vector.
- *
- * @param field a field that has been indexed as a {@link KnnByteVectorField}.
- * @param target the target of the search
- * @param k the number of documents to find
- * @param filter a filter applied before the vector search
- * @param seed a query that is executed to seed the vector search
- * @throws IllegalArgumentException if k
is less than 1
- */
- public KnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) {
- super(field, k, filter, seed);
+ super(field, k, filter);
this.target = Objects.requireNonNull(target, "target");
}
@@ -95,14 +80,10 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
- DocIdSetIterator seedDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
- if (seedDocs != null) {
- knnCollector = new KnnCollector.Seeded(knnCollector, seedDocs);
- }
LeafReader reader = context.reader();
ByteVectorValues byteVectorValues = reader.getByteVectorValues(field);
if (byteVectorValues == null) {
@@ -159,18 +140,4 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
public byte[] getTargetCopy() {
return ArrayUtil.copyArray(target);
}
-
- /**
- * Returns a new iterator that maps the provided docIds to the vector ordinals.
- *
- * This method assumes that all docIds have corresponding ordinals.
- *
- * @lucene.internal
- * @lucene.experimental
- */
- @Override
- protected DocIdSetIterator convertDocIdsToVectorOrdinals(
- LeafReader reader, DocIdSetIterator docIds) throws IOException {
- return reader.getByteVectorValues(field).convertDocIdsToVectorOrdinals(docIds);
- }
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java
index 43338fec72e..a05ca674771 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java
@@ -86,26 +86,13 @@ public interface KnnCollector {
*/
TopDocs topDocs();
- /**
- * This method returns a {@link DocIdSetIterator} over entry points that seed the KNN search, or
- * {@code null} (default) to perform a full KNN search (without seeds).
- *
- *
Note that the entry points should represent ordinals, rather than true document IDs.
- *
- * @return the seed entry points or {@code null}.
- * @lucene.experimental
- */
- default DocIdSetIterator getSeedEntryPoints() {
- return null;
- }
-
/**
* KnnCollector.Decorator is the base class for decorators of KnnCollector objects, which extend
* the object with new behaviors.
*
* @lucene.experimental
*/
- public abstract static class Decorator implements KnnCollector {
+ abstract class Decorator implements KnnCollector {
private KnnCollector collector;
public Decorator(KnnCollector collector) {
@@ -151,29 +138,5 @@ public interface KnnCollector {
public TopDocs topDocs() {
return collector.topDocs();
}
-
- @Override
- public DocIdSetIterator getSeedEntryPoints() {
- return collector.getSeedEntryPoints();
- }
- }
-
- /**
- * KnnCollector.Seeded is a KnnCollector decorator that replaces the seedEntryPoints.
- *
- * @lucene.experimental
- */
- public static class Seeded extends Decorator {
- private DocIdSetIterator seedEntryPoints;
-
- public Seeded(KnnCollector collector, DocIdSetIterator seedEntryPoints) {
- super(collector);
- this.seedEntryPoints = seedEntryPoints;
- }
-
- @Override
- public DocIdSetIterator getSeedEntryPoints() {
- return seedEntryPoints;
- }
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java
index 28660756453..c7d6fdb3608 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java
@@ -47,7 +47,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
- private final float[] target;
+ protected final float[] target;
/**
* Find the k
nearest documents to the target vector according to the vectors in the
@@ -73,22 +73,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
* @throws IllegalArgumentException if k
is less than 1
*/
public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) {
- this(field, target, k, filter, null);
- }
-
- /**
- * Find the k
nearest documents to the target vector according to the vectors in the
- * given field. target
vector.
- *
- * @param field a field that has been indexed as a {@link KnnFloatVectorField}.
- * @param target the target of the search
- * @param k the number of documents to find
- * @param filter a filter applied before the vector search
- * @param seed a query that is executed to seed the vector search
- * @throws IllegalArgumentException if k
is less than 1
- */
- public KnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) {
- super(field, k, filter, seed);
+ super(field, k, filter);
this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target"));
}
@@ -96,14 +81,10 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
- DocIdSetIterator seedDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
- if (seedDocs != null) {
- knnCollector = new KnnCollector.Seeded(knnCollector, seedDocs);
- }
LeafReader reader = context.reader();
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
if (floatVectorValues == null) {
@@ -162,18 +143,4 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
public float[] getTargetCopy() {
return ArrayUtil.copyArray(target);
}
-
- /**
- * Returns a new iterator that maps the provided docIds to the vector ordinals.
- *
- *
This method assumes that all docIds have corresponding ordinals.
- *
- * @lucene.internal
- * @lucene.experimental
- */
- @Override
- protected DocIdSetIterator convertDocIdsToVectorOrdinals(
- LeafReader reader, DocIdSetIterator docIds) throws IOException {
- return reader.getFloatVectorValues(field).convertDocIdsToVectorOrdinals(docIds);
- }
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java
new file mode 100644
index 00000000000..93286948b0a
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.io.IOException;
+import java.util.Objects;
+import org.apache.lucene.search.knn.KnnCollectorManager;
+import org.apache.lucene.search.knn.SeededKnnCollectorManager;
+
+/**
+ * This is a version of knn byte vector query that provides a query seed to initiate the vector
+ * search. NOTE: The underlying format is free to ignore the provided seed
+ *
+ * @lucene.experimental
+ */
+public class SeededKnnByteVectorQuery extends KnnByteVectorQuery {
+ private final Query seed;
+ private final Weight seedWeight;
+
+ /**
+ * Construct a new SeededKnnFloatVectorQuery instance
+ *
+ * @param field knn byte vector field to query
+ * @param target the query vector
+ * @param k number of neighbors to return
+ * @param filter a filter on the neighbors to return
+ * @param seed a query seed to initiate the vector format search
+ */
+ public SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) {
+ super(field, target, k, filter);
+ this.seed = Objects.requireNonNull(seed);
+ this.seedWeight = null;
+ }
+
+ SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Weight seedWeight) {
+ super(field, target, k, filter);
+ this.seed = null;
+ this.seedWeight = Objects.requireNonNull(seedWeight);
+ }
+
+ @Override
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ if (seedWeight != null) {
+ return super.rewrite(indexSearcher);
+ }
+ BooleanQuery.Builder booleanSeedQueryBuilder =
+ new BooleanQuery.Builder()
+ .add(seed, BooleanClause.Occur.MUST)
+ .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER);
+ if (filter != null) {
+ booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
+ }
+ Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
+ Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
+ SeededKnnByteVectorQuery rewritten =
+ new SeededKnnByteVectorQuery(field, target, k, filter, seedWeight);
+ return rewritten.rewrite(indexSearcher);
+ }
+
+ @Override
+ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
+ if (seedWeight == null) {
+ throw new UnsupportedOperationException("must be rewritten before constructing manager");
+ }
+ return new SeededKnnCollectorManager(
+ super.getKnnCollectorManager(k, searcher),
+ seedWeight,
+ k,
+ leaf -> leaf.getFloatVectorValues(field));
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java
new file mode 100644
index 00000000000..f64e0b29bc6
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.io.IOException;
+import java.util.Objects;
+import org.apache.lucene.search.knn.KnnCollectorManager;
+import org.apache.lucene.search.knn.SeededKnnCollectorManager;
+
+/**
+ * This is a version of knn float vector query that provides a query seed to initiate the vector
+ * search. NOTE: The underlying format is free to ignore the provided seed
+ *
+ * @lucene.experimental
+ */
+public class SeededKnnFloatVectorQuery extends KnnFloatVectorQuery {
+ private final Query seed;
+ private final Weight seedWeight;
+
+ /**
+ * Construct a new SeededKnnFloatVectorQuery instance
+ *
+ * @param field knn float vector field to query
+ * @param target the query vector
+ * @param k number of neighbors to return
+ * @param filter a filter on the neighbors to return
+ * @param seed a query seed to initiate the vector format search
+ */
+ public SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) {
+ super(field, target, k, filter);
+ this.seed = Objects.requireNonNull(seed);
+ this.seedWeight = null;
+ }
+
+ SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Weight seedWeight) {
+ super(field, target, k, filter);
+ this.seed = null;
+ this.seedWeight = Objects.requireNonNull(seedWeight);
+ }
+
+ @Override
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ if (seedWeight != null) {
+ return super.rewrite(indexSearcher);
+ }
+ BooleanQuery.Builder booleanSeedQueryBuilder =
+ new BooleanQuery.Builder()
+ .add(seed, BooleanClause.Occur.MUST)
+ .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER);
+ if (filter != null) {
+ booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
+ }
+ Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
+ Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
+ SeededKnnFloatVectorQuery rewritten =
+ new SeededKnnFloatVectorQuery(field, target, k, filter, seedWeight);
+ return rewritten.rewrite(indexSearcher);
+ }
+
+ @Override
+ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
+ if (seedWeight == null) {
+ throw new UnsupportedOperationException("must be rewritten before constructing manager");
+ }
+ return new SeededKnnCollectorManager(
+ super.getKnnCollectorManager(k, searcher),
+ seedWeight,
+ k,
+ leaf -> leaf.getFloatVectorValues(field));
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java
new file mode 100644
index 00000000000..40eda94c654
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search.knn;
+
+import org.apache.lucene.search.DocIdSetIterator;
+
+/** Provides entry points for the kNN search */
+public interface EntryPointProvider {
+ /** Iterator of valid entry points for the kNN search */
+ DocIdSetIterator entryPoints();
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java
new file mode 100644
index 00000000000..ac0c643eac5
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search.knn;
+
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.KnnCollector;
+
+public class SeededKnnCollector extends KnnCollector.Decorator implements EntryPointProvider {
+ private final DocIdSetIterator entryPoints;
+
+ public SeededKnnCollector(KnnCollector collector, DocIdSetIterator entryPoints) {
+ super(collector);
+ this.entryPoints = entryPoints;
+ }
+
+ @Override
+ public DocIdSetIterator entryPoints() {
+ return entryPoints;
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java
new file mode 100644
index 00000000000..1cea53c4179
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search.knn;
+
+import java.io.IOException;
+import java.util.Arrays;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.BulkScorer;
+import org.apache.lucene.search.CollectionTerminatedException;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.search.LeafCollector;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.TopScoreDocCollector;
+import org.apache.lucene.search.TopScoreDocCollectorManager;
+import org.apache.lucene.search.Weight;
+import org.apache.lucene.util.IOFunction;
+
+/** A {@link KnnCollectorManager} that collects results with a timeout. */
+public class SeededKnnCollectorManager implements KnnCollectorManager {
+ private final KnnCollectorManager delegate;
+ private final Weight seedWeight;
+ private final int k;
+ private final IOFunction vectorValuesSupplier;
+
+ public SeededKnnCollectorManager(
+ KnnCollectorManager delegate,
+ Weight seedWeight,
+ int k,
+ IOFunction vectorValuesSupplier) {
+ this.delegate = delegate;
+ this.seedWeight = seedWeight;
+ this.k = k;
+ this.vectorValuesSupplier = vectorValuesSupplier;
+ }
+
+ @Override
+ public KnnCollector newCollector(int visitedLimit, LeafReaderContext ctx) throws IOException {
+ // Execute the seed query
+ TopScoreDocCollector seedCollector =
+ new TopScoreDocCollectorManager(
+ k /* numHits */, null /* after */, Integer.MAX_VALUE /* totalHitsThreshold */)
+ .newCollector();
+ final LeafReader leafReader = ctx.reader();
+ final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx);
+ if (leafCollector != null) {
+ try {
+ BulkScorer scorer = seedWeight.bulkScorer(ctx);
+ if (scorer != null) {
+ scorer.score(
+ leafCollector,
+ leafReader.getLiveDocs(),
+ 0 /* min */,
+ DocIdSetIterator.NO_MORE_DOCS /* max */);
+ }
+ leafCollector.finish();
+ } catch (
+ @SuppressWarnings("unused")
+ CollectionTerminatedException e) {
+ }
+ }
+
+ TopDocs seedTopDocs = seedCollector.topDocs();
+ KnnVectorValues vectorValues = vectorValuesSupplier.apply(leafReader);
+ if (seedTopDocs.totalHits.value() == 0 || vectorValues == null) {
+ return delegate.newCollector(visitedLimit, ctx);
+ }
+ KnnVectorValues.DocIndexIterator indexIterator = vectorValues.iterator();
+ DocIdSetIterator seedDocs = new MappedDISI(indexIterator, new TopDocsDISI(seedTopDocs));
+ return new SeededKnnCollector(delegate.newCollector(visitedLimit, ctx), seedDocs);
+ }
+
+ public static class MappedDISI extends DocIdSetIterator {
+ KnnVectorValues.DocIndexIterator indexedDISI;
+ DocIdSetIterator sourceDISI;
+
+ public MappedDISI(KnnVectorValues.DocIndexIterator indexedDISI, DocIdSetIterator sourceDISI) {
+ this.indexedDISI = indexedDISI;
+ this.sourceDISI = sourceDISI;
+ }
+
+ /**
+ * Advances the source iterator to the first document number that is greater than or equal to
+ * the provided target and returns the corresponding index.
+ */
+ @Override
+ public int advance(int target) throws IOException {
+ int newTarget = sourceDISI.advance(target);
+ if (newTarget != NO_MORE_DOCS) {
+ indexedDISI.advance(newTarget);
+ }
+ return docID();
+ }
+
+ @Override
+ public long cost() {
+ return this.sourceDISI.cost();
+ }
+
+ @Override
+ public int docID() {
+ if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) {
+ return NO_MORE_DOCS;
+ }
+ return indexedDISI.index();
+ }
+
+ /** Advances to the next document in the source iterator and returns the corresponding index. */
+ @Override
+ public int nextDoc() throws IOException {
+ int newTarget = sourceDISI.nextDoc();
+ if (newTarget != NO_MORE_DOCS) {
+ indexedDISI.advance(newTarget);
+ }
+ return docID();
+ }
+ }
+
+ private static class TopDocsDISI extends DocIdSetIterator {
+ private final int[] sortedDocIds;
+ private int idx = -1;
+
+ private TopDocsDISI(TopDocs topDocs) {
+ sortedDocIds = new int[topDocs.scoreDocs.length];
+ for (int i = 0; i < topDocs.scoreDocs.length; i++) {
+ sortedDocIds[i] = topDocs.scoreDocs[i].doc;
+ }
+ Arrays.sort(sortedDocIds);
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ return slowAdvance(target);
+ }
+
+ @Override
+ public long cost() {
+ return sortedDocIds.length;
+ }
+
+ @Override
+ public int docID() {
+ if (idx == -1) {
+ return -1;
+ } else if (idx >= sortedDocIds.length) {
+ return DocIdSetIterator.NO_MORE_DOCS;
+ } else {
+ return sortedDocIds[idx];
+ }
+ }
+
+ @Override
+ public int nextDoc() {
+ idx += 1;
+ return docID();
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
index b3f400372b9..d08e1165cf5 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
@@ -20,10 +20,10 @@ package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
-import java.util.ArrayList;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopKnnCollector;
+import org.apache.lucene.search.knn.EntryPointProvider;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
@@ -67,25 +67,23 @@ public class HnswGraphSearcher {
public static void search(
RandomVectorScorer scorer, KnnCollector knnCollector, HnswGraph graph, Bits acceptOrds)
throws IOException {
- ArrayList entryPointOrdInts = null;
- DocIdSetIterator entryPoints = knnCollector.getSeedEntryPoints();
- if (entryPoints != null) {
- entryPointOrdInts = new ArrayList();
- int entryPointOrdInt;
- while ((entryPointOrdInt = entryPoints.nextDoc()) != NO_MORE_DOCS) {
- entryPointOrdInts.add(entryPointOrdInt);
- }
- }
HnswGraphSearcher graphSearcher =
new HnswGraphSearcher(
new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph)));
- if (entryPointOrdInts == null || entryPointOrdInts.isEmpty()) {
- search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
- } else {
- int[] entryPointOrdIntsArr = entryPointOrdInts.stream().mapToInt(Integer::intValue).toArray();
+ final int[] entryPoints;
+ if (knnCollector instanceof EntryPointProvider epp) {
+ DocIdSetIterator eps = epp.entryPoints();
+ entryPoints = new int[(int) eps.cost()];
+ int idx = 0;
+ int entryPointOrdInt;
+ while ((entryPointOrdInt = eps.nextDoc()) != NO_MORE_DOCS) {
+ entryPoints[idx++] = entryPointOrdInt;
+ }
// We use provided entry point ordinals to search the complete graph (level 0)
graphSearcher.searchLevel(
- knnCollector, scorer, 0 /* level */, entryPointOrdIntsArr, graph, acceptOrds);
+ knnCollector, scorer, 0 /* level */, entryPoints, graph, acceptOrds);
+ } else {
+ search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java
index 92d6562dd3f..1e485515a62 100644
--- a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java
+++ b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java
@@ -28,6 +28,7 @@ import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
+import org.apache.lucene.search.SeededKnnFloatVectorQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
@@ -108,7 +109,7 @@ public class TestManyKnnDocs extends LuceneTestCase {
vector[1] = 1;
TopDocs docs =
searcher.search(
- new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
+ new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
assertEquals(5, docs.scoreDocs.length);
String s = "";
for (int j = 0; j < docs.scoreDocs.length - 1; j++) {
@@ -131,7 +132,7 @@ public class TestManyKnnDocs extends LuceneTestCase {
vector[1] = 1;
TopDocs docs =
searcher.search(
- new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
+ new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
assertEquals(5, docs.scoreDocs.length);
String s = "";
for (int j = 0; j < docs.scoreDocs.length - 1; j++) {
@@ -154,7 +155,7 @@ public class TestManyKnnDocs extends LuceneTestCase {
vector[1] = 1;
TopDocs docs =
searcher.search(
- new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
+ new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
assertEquals(5, docs.scoreDocs.length);
Document d = searcher.storedFields().document(docs.scoreDocs[0].doc);
String s = "";
@@ -177,7 +178,7 @@ public class TestManyKnnDocs extends LuceneTestCase {
vector[1] = 1;
TopDocs docs =
searcher.search(
- new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
+ new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
assertEquals(5, docs.scoreDocs.length);
Document d = searcher.storedFields().document(docs.scoreDocs[0].doc);
String s = "";
diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
index fb75ef9b50e..8a0d3b65aea 100644
--- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
@@ -66,15 +66,9 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
abstract AbstractKnnVectorQuery getKnnVectorQuery(
String field, float[] query, int k, Query queryFilter);
- abstract AbstractKnnVectorQuery getKnnVectorQuery(
- String field, float[] query, int k, Query queryFilter, Query seedQuery);
-
abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery(
String field, float[] query, int k, Query queryFilter);
- abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery(
- String field, float[] query, int k, Query queryFilter, Query seedQuery);
-
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k) {
return getKnnVectorQuery(field, query, k, null);
}
@@ -613,91 +607,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
}
}
- /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */
- public void testRandomWithSeed() throws IOException {
- int numDocs = 1000;
- int dimension = atLeast(5);
- int numIters = atLeast(10);
- int numDocsWithVector = 0;
- try (Directory d = newDirectoryForTest()) {
- // Always use the default kNN format to have predictable behavior around when it hits
- // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
- // format
- // implementation.
- IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
- RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
- for (int i = 0; i < numDocs; i++) {
- Document doc = new Document();
- if (random().nextBoolean()) {
- // Randomly skip some vectors to test the mapping from docid to ordinals
- doc.add(getKnnVectorField("field", randomVector(dimension)));
- numDocsWithVector += 1;
- }
- doc.add(new NumericDocValuesField("tag", i));
- doc.add(new IntPoint("tag", i));
- w.addDocument(doc);
- }
- w.forceMerge(1);
- w.close();
-
- try (IndexReader reader = DirectoryReader.open(d)) {
- IndexSearcher searcher = newSearcher(reader);
- for (int i = 0; i < numIters; i++) {
- int k = random().nextInt(80) + 1;
- int n = random().nextInt(100) + 1;
- // we may get fewer results than requested if there are deletions, but this test doesn't
- // check that
- assert reader.hasDeletions() == false;
-
- // All documents as seeds
- Query seed1 = new MatchAllDocsQuery();
- Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery();
- AbstractKnnVectorQuery query =
- getKnnVectorQuery("field", randomVector(dimension), k, filter, seed1);
- TopDocs results = searcher.search(query, n);
- int expected = Math.min(Math.min(n, k), numDocsWithVector);
-
- assertEquals(expected, results.scoreDocs.length);
- assertTrue(results.totalHits.value() >= results.scoreDocs.length);
- // verify the results are in descending score order
- float last = Float.MAX_VALUE;
- for (ScoreDoc scoreDoc : results.scoreDocs) {
- assertTrue(scoreDoc.score <= last);
- last = scoreDoc.score;
- }
-
- // Restrictive seed query -- 6 documents
- Query seed2 = IntPoint.newRangeQuery("tag", 1, 6);
- query = getKnnVectorQuery("field", randomVector(dimension), k, null, seed2);
- results = searcher.search(query, n);
- expected = Math.min(Math.min(n, k), reader.numDocs());
- assertEquals(expected, results.scoreDocs.length);
- assertTrue(results.totalHits.value() >= results.scoreDocs.length);
- // verify the results are in descending score order
- last = Float.MAX_VALUE;
- for (ScoreDoc scoreDoc : results.scoreDocs) {
- assertTrue(scoreDoc.score <= last);
- last = scoreDoc.score;
- }
-
- // No seed documents -- falls back on full approx search
- Query seed3 = new MatchNoDocsQuery();
- query = getKnnVectorQuery("field", randomVector(dimension), k, null, seed3);
- results = searcher.search(query, n);
- expected = Math.min(Math.min(n, k), reader.numDocs());
- assertEquals(expected, results.scoreDocs.length);
- assertTrue(results.totalHits.value() >= results.scoreDocs.length);
- // verify the results are in descending score order
- last = Float.MAX_VALUE;
- for (ScoreDoc scoreDoc : results.scoreDocs) {
- assertTrue(scoreDoc.score <= last);
- last = scoreDoc.score;
- }
- }
- }
- }
- }
-
/** Tests filtering when all vectors have the same score. */
public void testFilterWithSameScore() throws IOException {
int numDocs = 100;
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
index 651a1219f07..21219e0e1d9 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
@@ -34,21 +34,9 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter);
}
- @Override
- AbstractKnnVectorQuery getKnnVectorQuery(
- String field, float[] query, int k, Query queryFilter, Query seedQuery) {
- return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter, seedQuery);
- }
-
@Override
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
- return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, null);
- }
-
- @Override
- AbstractKnnVectorQuery getThrowingKnnVectorQuery(
- String field, float[] vec, int k, Query query, Query seedQuery) {
- return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, seedQuery);
+ return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query);
}
@Override
@@ -73,7 +61,7 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN);
}
- private static byte[] floatToBytes(float[] query) {
+ static byte[] floatToBytes(float[] query) {
byte[] bytes = new byte[query.length];
for (int i = 0; i < query.length; i++) {
assert query[i] <= Byte.MAX_VALUE && query[i] >= Byte.MIN_VALUE && (query[i] % 1) == 0
@@ -121,10 +109,10 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
}
}
- private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
+ static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
- public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, Query seed) {
- super(field, target, k, filter, seed);
+ public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
+ super(field, target, k, filter);
}
@Override
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
index 5a4b49b8e2d..ece2b385654 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
@@ -50,21 +50,9 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
return new KnnFloatVectorQuery(field, query, k, queryFilter);
}
- @Override
- KnnFloatVectorQuery getKnnVectorQuery(
- String field, float[] query, int k, Query queryFilter, Query seedQuery) {
- return new KnnFloatVectorQuery(field, query, k, queryFilter, seedQuery);
- }
-
@Override
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
- return new ThrowingKnnVectorQuery(field, vec, k, query, null);
- }
-
- @Override
- AbstractKnnVectorQuery getThrowingKnnVectorQuery(
- String field, float[] vec, int k, Query query, Query seedQuery) {
- return new ThrowingKnnVectorQuery(field, vec, k, query, seedQuery);
+ return new ThrowingKnnVectorQuery(field, vec, k, query);
}
@Override
@@ -271,10 +259,10 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
}
}
- private static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery {
+ static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery {
- public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) {
- super(field, target, k, filter, seed);
+ public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) {
+ super(field, target, k, filter);
}
@Override
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java
new file mode 100644
index 00000000000..c4ce074b57f
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.TestKnnByteVectorQuery.floatToBytes;
+
+import java.io.IOException;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.IntPoint;
+import org.apache.lucene.document.KnnByteVectorField;
+import org.apache.lucene.document.NumericDocValuesField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.QueryTimeout;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.util.TestUtil;
+import org.apache.lucene.util.TestVectorUtil;
+
+public class TestSeededKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
+
+ private static final Query MATCH_NONE = new MatchNoDocsQuery();
+
+ @Override
+ AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
+ return new SeededKnnByteVectorQuery(field, floatToBytes(query), k, queryFilter, MATCH_NONE);
+ }
+
+ @Override
+ AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
+ return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, MATCH_NONE);
+ }
+
+ @Override
+ float[] randomVector(int dim) {
+ byte[] b = TestVectorUtil.randomVectorBytes(dim);
+ float[] v = new float[b.length];
+ int vi = 0;
+ for (int i = 0; i < v.length; i++) {
+ v[vi++] = b[i];
+ }
+ return v;
+ }
+
+ @Override
+ Field getKnnVectorField(
+ String name, float[] vector, VectorSimilarityFunction similarityFunction) {
+ return new KnnByteVectorField(name, floatToBytes(vector), similarityFunction);
+ }
+
+ @Override
+ Field getKnnVectorField(String name, float[] vector) {
+ return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN);
+ }
+
+ /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */
+ public void testRandomWithSeed() throws IOException {
+ int numDocs = 1000;
+ int dimension = atLeast(5);
+ int numIters = atLeast(10);
+ int numDocsWithVector = 0;
+ try (Directory d = newDirectoryForTest()) {
+ // Always use the default kNN format to have predictable behavior around when it hits
+ // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
+ // format
+ // implementation.
+ IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
+ RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
+ for (int i = 0; i < numDocs; i++) {
+ Document doc = new Document();
+ if (random().nextBoolean()) {
+ // Randomly skip some vectors to test the mapping from docid to ordinals
+ doc.add(getKnnVectorField("field", randomVector(dimension)));
+ numDocsWithVector += 1;
+ }
+ doc.add(new NumericDocValuesField("tag", i));
+ doc.add(new IntPoint("tag", i));
+ w.addDocument(doc);
+ }
+ w.forceMerge(1);
+ w.close();
+
+ try (IndexReader reader = DirectoryReader.open(d)) {
+ IndexSearcher searcher = newSearcher(reader);
+ for (int i = 0; i < numIters; i++) {
+ int k = random().nextInt(80) + 1;
+ int n = random().nextInt(100) + 1;
+ // we may get fewer results than requested if there are deletions, but this test doesn't
+ // check that
+ assert reader.hasDeletions() == false;
+
+ // All documents as seeds
+ Query seed1 = new MatchAllDocsQuery();
+ Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery();
+ SeededKnnByteVectorQuery query =
+ new SeededKnnByteVectorQuery(
+ "field", floatToBytes(randomVector(dimension)), k, filter, seed1);
+ TopDocs results = searcher.search(query, n);
+ int expected = Math.min(Math.min(n, k), numDocsWithVector);
+
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value() >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ float last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+
+ // Restrictive seed query -- 6 documents
+ Query seed2 = IntPoint.newRangeQuery("tag", 1, 6);
+ query =
+ new SeededKnnByteVectorQuery(
+ "field", floatToBytes(randomVector(dimension)), k, null, seed2);
+ results = searcher.search(query, n);
+ expected = Math.min(Math.min(n, k), reader.numDocs());
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value() >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+
+ // No seed documents -- falls back on full approx search
+ Query seed3 = new MatchNoDocsQuery();
+ query =
+ new SeededKnnByteVectorQuery(
+ "field", floatToBytes(randomVector(dimension)), k, null, seed3);
+ results = searcher.search(query, n);
+ expected = Math.min(Math.min(n, k), reader.numDocs());
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value() >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+ }
+ }
+ }
+ }
+
+ private static class ThrowingKnnVectorQuery extends SeededKnnByteVectorQuery {
+
+ public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, Query seed) {
+ super(field, target, k, filter, seed);
+ }
+
+ @Override
+ protected TopDocs exactSearch(
+ LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
+ throw new UnsupportedOperationException("exact search is not supported");
+ }
+
+ @Override
+ public String toString(String field) {
+ return null;
+ }
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java
new file mode 100644
index 00000000000..268e7d80637
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.io.IOException;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.IntPoint;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.document.NumericDocValuesField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.QueryTimeout;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.util.TestUtil;
+import org.apache.lucene.util.TestVectorUtil;
+
+public class TestSeededKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
+ private static final Query MATCH_NONE = new MatchNoDocsQuery();
+
+ @Override
+ KnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
+ return new SeededKnnFloatVectorQuery(field, query, k, queryFilter, MATCH_NONE);
+ }
+
+ @Override
+ AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
+ return new ThrowingKnnVectorQuery(field, vec, k, query, MATCH_NONE);
+ }
+
+ @Override
+ float[] randomVector(int dim) {
+ return TestVectorUtil.randomVector(dim);
+ }
+
+ @Override
+ Field getKnnVectorField(
+ String name, float[] vector, VectorSimilarityFunction similarityFunction) {
+ return new KnnFloatVectorField(name, vector, similarityFunction);
+ }
+
+ @Override
+ Field getKnnVectorField(String name, float[] vector) {
+ return new KnnFloatVectorField(name, vector);
+ }
+
+ /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */
+ public void testRandomWithSeed() throws IOException {
+ int numDocs = 1000;
+ int dimension = atLeast(5);
+ int numIters = atLeast(10);
+ int numDocsWithVector = 0;
+ try (Directory d = newDirectoryForTest()) {
+ // Always use the default kNN format to have predictable behavior around when it hits
+ // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
+ // format
+ // implementation.
+ IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
+ RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
+ for (int i = 0; i < numDocs; i++) {
+ Document doc = new Document();
+ if (random().nextBoolean()) {
+ // Randomly skip some vectors to test the mapping from docid to ordinals
+ doc.add(getKnnVectorField("field", randomVector(dimension)));
+ numDocsWithVector += 1;
+ }
+ doc.add(new NumericDocValuesField("tag", i));
+ doc.add(new IntPoint("tag", i));
+ w.addDocument(doc);
+ }
+ w.forceMerge(1);
+ w.close();
+
+ try (IndexReader reader = DirectoryReader.open(d)) {
+ IndexSearcher searcher = newSearcher(reader);
+ for (int i = 0; i < numIters; i++) {
+ int k = random().nextInt(80) + 1;
+ int n = random().nextInt(100) + 1;
+ // we may get fewer results than requested if there are deletions, but this test doesn't
+ // check that
+ assert reader.hasDeletions() == false;
+
+ // All documents as seeds
+ Query seed1 = new MatchAllDocsQuery();
+ Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery();
+ AbstractKnnVectorQuery query =
+ new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, filter, seed1);
+ TopDocs results = searcher.search(query, n);
+ int expected = Math.min(Math.min(n, k), numDocsWithVector);
+
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value() >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ float last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+
+ // Restrictive seed query -- 6 documents
+ Query seed2 = IntPoint.newRangeQuery("tag", 1, 6);
+ query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed2);
+ results = searcher.search(query, n);
+ expected = Math.min(Math.min(n, k), reader.numDocs());
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value() >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+
+ // No seed documents -- falls back on full approx search
+ Query seed3 = new MatchNoDocsQuery();
+ query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed3);
+ results = searcher.search(query, n);
+ expected = Math.min(Math.min(n, k), reader.numDocs());
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value() >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+ }
+ }
+ }
+ }
+
+ private static class ThrowingKnnVectorQuery extends SeededKnnFloatVectorQuery {
+
+ public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) {
+ super(field, target, k, filter, seed);
+ }
+
+ @Override
+ protected TopDocs exactSearch(
+ LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
+ throw new UnsupportedOperationException("exact search is not supported");
+ }
+
+ @Override
+ public String toString(String field) {
+ return null;
+ }
+ }
+}
diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java
index ab140be7113..45cb8b9c88f 100644
--- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java
+++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java
@@ -140,7 +140,6 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
- DocIdSetIterator seedDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java
index ab2e1462c4c..9c44a2f7856 100644
--- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java
+++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java
@@ -139,7 +139,6 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
- DocIdSetIterator seedDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {