Speedup concurrent multi-segment HNWS graph search (#12962)

Speedup concurrent multi-segment HNWS graph search by exchanging 
the global top candidated collected so far across segments. These global top 
candidates set the minimum threshold that new candidates need to pass
 to be considered. This allows earlier stopping for segments that don't have 
good candidates.
This commit is contained in:
Mayya Sharipova 2024-02-06 09:16:06 -05:00 committed by GitHub
parent 635d09001a
commit d095ed02a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 901 additions and 22 deletions

View File

@ -244,6 +244,9 @@ Optimizations
* GITHUB#13036 Optimize counts on two clause term disjunctions. (Adrien Grand, Johannes Fredén)
* GITHUB#12962: Speedup concurrent multi-segment HNWS graph search (Mayya Sharipova, Tom Veasey)
Bug Fixes
---------------------
* GITHUB#12866: Prevent extra similarity computation for single-level HNSW graphs. (Kaival Parikh)

View File

@ -42,6 +42,7 @@ module org.apache.lucene.core {
exports org.apache.lucene.search;
exports org.apache.lucene.search.comparators;
exports org.apache.lucene.search.similarities;
exports org.apache.lucene.search.knn;
exports org.apache.lucene.store;
exports org.apache.lucene.util;
exports org.apache.lucene.util.automaton;

View File

@ -23,7 +23,7 @@ package org.apache.lucene.search;
*/
public abstract class AbstractKnnCollector implements KnnCollector {
private long visitedCount;
protected long visitedCount;
private final long visitLimit;
private final int k;

View File

@ -29,6 +29,8 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.TopKnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
@ -79,11 +81,12 @@ abstract class AbstractKnnVectorQuery extends Query {
filterWeight = null;
}
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext context : leafReaderContexts) {
tasks.add(() -> searchLeaf(context, filterWeight));
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
}
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
@ -95,8 +98,10 @@ abstract class AbstractKnnVectorQuery extends Query {
return createRewrittenQuery(reader, topK);
}
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight);
private TopDocs searchLeaf(
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
@ -105,12 +110,14 @@ abstract class AbstractKnnVectorQuery extends Query {
return results;
}
private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
private TopDocs getLeafResults(
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();
if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
}
Scorer scorer = filterWeight.scorer(ctx);
@ -128,7 +135,7 @@ abstract class AbstractKnnVectorQuery extends Query {
}
// Perform the approximate kNN search
TopDocs results = approximateSearch(ctx, acceptDocs, cost);
TopDocs results = approximateSearch(ctx, acceptDocs, cost, knnCollectorManager);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
return results;
} else {
@ -155,8 +162,16 @@ abstract class AbstractKnnVectorQuery extends Query {
}
}
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new TopKnnCollectorManager(k, searcher);
}
protected abstract TopDocs approximateSearch(
LeafReaderContext context, Bits acceptDocs, int visitedLimit) throws IOException;
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException;
abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi)
throws IOException;

View File

@ -24,6 +24,7 @@ import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
@ -75,10 +76,23 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
}
@Override
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
TopDocs results =
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return TopDocsCollector.EMPTY_TOPDOCS;
}
if (Math.min(knnCollector.k(), context.reader().getByteVectorValues(fi.name).size()) == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
return results != null ? results : NO_RESULTS;
}

View File

@ -24,6 +24,7 @@ import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
@ -76,10 +77,23 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
}
@Override
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
TopDocs results =
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return TopDocsCollector.EMPTY_TOPDOCS;
}
if (Math.min(knnCollector.k(), context.reader().getFloatVectorValues(fi.name).size()) == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
return results != null ? results : NO_RESULTS;
}

View File

@ -25,9 +25,9 @@ import org.apache.lucene.util.hnsw.NeighborQueue;
*
* @lucene.experimental
*/
public final class TopKnnCollector extends AbstractKnnCollector {
public class TopKnnCollector extends AbstractKnnCollector {
private final NeighborQueue queue;
protected final NeighborQueue queue;
/**
* @param k the number of neighbors to collect

View File

@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.knn;
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.KnnCollector;
/**
* KnnCollectorManager responsible for creating {@link KnnCollector} instances. Useful to create
* {@link KnnCollector} instances that share global state across leaves, such a global queue of
* results collected so far.
*/
public interface KnnCollectorManager {
/**
* Return a new {@link KnnCollector} instance.
*
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @param context the leaf reader context
*/
KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException;
}

View File

@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.knn;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
import org.apache.lucene.util.hnsw.FloatHeap;
/**
* MultiLeafTopKnnCollector is a specific KnnCollector that can exchange the top collected results
* across segments through a shared global queue.
*
* @lucene.experimental
*/
public final class MultiLeafTopKnnCollector extends TopKnnCollector {
// greediness of globally non-competitive search: (0,1]
private static final float DEFAULT_GREEDINESS = 0.9f;
// the global queue of the highest similarities collected so far across all segments
private final BlockingFloatHeap globalSimilarityQueue;
// the local queue of the highest similarities if we are not competitive globally
// the size of this queue is defined by greediness
private final FloatHeap nonCompetitiveQueue;
private final float greediness;
// the queue of the local similarities to periodically update with the global queue
private final FloatHeap updatesQueue;
// interval to synchronize the local and global queues, as a number of visited vectors
private final int interval = 0xff; // 255
private boolean kResultsCollected = false;
private float cachedGlobalMinSim = Float.NEGATIVE_INFINITY;
/**
* @param k the number of neighbors to collect
* @param visitLimit how many vector nodes the results are allowed to visit
*/
public MultiLeafTopKnnCollector(int k, int visitLimit, BlockingFloatHeap globalSimilarityQueue) {
super(k, visitLimit);
this.greediness = DEFAULT_GREEDINESS;
this.globalSimilarityQueue = globalSimilarityQueue;
this.nonCompetitiveQueue = new FloatHeap(Math.max(1, Math.round((1 - greediness) * k)));
this.updatesQueue = new FloatHeap(k);
}
@Override
public boolean collect(int docId, float similarity) {
boolean localSimUpdated = queue.insertWithOverflow(docId, similarity);
boolean firstKResultsCollected = (kResultsCollected == false && queue.size() == k());
if (firstKResultsCollected) {
kResultsCollected = true;
}
updatesQueue.offer(similarity);
boolean globalSimUpdated = nonCompetitiveQueue.offer(similarity);
if (kResultsCollected) {
// as we've collected k results, we can start do periodic updates with the global queue
if (firstKResultsCollected || (visitedCount & interval) == 0) {
cachedGlobalMinSim = globalSimilarityQueue.offer(updatesQueue.getHeap());
updatesQueue.clear();
globalSimUpdated = true;
}
}
return localSimUpdated || globalSimUpdated;
}
@Override
public float minCompetitiveSimilarity() {
if (kResultsCollected == false) {
return Float.NEGATIVE_INFINITY;
}
return Math.max(queue.topScore(), Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim));
}
@Override
public TopDocs topDocs() {
assert queue.size() <= k() : "Tried to collect more results than the maximum number allowed";
ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
for (int i = 1; i <= scoreDocs.length; i++) {
scoreDocs[scoreDocs.length - i] = new ScoreDoc(queue.topNode(), queue.topScore());
queue.pop();
}
TotalHits.Relation relation =
earlyTerminated()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
}
@Override
public String toString() {
return "MultiLeafTopKnnCollector[k=" + k() + ", size=" + queue.size() + "]";
}
}

View File

@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.knn;
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
/**
* TopKnnCollectorManager responsible for creating {@link TopKnnCollector} instances. When
* concurrency is supported, the {@link BlockingFloatHeap} is used to track the global top scores
* collected across all leaves.
*/
public class TopKnnCollectorManager implements KnnCollectorManager {
// the number of docs to collect
private final int k;
// the global score queue used to track the top scores collected across all leaves
private final BlockingFloatHeap globalScoreQueue;
public TopKnnCollectorManager(int k, IndexSearcher indexSearcher) {
boolean isMultiSegments = indexSearcher.getIndexReader().leaves().size() > 1;
this.k = k;
this.globalScoreQueue = isMultiSegments ? new BlockingFloatHeap(k) : null;
}
/**
* Return a new {@link TopKnnCollector} instance.
*
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @param context the leaf reader context
*/
@Override
public TopKnnCollector newCollector(int visitedLimit, LeafReaderContext context)
throws IOException {
if (globalScoreQueue == null) {
return new TopKnnCollector(k, visitedLimit);
} else {
return new MultiLeafTopKnnCollector(k, visitedLimit, globalScoreQueue);
}
}
}

View File

@ -0,0 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/** Classes related to vector search: knn and vector fields. */
package org.apache.lucene.search.knn;

View File

@ -0,0 +1,192 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.util.concurrent.locks.ReentrantLock;
/**
* A blocking bounded min heap that stores floats. The top element is the lowest value of the heap.
*
* <p>A primitive priority queue that maintains a partial ordering of its elements such that the
* least element can always be found in constant time. Implementation is based on {@link
* org.apache.lucene.util.LongHeap}
*
* @lucene.internal
*/
public final class BlockingFloatHeap {
private final int maxSize;
private final float[] heap;
private final ReentrantLock lock;
private int size;
public BlockingFloatHeap(int maxSize) {
this.maxSize = maxSize;
this.heap = new float[maxSize + 1];
this.lock = new ReentrantLock();
this.size = 0;
}
/**
* Inserts a value into this heap.
*
* <p>If the number of values would exceed the heap's maxSize, the least value is discarded
*
* @param value the value to add
* @return the new 'top' element in the queue.
*/
public float offer(float value) {
lock.lock();
try {
if (size < maxSize) {
push(value);
return heap[1];
} else {
if (value >= heap[1]) {
updateTop(value);
}
return heap[1];
}
} finally {
lock.unlock();
}
}
/**
* Inserts array of values into this heap.
*
* <p>Values must be sorted in ascending order.
*
* @param values a set of values to insert, must be sorted in ascending order
* @return the new 'top' element in the queue.
*/
public float offer(float[] values) {
lock.lock();
try {
for (int i = values.length - 1; i >= 0; i--) {
if (size < maxSize) {
push(values[i]);
} else {
if (values[i] >= heap[1]) {
updateTop(values[i]);
} else {
break;
}
}
}
return heap[1];
} finally {
lock.unlock();
}
}
/**
* Removes and returns the head of the heap
*
* @return the head of the heap, the smallest value
* @throws IllegalStateException if the heap is empty
*/
public float poll() {
if (size > 0) {
float result;
lock.lock();
try {
result = heap[1]; // save first value
heap[1] = heap[size]; // move last to first
size--;
downHeap(1); // adjust heap
} finally {
lock.unlock();
}
return result;
} else {
throw new IllegalStateException("The heap is empty");
}
}
/**
* Retrieves, but does not remove, the head of this heap.
*
* @return the head of the heap, the smallest value
*/
public float peek() {
lock.lock();
try {
return heap[1];
} finally {
lock.unlock();
}
}
/**
* Returns the number of elements in this heap.
*
* @return the number of elements in this heap
*/
public int size() {
lock.lock();
try {
return size;
} finally {
lock.unlock();
}
}
private void push(float element) {
size++;
heap[size] = element;
upHeap(size);
}
private float updateTop(float value) {
heap[1] = value;
downHeap(1);
return heap[1];
}
private void downHeap(int i) {
float value = heap[i]; // save top value
int j = i << 1; // find smaller child
int k = j + 1;
if (k <= size && heap[k] < heap[j]) {
j = k;
}
while (j <= size && heap[j] < value) {
heap[i] = heap[j]; // shift up child
i = j;
j = i << 1;
k = j + 1;
if (k <= size && heap[k] < heap[j]) {
j = k;
}
}
heap[i] = value; // install saved value
}
private void upHeap(int origPos) {
int i = origPos;
float value = heap[i]; // save bottom value
int j = i >>> 1;
while (j > 0 && value < heap[j]) {
heap[i] = heap[j]; // shift parents down
i = j;
j = j >>> 1;
}
heap[i] = value; // install saved value
}
}

View File

@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
/**
* A bounded min heap that stores floats. The top element is the lowest value of the heap.
*
* <p>A primitive priority queue that maintains a partial ordering of its elements such that the
* least element can always be found in constant time. Implementation is based on {@link
* org.apache.lucene.util.LongHeap}
*
* @lucene.internal
*/
public final class FloatHeap {
private final int maxSize;
private final float[] heap;
private int size;
public FloatHeap(int maxSize) {
this.maxSize = maxSize;
this.heap = new float[maxSize + 1];
this.size = 0;
}
/**
* Inserts a value into this heap.
*
* <p>If the number of values would exceed the heap's maxSize, the least value is discarded
*
* @param value the value to add
* @return whether the value was added (unless the heap is full, or the new value is less than the
* top value)
*/
public boolean offer(float value) {
if (size >= maxSize) {
if (value < heap[1]) {
return false;
}
updateTop(value);
return true;
}
push(value);
return true;
}
public float[] getHeap() {
float[] result = new float[size];
System.arraycopy(this.heap, 1, result, 0, size);
return result;
}
/**
* Removes and returns the head of the heap
*
* @return the head of the heap, the smallest value
* @throws IllegalStateException if the heap is empty
*/
public float poll() {
if (size > 0) {
float result;
result = heap[1]; // save first value
heap[1] = heap[size]; // move last to first
size--;
downHeap(1); // adjust heap
return result;
} else {
throw new IllegalStateException("The heap is empty");
}
}
/**
* Retrieves, but does not remove, the head of this heap.
*
* @return the head of the heap, the smallest value
*/
public float peek() {
return heap[1];
}
/**
* Returns the number of elements in this heap.
*
* @return the number of elements in this heap
*/
public int size() {
return size;
}
public void clear() {
size = 0;
}
private void push(float element) {
size++;
heap[size] = element;
upHeap(size);
}
private float updateTop(float value) {
heap[1] = value;
downHeap(1);
return heap[1];
}
private void downHeap(int i) {
float value = heap[i]; // save top value
int j = i << 1; // find smaller child
int k = j + 1;
if (k <= size && heap[k] < heap[j]) {
j = k;
}
while (j <= size && heap[j] < value) {
heap[i] = heap[j]; // shift up child
i = j;
j = i << 1;
k = j + 1;
if (k <= size && heap[k] < heap[j]) {
j = k;
}
}
heap[i] = value; // install saved value
}
private void upHeap(int origPos) {
int i = origPos;
float value = heap[i]; // save bottom value
int j = i >>> 1;
while (j > 0 && value < heap[j]) {
heap[i] = heap[j]; // shift parents down
i = j;
j = j >>> 1;
}
heap[i] = value; // install saved value
}
}

View File

@ -547,6 +547,7 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
5,
leaf.getLiveDocs(),
Integer.MAX_VALUE));
} else {
DocIdSetIterator iter = leaf.getByteVectorValues("vector");
scanAndRetrieve(leaf, iter);

View File

@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
import java.util.concurrent.CountDownLatch;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.SuppressForbidden;
public class TestBlockingFloatHeap extends LuceneTestCase {
public void testBasicOperations() {
BlockingFloatHeap heap = new BlockingFloatHeap(3);
heap.offer(2);
heap.offer(4);
heap.offer(1);
heap.offer(3);
assertEquals(3, heap.size());
assertEquals(2, heap.peek(), 0);
assertEquals(2, heap.poll(), 0);
assertEquals(3, heap.poll(), 0);
assertEquals(4, heap.poll(), 0);
assertEquals(0, heap.size(), 0);
}
public void testBasicOperations2() {
int size = atLeast(10);
BlockingFloatHeap heap = new BlockingFloatHeap(size);
double sum = 0, sum2 = 0;
for (int i = 0; i < size; i++) {
float next = random().nextFloat(100f);
sum += next;
heap.offer(next);
}
float last = Float.NEGATIVE_INFINITY;
for (long i = 0; i < size; i++) {
float next = heap.poll();
assertTrue(next >= last);
last = next;
sum2 += last;
}
assertEquals(sum, sum2, 0.01);
}
@SuppressForbidden(reason = "Thread sleep")
public void testMultipleThreads() throws Exception {
Thread[] threads = new Thread[randomIntBetween(3, 20)];
final CountDownLatch latch = new CountDownLatch(1);
BlockingFloatHeap globalHeap = new BlockingFloatHeap(1);
for (int i = 0; i < threads.length; i++) {
threads[i] =
new Thread(
() -> {
try {
latch.await();
int numIterations = randomIntBetween(10, 100);
float bottomValue = 0;
while (numIterations-- > 0) {
bottomValue += randomIntBetween(0, 5);
globalHeap.offer(bottomValue);
Thread.sleep(randomIntBetween(0, 50));
float globalBottomValue = globalHeap.peek();
assertTrue(globalBottomValue >= bottomValue);
bottomValue = globalBottomValue;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
});
threads[i].start();
}
latch.countDown();
for (Thread t : threads) {
t.join();
}
}
}

View File

@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import org.apache.lucene.tests.util.LuceneTestCase;
public class TestFloatHeap extends LuceneTestCase {
public void testBasicOperations() {
FloatHeap heap = new FloatHeap(3);
heap.offer(2);
heap.offer(4);
heap.offer(1);
heap.offer(3);
assertEquals(3, heap.size());
assertEquals(2, heap.peek(), 0);
assertEquals(2, heap.poll(), 0);
assertEquals(3, heap.poll(), 0);
assertEquals(4, heap.poll(), 0);
assertEquals(0, heap.size(), 0);
}
public void testBasicOperations2() {
int size = atLeast(10);
FloatHeap heap = new FloatHeap(size);
double sum = 0, sum2 = 0;
for (int i = 0; i < size; i++) {
float next = random().nextFloat(100f);
sum += next;
heap.offer(next);
}
float last = Float.NEGATIVE_INFINITY;
for (long i = 0; i < size; i++) {
float next = heap.poll();
assertTrue(next >= last);
last = next;
sum2 += last;
}
assertEquals(sum, sum2, 0.01);
}
public void testClear() {
FloatHeap heap = new FloatHeap(3);
heap.offer(20);
heap.offer(40);
heap.offer(30);
assertEquals(3, heap.size());
assertEquals(20, heap.peek(), 0);
heap.clear();
assertEquals(0, heap.size(), 0);
assertEquals(20, heap.peek(), 0);
heap.offer(15);
heap.offer(35);
assertEquals(2, heap.size());
assertEquals(15, heap.peek(), 0);
assertEquals(15, heap.poll(), 0);
assertEquals(35, heap.poll(), 0);
assertEquals(0, heap.size(), 0);
}
}

View File

@ -26,6 +26,7 @@ import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.Query;
@ -33,6 +34,7 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
@ -123,7 +125,16 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
}
@Override
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter);
}
@Override
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {

View File

@ -26,6 +26,7 @@ import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
@ -33,6 +34,7 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
@ -123,14 +125,21 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
}
@Override
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter);
}
@Override
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
if (collector == null) {
return NO_RESULTS;
}
KnnCollector collector =
new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
context.reader().searchNearestVectors(field, query, collector, acceptDocs);
return collector.topDocs();
}

View File

@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.join;
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.BitSet;
/**
* DiversifyingNearestChildrenKnnCollectorManager responsible for creating {@link
* DiversifyingNearestChildrenKnnCollector} instances.
*/
public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollectorManager {
// the number of docs to collect
private final int k;
// filter identifying the parent documents.
private final BitSetProducer parentsFilter;
/**
* Constructor
*
* @param k - the number of top k vectors to collect
* @param parentsFilter Filter identifying the parent documents.
*/
public DiversifyingNearestChildrenKnnCollectorManager(int k, BitSetProducer parentsFilter) {
this.k = k;
this.parentsFilter = parentsFilter;
}
/**
* Return a new {@link DiversifyingNearestChildrenKnnCollector} instance.
*
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @param context the leaf reader context
*/
@Override
public DiversifyingNearestChildrenKnnCollector newCollector(
int visitedLimit, LeafReaderContext context) throws IOException {
BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {
return null;
}
return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
}
}