This commit is contained in:
Benjamin Trent 2024-12-20 11:34:00 -05:00
parent 3cfe678649
commit cb6ab5484e
18 changed files with 790 additions and 394 deletions

View File

@ -21,7 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
@ -49,10 +48,6 @@ import org.apache.lucene.util.Bits;
* <li>Otherwise run a kNN search subject to the filter
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
* </ul>
*
* <p>When a seed query is provided, this query is executed first to seed the kNN search (subject to
* the same rules provided by the filter). If the seed query fails to identify any documents, it
* falls back on the strategy above.
*/
abstract class AbstractKnnVectorQuery extends Query {
@ -60,21 +55,15 @@ abstract class AbstractKnnVectorQuery extends Query {
protected final String field;
protected final int k;
private final Query filter;
private final Query seed;
protected final Query filter;
public AbstractKnnVectorQuery(String field, int k, Query filter) {
this(field, k, filter, null);
}
public AbstractKnnVectorQuery(String field, int k, Query filter, Query seed) {
this.field = Objects.requireNonNull(field, "field");
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
}
this.filter = filter;
this.seed = seed;
}
@Override
@ -94,21 +83,6 @@ abstract class AbstractKnnVectorQuery extends Query {
filterWeight = null;
}
final Weight seedWeight;
if (seed != null) {
BooleanQuery.Builder booleanSeedQueryBuilder =
new BooleanQuery.Builder()
.add(seed, BooleanClause.Occur.MUST)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER);
if (filter != null) {
booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
}
Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
} else {
seedWeight = null;
}
TimeLimitingKnnCollectorManager knnCollectorManager =
new TimeLimitingKnnCollectorManager(
getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout());
@ -116,7 +90,7 @@ abstract class AbstractKnnVectorQuery extends Query {
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext context : leafReaderContexts) {
tasks.add(() -> searchLeaf(context, filterWeight, seedWeight, knnCollectorManager));
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
}
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
@ -131,11 +105,9 @@ abstract class AbstractKnnVectorQuery extends Query {
private TopDocs searchLeaf(
LeafReaderContext ctx,
Weight filterWeight,
Weight seedWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
TopDocs results =
getLeafResults(ctx, filterWeight, seedWeight, timeLimitingKnnCollectorManager);
TopDocs results = getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
@ -147,19 +119,13 @@ abstract class AbstractKnnVectorQuery extends Query {
private TopDocs getLeafResults(
LeafReaderContext ctx,
Weight filterWeight,
Weight seedWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
final LeafReader reader = ctx.reader();
final Bits liveDocs = reader.getLiveDocs();
if (filterWeight == null) {
return approximateSearch(
ctx,
liveDocs,
executeSeedQuery(ctx, seedWeight),
Integer.MAX_VALUE,
timeLimitingKnnCollectorManager);
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
}
Scorer scorer = filterWeight.scorer(ctx);
@ -179,13 +145,7 @@ abstract class AbstractKnnVectorQuery extends Query {
// Perform the approximate kNN search
// We pass cost + 1 here to account for the edge case when we explore exactly cost vectors
TopDocs results =
approximateSearch(
ctx,
acceptDocs,
executeSeedQuery(ctx, seedWeight),
cost + 1,
timeLimitingKnnCollectorManager);
TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager);
if (results.totalHits.relation() == TotalHits.Relation.EQUAL_TO
// Return partial results only when timeout is met
|| (queryTimeout != null && queryTimeout.shouldExit())) {
@ -196,49 +156,6 @@ abstract class AbstractKnnVectorQuery extends Query {
}
}
private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeight)
throws IOException {
if (seedWeight == null) return null;
// Execute the seed query
TopScoreDocCollector seedCollector =
new TopScoreDocCollectorManager(
k /* numHits */,
null /* after */,
Integer.MAX_VALUE /* totalHitsThreshold */,
false /* supportsConcurrency */)
.newCollector();
final LeafReader leafReader = ctx.reader();
final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx);
if (leafCollector != null) {
try {
BulkScorer scorer = seedWeight.bulkScorer(ctx);
if (scorer != null) {
scorer.score(
leafCollector,
leafReader.getLiveDocs(),
0 /* min */,
DocIdSetIterator.NO_MORE_DOCS /* max */);
}
leafCollector.finish();
} catch (
@SuppressWarnings("unused")
CollectionTerminatedException e) {
}
}
TopDocs seedTopDocs = seedCollector.topDocs();
return convertDocIdsToVectorOrdinals(leafReader, new TopDocsDISI(seedTopDocs));
}
/**
* Returns a new iterator that maps the provided docIds to the vector ordinals.
*
* @lucene.internal
* @lucene.experimental
*/
protected abstract DocIdSetIterator convertDocIdsToVectorOrdinals(
LeafReader reader, DocIdSetIterator docIds) throws IOException;
private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
throws IOException {
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
@ -264,7 +181,6 @@ abstract class AbstractKnnVectorQuery extends Query {
protected abstract TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
DocIdSetIterator seedDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException;
@ -386,15 +302,12 @@ abstract class AbstractKnnVectorQuery extends Query {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AbstractKnnVectorQuery that = (AbstractKnnVectorQuery) o;
return k == that.k
&& Objects.equals(field, that.field)
&& Objects.equals(filter, that.filter)
&& Objects.equals(seed, that.seed);
return k == that.k && Objects.equals(field, that.field) && Objects.equals(filter, that.filter);
}
@Override
public int hashCode() {
return Objects.hash(field, k, filter, seed);
return Objects.hash(field, k, filter);
}
/**
@ -419,13 +332,6 @@ abstract class AbstractKnnVectorQuery extends Query {
return filter;
}
/**
* @return the query that seeds the kNN search.
*/
public Query getSeed() {
return seed;
}
/** Caches the results of a KnnVector search: a list of docs and their scores */
static class DocAndScoreQuery extends Query {
@ -585,44 +491,4 @@ abstract class AbstractKnnVectorQuery extends Query {
classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));
}
}
private static class TopDocsDISI extends DocIdSetIterator {
private final List<Integer> sortedDocIdList;
private int idx = -1;
public TopDocsDISI(TopDocs topDocs) {
sortedDocIdList = new ArrayList<Integer>(topDocs.scoreDocs.length);
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
sortedDocIdList.add(topDocs.scoreDocs[i].doc);
}
Collections.sort(sortedDocIdList);
}
@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
}
@Override
public long cost() {
return sortedDocIdList.size();
}
@Override
public int docID() {
if (idx == -1) {
return -1;
} else if (idx >= sortedDocIdList.size()) {
return DocIdSetIterator.NO_MORE_DOCS;
} else {
return sortedDocIdList.get(idx);
}
}
@Override
public int nextDoc() throws IOException {
idx += 1;
return docID();
}
}
}

View File

@ -46,7 +46,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
private final byte[] target;
protected final byte[] target;
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
@ -72,22 +72,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
this(field, target, k, filter, null);
}
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
* given field. <code>target</code> vector.
*
* @param field a field that has been indexed as a {@link KnnByteVectorField}.
* @param target the target of the search
* @param k the number of documents to find
* @param filter a filter applied before the vector search
* @param seed a query that is executed to seed the vector search
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) {
super(field, k, filter, seed);
super(field, k, filter);
this.target = Objects.requireNonNull(target, "target");
}
@ -95,14 +80,10 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
DocIdSetIterator seedDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
if (seedDocs != null) {
knnCollector = new KnnCollector.Seeded(knnCollector, seedDocs);
}
LeafReader reader = context.reader();
ByteVectorValues byteVectorValues = reader.getByteVectorValues(field);
if (byteVectorValues == null) {
@ -159,18 +140,4 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
public byte[] getTargetCopy() {
return ArrayUtil.copyArray(target);
}
/**
* Returns a new iterator that maps the provided docIds to the vector ordinals.
*
* <p>This method assumes that all docIds have corresponding ordinals.
*
* @lucene.internal
* @lucene.experimental
*/
@Override
protected DocIdSetIterator convertDocIdsToVectorOrdinals(
LeafReader reader, DocIdSetIterator docIds) throws IOException {
return reader.getByteVectorValues(field).convertDocIdsToVectorOrdinals(docIds);
}
}

View File

@ -86,26 +86,13 @@ public interface KnnCollector {
*/
TopDocs topDocs();
/**
* This method returns a {@link DocIdSetIterator} over entry points that seed the KNN search, or
* {@code null} (default) to perform a full KNN search (without seeds).
*
* <p>Note that the entry points should represent ordinals, rather than true document IDs.
*
* @return the seed entry points or {@code null}.
* @lucene.experimental
*/
default DocIdSetIterator getSeedEntryPoints() {
return null;
}
/**
* KnnCollector.Decorator is the base class for decorators of KnnCollector objects, which extend
* the object with new behaviors.
*
* @lucene.experimental
*/
public abstract static class Decorator implements KnnCollector {
abstract class Decorator implements KnnCollector {
private KnnCollector collector;
public Decorator(KnnCollector collector) {
@ -151,29 +138,5 @@ public interface KnnCollector {
public TopDocs topDocs() {
return collector.topDocs();
}
@Override
public DocIdSetIterator getSeedEntryPoints() {
return collector.getSeedEntryPoints();
}
}
/**
* KnnCollector.Seeded is a KnnCollector decorator that replaces the seedEntryPoints.
*
* @lucene.experimental
*/
public static class Seeded extends Decorator {
private DocIdSetIterator seedEntryPoints;
public Seeded(KnnCollector collector, DocIdSetIterator seedEntryPoints) {
super(collector);
this.seedEntryPoints = seedEntryPoints;
}
@Override
public DocIdSetIterator getSeedEntryPoints() {
return seedEntryPoints;
}
}
}

View File

@ -47,7 +47,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
private final float[] target;
protected final float[] target;
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
@ -73,22 +73,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) {
this(field, target, k, filter, null);
}
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
* given field. <code>target</code> vector.
*
* @param field a field that has been indexed as a {@link KnnFloatVectorField}.
* @param target the target of the search
* @param k the number of documents to find
* @param filter a filter applied before the vector search
* @param seed a query that is executed to seed the vector search
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) {
super(field, k, filter, seed);
super(field, k, filter);
this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target"));
}
@ -96,14 +81,10 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
DocIdSetIterator seedDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
if (seedDocs != null) {
knnCollector = new KnnCollector.Seeded(knnCollector, seedDocs);
}
LeafReader reader = context.reader();
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
if (floatVectorValues == null) {
@ -162,18 +143,4 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
public float[] getTargetCopy() {
return ArrayUtil.copyArray(target);
}
/**
* Returns a new iterator that maps the provided docIds to the vector ordinals.
*
* <p>This method assumes that all docIds have corresponding ordinals.
*
* @lucene.internal
* @lucene.experimental
*/
@Override
protected DocIdSetIterator convertDocIdsToVectorOrdinals(
LeafReader reader, DocIdSetIterator docIds) throws IOException {
return reader.getFloatVectorValues(field).convertDocIdsToVectorOrdinals(docIds);
}
}

View File

@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Objects;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.SeededKnnCollectorManager;
/**
* This is a version of knn byte vector query that provides a query seed to initiate the vector
* search. NOTE: The underlying format is free to ignore the provided seed
*
* @lucene.experimental
*/
public class SeededKnnByteVectorQuery extends KnnByteVectorQuery {
private final Query seed;
private final Weight seedWeight;
/**
* Construct a new SeededKnnFloatVectorQuery instance
*
* @param field knn byte vector field to query
* @param target the query vector
* @param k number of neighbors to return
* @param filter a filter on the neighbors to return
* @param seed a query seed to initiate the vector format search
*/
public SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) {
super(field, target, k, filter);
this.seed = Objects.requireNonNull(seed);
this.seedWeight = null;
}
SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Weight seedWeight) {
super(field, target, k, filter);
this.seed = null;
this.seedWeight = Objects.requireNonNull(seedWeight);
}
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (seedWeight != null) {
return super.rewrite(indexSearcher);
}
BooleanQuery.Builder booleanSeedQueryBuilder =
new BooleanQuery.Builder()
.add(seed, BooleanClause.Occur.MUST)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER);
if (filter != null) {
booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
}
Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
SeededKnnByteVectorQuery rewritten =
new SeededKnnByteVectorQuery(field, target, k, filter, seedWeight);
return rewritten.rewrite(indexSearcher);
}
@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
if (seedWeight == null) {
throw new UnsupportedOperationException("must be rewritten before constructing manager");
}
return new SeededKnnCollectorManager(
super.getKnnCollectorManager(k, searcher),
seedWeight,
k,
leaf -> leaf.getFloatVectorValues(field));
}
}

View File

@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Objects;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.SeededKnnCollectorManager;
/**
* This is a version of knn float vector query that provides a query seed to initiate the vector
* search. NOTE: The underlying format is free to ignore the provided seed
*
* @lucene.experimental
*/
public class SeededKnnFloatVectorQuery extends KnnFloatVectorQuery {
private final Query seed;
private final Weight seedWeight;
/**
* Construct a new SeededKnnFloatVectorQuery instance
*
* @param field knn float vector field to query
* @param target the query vector
* @param k number of neighbors to return
* @param filter a filter on the neighbors to return
* @param seed a query seed to initiate the vector format search
*/
public SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) {
super(field, target, k, filter);
this.seed = Objects.requireNonNull(seed);
this.seedWeight = null;
}
SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Weight seedWeight) {
super(field, target, k, filter);
this.seed = null;
this.seedWeight = Objects.requireNonNull(seedWeight);
}
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (seedWeight != null) {
return super.rewrite(indexSearcher);
}
BooleanQuery.Builder booleanSeedQueryBuilder =
new BooleanQuery.Builder()
.add(seed, BooleanClause.Occur.MUST)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER);
if (filter != null) {
booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
}
Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
SeededKnnFloatVectorQuery rewritten =
new SeededKnnFloatVectorQuery(field, target, k, filter, seedWeight);
return rewritten.rewrite(indexSearcher);
}
@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
if (seedWeight == null) {
throw new UnsupportedOperationException("must be rewritten before constructing manager");
}
return new SeededKnnCollectorManager(
super.getKnnCollectorManager(k, searcher),
seedWeight,
k,
leaf -> leaf.getFloatVectorValues(field));
}
}

View File

@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.knn;
import org.apache.lucene.search.DocIdSetIterator;
/** Provides entry points for the kNN search */
public interface EntryPointProvider {
/** Iterator of valid entry points for the kNN search */
DocIdSetIterator entryPoints();
}

View File

@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.knn;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
public class SeededKnnCollector extends KnnCollector.Decorator implements EntryPointProvider {
private final DocIdSetIterator entryPoints;
public SeededKnnCollector(KnnCollector collector, DocIdSetIterator entryPoints) {
super(collector);
this.entryPoints = entryPoints;
}
@Override
public DocIdSetIterator entryPoints() {
return entryPoints;
}
}

View File

@ -0,0 +1,174 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.knn;
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.search.TopScoreDocCollectorManager;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.IOFunction;
/** A {@link KnnCollectorManager} that collects results with a timeout. */
public class SeededKnnCollectorManager implements KnnCollectorManager {
private final KnnCollectorManager delegate;
private final Weight seedWeight;
private final int k;
private final IOFunction<LeafReader, KnnVectorValues> vectorValuesSupplier;
public SeededKnnCollectorManager(
KnnCollectorManager delegate,
Weight seedWeight,
int k,
IOFunction<LeafReader, KnnVectorValues> vectorValuesSupplier) {
this.delegate = delegate;
this.seedWeight = seedWeight;
this.k = k;
this.vectorValuesSupplier = vectorValuesSupplier;
}
@Override
public KnnCollector newCollector(int visitedLimit, LeafReaderContext ctx) throws IOException {
// Execute the seed query
TopScoreDocCollector seedCollector =
new TopScoreDocCollectorManager(
k /* numHits */, null /* after */, Integer.MAX_VALUE /* totalHitsThreshold */)
.newCollector();
final LeafReader leafReader = ctx.reader();
final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx);
if (leafCollector != null) {
try {
BulkScorer scorer = seedWeight.bulkScorer(ctx);
if (scorer != null) {
scorer.score(
leafCollector,
leafReader.getLiveDocs(),
0 /* min */,
DocIdSetIterator.NO_MORE_DOCS /* max */);
}
leafCollector.finish();
} catch (
@SuppressWarnings("unused")
CollectionTerminatedException e) {
}
}
TopDocs seedTopDocs = seedCollector.topDocs();
KnnVectorValues vectorValues = vectorValuesSupplier.apply(leafReader);
if (seedTopDocs.totalHits.value() == 0 || vectorValues == null) {
return delegate.newCollector(visitedLimit, ctx);
}
KnnVectorValues.DocIndexIterator indexIterator = vectorValues.iterator();
DocIdSetIterator seedDocs = new MappedDISI(indexIterator, new TopDocsDISI(seedTopDocs));
return new SeededKnnCollector(delegate.newCollector(visitedLimit, ctx), seedDocs);
}
public static class MappedDISI extends DocIdSetIterator {
KnnVectorValues.DocIndexIterator indexedDISI;
DocIdSetIterator sourceDISI;
public MappedDISI(KnnVectorValues.DocIndexIterator indexedDISI, DocIdSetIterator sourceDISI) {
this.indexedDISI = indexedDISI;
this.sourceDISI = sourceDISI;
}
/**
* Advances the source iterator to the first document number that is greater than or equal to
* the provided target and returns the corresponding index.
*/
@Override
public int advance(int target) throws IOException {
int newTarget = sourceDISI.advance(target);
if (newTarget != NO_MORE_DOCS) {
indexedDISI.advance(newTarget);
}
return docID();
}
@Override
public long cost() {
return this.sourceDISI.cost();
}
@Override
public int docID() {
if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) {
return NO_MORE_DOCS;
}
return indexedDISI.index();
}
/** Advances to the next document in the source iterator and returns the corresponding index. */
@Override
public int nextDoc() throws IOException {
int newTarget = sourceDISI.nextDoc();
if (newTarget != NO_MORE_DOCS) {
indexedDISI.advance(newTarget);
}
return docID();
}
}
private static class TopDocsDISI extends DocIdSetIterator {
private final int[] sortedDocIds;
private int idx = -1;
private TopDocsDISI(TopDocs topDocs) {
sortedDocIds = new int[topDocs.scoreDocs.length];
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
sortedDocIds[i] = topDocs.scoreDocs[i].doc;
}
Arrays.sort(sortedDocIds);
}
@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
}
@Override
public long cost() {
return sortedDocIds.length;
}
@Override
public int docID() {
if (idx == -1) {
return -1;
} else if (idx >= sortedDocIds.length) {
return DocIdSetIterator.NO_MORE_DOCS;
} else {
return sortedDocIds[idx];
}
}
@Override
public int nextDoc() {
idx += 1;
return docID();
}
}
}

View File

@ -20,10 +20,10 @@ package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.search.knn.EntryPointProvider;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
@ -67,25 +67,23 @@ public class HnswGraphSearcher {
public static void search(
RandomVectorScorer scorer, KnnCollector knnCollector, HnswGraph graph, Bits acceptOrds)
throws IOException {
ArrayList<Integer> entryPointOrdInts = null;
DocIdSetIterator entryPoints = knnCollector.getSeedEntryPoints();
if (entryPoints != null) {
entryPointOrdInts = new ArrayList<Integer>();
int entryPointOrdInt;
while ((entryPointOrdInt = entryPoints.nextDoc()) != NO_MORE_DOCS) {
entryPointOrdInts.add(entryPointOrdInt);
}
}
HnswGraphSearcher graphSearcher =
new HnswGraphSearcher(
new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph)));
if (entryPointOrdInts == null || entryPointOrdInts.isEmpty()) {
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
} else {
int[] entryPointOrdIntsArr = entryPointOrdInts.stream().mapToInt(Integer::intValue).toArray();
final int[] entryPoints;
if (knnCollector instanceof EntryPointProvider epp) {
DocIdSetIterator eps = epp.entryPoints();
entryPoints = new int[(int) eps.cost()];
int idx = 0;
int entryPointOrdInt;
while ((entryPointOrdInt = eps.nextDoc()) != NO_MORE_DOCS) {
entryPoints[idx++] = entryPointOrdInt;
}
// We use provided entry point ordinals to search the complete graph (level 0)
graphSearcher.searchLevel(
knnCollector, scorer, 0 /* level */, entryPointOrdIntsArr, graph, acceptOrds);
knnCollector, scorer, 0 /* level */, entryPoints, graph, acceptOrds);
} else {
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
}
}

View File

@ -28,6 +28,7 @@ import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.SeededKnnFloatVectorQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
@ -108,7 +109,7 @@ public class TestManyKnnDocs extends LuceneTestCase {
vector[1] = 1;
TopDocs docs =
searcher.search(
new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
assertEquals(5, docs.scoreDocs.length);
String s = "";
for (int j = 0; j < docs.scoreDocs.length - 1; j++) {
@ -131,7 +132,7 @@ public class TestManyKnnDocs extends LuceneTestCase {
vector[1] = 1;
TopDocs docs =
searcher.search(
new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
assertEquals(5, docs.scoreDocs.length);
String s = "";
for (int j = 0; j < docs.scoreDocs.length - 1; j++) {
@ -154,7 +155,7 @@ public class TestManyKnnDocs extends LuceneTestCase {
vector[1] = 1;
TopDocs docs =
searcher.search(
new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
assertEquals(5, docs.scoreDocs.length);
Document d = searcher.storedFields().document(docs.scoreDocs[0].doc);
String s = "";
@ -177,7 +178,7 @@ public class TestManyKnnDocs extends LuceneTestCase {
vector[1] = 1;
TopDocs docs =
searcher.search(
new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
assertEquals(5, docs.scoreDocs.length);
Document d = searcher.storedFields().document(docs.scoreDocs[0].doc);
String s = "";

View File

@ -66,15 +66,9 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
abstract AbstractKnnVectorQuery getKnnVectorQuery(
String field, float[] query, int k, Query queryFilter);
abstract AbstractKnnVectorQuery getKnnVectorQuery(
String field, float[] query, int k, Query queryFilter, Query seedQuery);
abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery(
String field, float[] query, int k, Query queryFilter);
abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery(
String field, float[] query, int k, Query queryFilter, Query seedQuery);
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k) {
return getKnnVectorQuery(field, query, k, null);
}
@ -613,91 +607,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
}
}
/** Tests with random vectors and a random seed. Uses RandomIndexWriter. */
public void testRandomWithSeed() throws IOException {
int numDocs = 1000;
int dimension = atLeast(5);
int numIters = atLeast(10);
int numDocsWithVector = 0;
try (Directory d = newDirectoryForTest()) {
// Always use the default kNN format to have predictable behavior around when it hits
// visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
// format
// implementation.
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
if (random().nextBoolean()) {
// Randomly skip some vectors to test the mapping from docid to ordinals
doc.add(getKnnVectorField("field", randomVector(dimension)));
numDocsWithVector += 1;
}
doc.add(new NumericDocValuesField("tag", i));
doc.add(new IntPoint("tag", i));
w.addDocument(doc);
}
w.forceMerge(1);
w.close();
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
for (int i = 0; i < numIters; i++) {
int k = random().nextInt(80) + 1;
int n = random().nextInt(100) + 1;
// we may get fewer results than requested if there are deletions, but this test doesn't
// check that
assert reader.hasDeletions() == false;
// All documents as seeds
Query seed1 = new MatchAllDocsQuery();
Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery();
AbstractKnnVectorQuery query =
getKnnVectorQuery("field", randomVector(dimension), k, filter, seed1);
TopDocs results = searcher.search(query, n);
int expected = Math.min(Math.min(n, k), numDocsWithVector);
assertEquals(expected, results.scoreDocs.length);
assertTrue(results.totalHits.value() >= results.scoreDocs.length);
// verify the results are in descending score order
float last = Float.MAX_VALUE;
for (ScoreDoc scoreDoc : results.scoreDocs) {
assertTrue(scoreDoc.score <= last);
last = scoreDoc.score;
}
// Restrictive seed query -- 6 documents
Query seed2 = IntPoint.newRangeQuery("tag", 1, 6);
query = getKnnVectorQuery("field", randomVector(dimension), k, null, seed2);
results = searcher.search(query, n);
expected = Math.min(Math.min(n, k), reader.numDocs());
assertEquals(expected, results.scoreDocs.length);
assertTrue(results.totalHits.value() >= results.scoreDocs.length);
// verify the results are in descending score order
last = Float.MAX_VALUE;
for (ScoreDoc scoreDoc : results.scoreDocs) {
assertTrue(scoreDoc.score <= last);
last = scoreDoc.score;
}
// No seed documents -- falls back on full approx search
Query seed3 = new MatchNoDocsQuery();
query = getKnnVectorQuery("field", randomVector(dimension), k, null, seed3);
results = searcher.search(query, n);
expected = Math.min(Math.min(n, k), reader.numDocs());
assertEquals(expected, results.scoreDocs.length);
assertTrue(results.totalHits.value() >= results.scoreDocs.length);
// verify the results are in descending score order
last = Float.MAX_VALUE;
for (ScoreDoc scoreDoc : results.scoreDocs) {
assertTrue(scoreDoc.score <= last);
last = scoreDoc.score;
}
}
}
}
}
/** Tests filtering when all vectors have the same score. */
public void testFilterWithSameScore() throws IOException {
int numDocs = 100;

View File

@ -34,21 +34,9 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter);
}
@Override
AbstractKnnVectorQuery getKnnVectorQuery(
String field, float[] query, int k, Query queryFilter, Query seedQuery) {
return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter, seedQuery);
}
@Override
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, null);
}
@Override
AbstractKnnVectorQuery getThrowingKnnVectorQuery(
String field, float[] vec, int k, Query query, Query seedQuery) {
return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, seedQuery);
return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query);
}
@Override
@ -73,7 +61,7 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN);
}
private static byte[] floatToBytes(float[] query) {
static byte[] floatToBytes(float[] query) {
byte[] bytes = new byte[query.length];
for (int i = 0; i < query.length; i++) {
assert query[i] <= Byte.MAX_VALUE && query[i] >= Byte.MIN_VALUE && (query[i] % 1) == 0
@ -121,10 +109,10 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
}
}
private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, Query seed) {
super(field, target, k, filter, seed);
public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
super(field, target, k, filter);
}
@Override

View File

@ -50,21 +50,9 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
return new KnnFloatVectorQuery(field, query, k, queryFilter);
}
@Override
KnnFloatVectorQuery getKnnVectorQuery(
String field, float[] query, int k, Query queryFilter, Query seedQuery) {
return new KnnFloatVectorQuery(field, query, k, queryFilter, seedQuery);
}
@Override
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
return new ThrowingKnnVectorQuery(field, vec, k, query, null);
}
@Override
AbstractKnnVectorQuery getThrowingKnnVectorQuery(
String field, float[] vec, int k, Query query, Query seedQuery) {
return new ThrowingKnnVectorQuery(field, vec, k, query, seedQuery);
return new ThrowingKnnVectorQuery(field, vec, k, query);
}
@Override
@ -271,10 +259,10 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
}
}
private static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery {
static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery {
public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) {
super(field, target, k, filter, seed);
public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) {
super(field, target, k, filter);
}
@Override

View File

@ -0,0 +1,181 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import static org.apache.lucene.search.TestKnnByteVectorQuery.floatToBytes;
import java.io.IOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.TestVectorUtil;
public class TestSeededKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
private static final Query MATCH_NONE = new MatchNoDocsQuery();
@Override
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
return new SeededKnnByteVectorQuery(field, floatToBytes(query), k, queryFilter, MATCH_NONE);
}
@Override
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, MATCH_NONE);
}
@Override
float[] randomVector(int dim) {
byte[] b = TestVectorUtil.randomVectorBytes(dim);
float[] v = new float[b.length];
int vi = 0;
for (int i = 0; i < v.length; i++) {
v[vi++] = b[i];
}
return v;
}
@Override
Field getKnnVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
return new KnnByteVectorField(name, floatToBytes(vector), similarityFunction);
}
@Override
Field getKnnVectorField(String name, float[] vector) {
return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN);
}
/** Tests with random vectors and a random seed. Uses RandomIndexWriter. */
public void testRandomWithSeed() throws IOException {
int numDocs = 1000;
int dimension = atLeast(5);
int numIters = atLeast(10);
int numDocsWithVector = 0;
try (Directory d = newDirectoryForTest()) {
// Always use the default kNN format to have predictable behavior around when it hits
// visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
// format
// implementation.
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
if (random().nextBoolean()) {
// Randomly skip some vectors to test the mapping from docid to ordinals
doc.add(getKnnVectorField("field", randomVector(dimension)));
numDocsWithVector += 1;
}
doc.add(new NumericDocValuesField("tag", i));
doc.add(new IntPoint("tag", i));
w.addDocument(doc);
}
w.forceMerge(1);
w.close();
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
for (int i = 0; i < numIters; i++) {
int k = random().nextInt(80) + 1;
int n = random().nextInt(100) + 1;
// we may get fewer results than requested if there are deletions, but this test doesn't
// check that
assert reader.hasDeletions() == false;
// All documents as seeds
Query seed1 = new MatchAllDocsQuery();
Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery();
SeededKnnByteVectorQuery query =
new SeededKnnByteVectorQuery(
"field", floatToBytes(randomVector(dimension)), k, filter, seed1);
TopDocs results = searcher.search(query, n);
int expected = Math.min(Math.min(n, k), numDocsWithVector);
assertEquals(expected, results.scoreDocs.length);
assertTrue(results.totalHits.value() >= results.scoreDocs.length);
// verify the results are in descending score order
float last = Float.MAX_VALUE;
for (ScoreDoc scoreDoc : results.scoreDocs) {
assertTrue(scoreDoc.score <= last);
last = scoreDoc.score;
}
// Restrictive seed query -- 6 documents
Query seed2 = IntPoint.newRangeQuery("tag", 1, 6);
query =
new SeededKnnByteVectorQuery(
"field", floatToBytes(randomVector(dimension)), k, null, seed2);
results = searcher.search(query, n);
expected = Math.min(Math.min(n, k), reader.numDocs());
assertEquals(expected, results.scoreDocs.length);
assertTrue(results.totalHits.value() >= results.scoreDocs.length);
// verify the results are in descending score order
last = Float.MAX_VALUE;
for (ScoreDoc scoreDoc : results.scoreDocs) {
assertTrue(scoreDoc.score <= last);
last = scoreDoc.score;
}
// No seed documents -- falls back on full approx search
Query seed3 = new MatchNoDocsQuery();
query =
new SeededKnnByteVectorQuery(
"field", floatToBytes(randomVector(dimension)), k, null, seed3);
results = searcher.search(query, n);
expected = Math.min(Math.min(n, k), reader.numDocs());
assertEquals(expected, results.scoreDocs.length);
assertTrue(results.totalHits.value() >= results.scoreDocs.length);
// verify the results are in descending score order
last = Float.MAX_VALUE;
for (ScoreDoc scoreDoc : results.scoreDocs) {
assertTrue(scoreDoc.score <= last);
last = scoreDoc.score;
}
}
}
}
}
private static class ThrowingKnnVectorQuery extends SeededKnnByteVectorQuery {
public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, Query seed) {
super(field, target, k, filter, seed);
}
@Override
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
throw new UnsupportedOperationException("exact search is not supported");
}
@Override
public String toString(String field) {
return null;
}
}
}

View File

@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.TestVectorUtil;
public class TestSeededKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
private static final Query MATCH_NONE = new MatchNoDocsQuery();
@Override
KnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
return new SeededKnnFloatVectorQuery(field, query, k, queryFilter, MATCH_NONE);
}
@Override
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
return new ThrowingKnnVectorQuery(field, vec, k, query, MATCH_NONE);
}
@Override
float[] randomVector(int dim) {
return TestVectorUtil.randomVector(dim);
}
@Override
Field getKnnVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
return new KnnFloatVectorField(name, vector, similarityFunction);
}
@Override
Field getKnnVectorField(String name, float[] vector) {
return new KnnFloatVectorField(name, vector);
}
/** Tests with random vectors and a random seed. Uses RandomIndexWriter. */
public void testRandomWithSeed() throws IOException {
int numDocs = 1000;
int dimension = atLeast(5);
int numIters = atLeast(10);
int numDocsWithVector = 0;
try (Directory d = newDirectoryForTest()) {
// Always use the default kNN format to have predictable behavior around when it hits
// visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
// format
// implementation.
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
if (random().nextBoolean()) {
// Randomly skip some vectors to test the mapping from docid to ordinals
doc.add(getKnnVectorField("field", randomVector(dimension)));
numDocsWithVector += 1;
}
doc.add(new NumericDocValuesField("tag", i));
doc.add(new IntPoint("tag", i));
w.addDocument(doc);
}
w.forceMerge(1);
w.close();
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
for (int i = 0; i < numIters; i++) {
int k = random().nextInt(80) + 1;
int n = random().nextInt(100) + 1;
// we may get fewer results than requested if there are deletions, but this test doesn't
// check that
assert reader.hasDeletions() == false;
// All documents as seeds
Query seed1 = new MatchAllDocsQuery();
Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery();
AbstractKnnVectorQuery query =
new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, filter, seed1);
TopDocs results = searcher.search(query, n);
int expected = Math.min(Math.min(n, k), numDocsWithVector);
assertEquals(expected, results.scoreDocs.length);
assertTrue(results.totalHits.value() >= results.scoreDocs.length);
// verify the results are in descending score order
float last = Float.MAX_VALUE;
for (ScoreDoc scoreDoc : results.scoreDocs) {
assertTrue(scoreDoc.score <= last);
last = scoreDoc.score;
}
// Restrictive seed query -- 6 documents
Query seed2 = IntPoint.newRangeQuery("tag", 1, 6);
query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed2);
results = searcher.search(query, n);
expected = Math.min(Math.min(n, k), reader.numDocs());
assertEquals(expected, results.scoreDocs.length);
assertTrue(results.totalHits.value() >= results.scoreDocs.length);
// verify the results are in descending score order
last = Float.MAX_VALUE;
for (ScoreDoc scoreDoc : results.scoreDocs) {
assertTrue(scoreDoc.score <= last);
last = scoreDoc.score;
}
// No seed documents -- falls back on full approx search
Query seed3 = new MatchNoDocsQuery();
query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed3);
results = searcher.search(query, n);
expected = Math.min(Math.min(n, k), reader.numDocs());
assertEquals(expected, results.scoreDocs.length);
assertTrue(results.totalHits.value() >= results.scoreDocs.length);
// verify the results are in descending score order
last = Float.MAX_VALUE;
for (ScoreDoc scoreDoc : results.scoreDocs) {
assertTrue(scoreDoc.score <= last);
last = scoreDoc.score;
}
}
}
}
}
private static class ThrowingKnnVectorQuery extends SeededKnnFloatVectorQuery {
public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) {
super(field, target, k, filter, seed);
}
@Override
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
throw new UnsupportedOperationException("exact search is not supported");
}
@Override
public String toString(String field) {
return null;
}
}
}

View File

@ -140,7 +140,6 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
DocIdSetIterator seedDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {

View File

@ -139,7 +139,6 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
DocIdSetIterator seedDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {