mirror of https://github.com/apache/lucene.git
Speedup concurrent multi-segment HNWS graph search (#12962)
Speedup concurrent multi-segment HNWS graph search by exchanging the global top candidated collected so far across segments. These global top candidates set the minimum threshold that new candidates need to pass to be considered. This allows earlier stopping for segments that don't have good candidates.
This commit is contained in:
parent
635d09001a
commit
d095ed02a2
|
@ -244,6 +244,9 @@ Optimizations
|
|||
|
||||
* GITHUB#13036 Optimize counts on two clause term disjunctions. (Adrien Grand, Johannes Fredén)
|
||||
|
||||
* GITHUB#12962: Speedup concurrent multi-segment HNWS graph search (Mayya Sharipova, Tom Veasey)
|
||||
|
||||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
* GITHUB#12866: Prevent extra similarity computation for single-level HNSW graphs. (Kaival Parikh)
|
||||
|
|
|
@ -42,6 +42,7 @@ module org.apache.lucene.core {
|
|||
exports org.apache.lucene.search;
|
||||
exports org.apache.lucene.search.comparators;
|
||||
exports org.apache.lucene.search.similarities;
|
||||
exports org.apache.lucene.search.knn;
|
||||
exports org.apache.lucene.store;
|
||||
exports org.apache.lucene.util;
|
||||
exports org.apache.lucene.util.automaton;
|
||||
|
|
|
@ -23,7 +23,7 @@ package org.apache.lucene.search;
|
|||
*/
|
||||
public abstract class AbstractKnnCollector implements KnnCollector {
|
||||
|
||||
private long visitedCount;
|
||||
protected long visitedCount;
|
||||
private final long visitLimit;
|
||||
private final int k;
|
||||
|
||||
|
|
|
@ -29,6 +29,8 @@ import org.apache.lucene.codecs.KnnVectorsReader;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.search.knn.TopKnnCollectorManager;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
@ -79,11 +81,12 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
filterWeight = null;
|
||||
}
|
||||
|
||||
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
|
||||
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
|
||||
List<LeafReaderContext> leafReaderContexts = reader.leaves();
|
||||
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
|
||||
for (LeafReaderContext context : leafReaderContexts) {
|
||||
tasks.add(() -> searchLeaf(context, filterWeight));
|
||||
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
|
||||
}
|
||||
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
|
||||
|
||||
|
@ -95,8 +98,10 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
return createRewrittenQuery(reader, topK);
|
||||
}
|
||||
|
||||
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
|
||||
TopDocs results = getLeafResults(ctx, filterWeight);
|
||||
private TopDocs searchLeaf(
|
||||
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
|
||||
throws IOException {
|
||||
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
scoreDoc.doc += ctx.docBase;
|
||||
|
@ -105,12 +110,14 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
return results;
|
||||
}
|
||||
|
||||
private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
|
||||
private TopDocs getLeafResults(
|
||||
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
|
||||
throws IOException {
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
int maxDoc = ctx.reader().maxDoc();
|
||||
|
||||
if (filterWeight == null) {
|
||||
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
|
||||
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
|
||||
}
|
||||
|
||||
Scorer scorer = filterWeight.scorer(ctx);
|
||||
|
@ -128,7 +135,7 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
}
|
||||
|
||||
// Perform the approximate kNN search
|
||||
TopDocs results = approximateSearch(ctx, acceptDocs, cost);
|
||||
TopDocs results = approximateSearch(ctx, acceptDocs, cost, knnCollectorManager);
|
||||
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
|
||||
return results;
|
||||
} else {
|
||||
|
@ -155,8 +162,16 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
}
|
||||
}
|
||||
|
||||
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
|
||||
return new TopKnnCollectorManager(k, searcher);
|
||||
}
|
||||
|
||||
protected abstract TopDocs approximateSearch(
|
||||
LeafReaderContext context, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
LeafReaderContext context,
|
||||
Bits acceptDocs,
|
||||
int visitedLimit,
|
||||
KnnCollectorManager knnCollectorManager)
|
||||
throws IOException;
|
||||
|
||||
abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi)
|
||||
throws IOException;
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.lucene.document.KnnFloatVectorField;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
|
@ -75,10 +76,23 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
|
||||
protected TopDocs approximateSearch(
|
||||
LeafReaderContext context,
|
||||
Bits acceptDocs,
|
||||
int visitedLimit,
|
||||
KnnCollectorManager knnCollectorManager)
|
||||
throws IOException {
|
||||
TopDocs results =
|
||||
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
// The field does not exist or does not index vectors
|
||||
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||
}
|
||||
if (Math.min(knnCollector.k(), context.reader().getByteVectorValues(fi.name).size()) == 0) {
|
||||
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||
}
|
||||
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
|
||||
TopDocs results = knnCollector.topDocs();
|
||||
return results != null ? results : NO_RESULTS;
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.lucene.document.KnnFloatVectorField;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
@ -76,10 +77,23 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
|
||||
protected TopDocs approximateSearch(
|
||||
LeafReaderContext context,
|
||||
Bits acceptDocs,
|
||||
int visitedLimit,
|
||||
KnnCollectorManager knnCollectorManager)
|
||||
throws IOException {
|
||||
TopDocs results =
|
||||
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
// The field does not exist or does not index vectors
|
||||
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||
}
|
||||
if (Math.min(knnCollector.k(), context.reader().getFloatVectorValues(fi.name).size()) == 0) {
|
||||
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||
}
|
||||
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
|
||||
TopDocs results = knnCollector.topDocs();
|
||||
return results != null ? results : NO_RESULTS;
|
||||
}
|
||||
|
||||
|
|
|
@ -25,9 +25,9 @@ import org.apache.lucene.util.hnsw.NeighborQueue;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class TopKnnCollector extends AbstractKnnCollector {
|
||||
public class TopKnnCollector extends AbstractKnnCollector {
|
||||
|
||||
private final NeighborQueue queue;
|
||||
protected final NeighborQueue queue;
|
||||
|
||||
/**
|
||||
* @param k the number of neighbors to collect
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.search.knn;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
|
||||
/**
|
||||
* KnnCollectorManager responsible for creating {@link KnnCollector} instances. Useful to create
|
||||
* {@link KnnCollector} instances that share global state across leaves, such a global queue of
|
||||
* results collected so far.
|
||||
*/
|
||||
public interface KnnCollectorManager {
|
||||
|
||||
/**
|
||||
* Return a new {@link KnnCollector} instance.
|
||||
*
|
||||
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
|
||||
* @param context the leaf reader context
|
||||
*/
|
||||
KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException;
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.search.knn;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopKnnCollector;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
|
||||
import org.apache.lucene.util.hnsw.FloatHeap;
|
||||
|
||||
/**
|
||||
* MultiLeafTopKnnCollector is a specific KnnCollector that can exchange the top collected results
|
||||
* across segments through a shared global queue.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class MultiLeafTopKnnCollector extends TopKnnCollector {
|
||||
|
||||
// greediness of globally non-competitive search: (0,1]
|
||||
private static final float DEFAULT_GREEDINESS = 0.9f;
|
||||
// the global queue of the highest similarities collected so far across all segments
|
||||
private final BlockingFloatHeap globalSimilarityQueue;
|
||||
// the local queue of the highest similarities if we are not competitive globally
|
||||
// the size of this queue is defined by greediness
|
||||
private final FloatHeap nonCompetitiveQueue;
|
||||
private final float greediness;
|
||||
// the queue of the local similarities to periodically update with the global queue
|
||||
private final FloatHeap updatesQueue;
|
||||
// interval to synchronize the local and global queues, as a number of visited vectors
|
||||
private final int interval = 0xff; // 255
|
||||
private boolean kResultsCollected = false;
|
||||
private float cachedGlobalMinSim = Float.NEGATIVE_INFINITY;
|
||||
|
||||
/**
|
||||
* @param k the number of neighbors to collect
|
||||
* @param visitLimit how many vector nodes the results are allowed to visit
|
||||
*/
|
||||
public MultiLeafTopKnnCollector(int k, int visitLimit, BlockingFloatHeap globalSimilarityQueue) {
|
||||
super(k, visitLimit);
|
||||
this.greediness = DEFAULT_GREEDINESS;
|
||||
this.globalSimilarityQueue = globalSimilarityQueue;
|
||||
this.nonCompetitiveQueue = new FloatHeap(Math.max(1, Math.round((1 - greediness) * k)));
|
||||
this.updatesQueue = new FloatHeap(k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean collect(int docId, float similarity) {
|
||||
boolean localSimUpdated = queue.insertWithOverflow(docId, similarity);
|
||||
boolean firstKResultsCollected = (kResultsCollected == false && queue.size() == k());
|
||||
if (firstKResultsCollected) {
|
||||
kResultsCollected = true;
|
||||
}
|
||||
updatesQueue.offer(similarity);
|
||||
boolean globalSimUpdated = nonCompetitiveQueue.offer(similarity);
|
||||
|
||||
if (kResultsCollected) {
|
||||
// as we've collected k results, we can start do periodic updates with the global queue
|
||||
if (firstKResultsCollected || (visitedCount & interval) == 0) {
|
||||
cachedGlobalMinSim = globalSimilarityQueue.offer(updatesQueue.getHeap());
|
||||
updatesQueue.clear();
|
||||
globalSimUpdated = true;
|
||||
}
|
||||
}
|
||||
return localSimUpdated || globalSimUpdated;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float minCompetitiveSimilarity() {
|
||||
if (kResultsCollected == false) {
|
||||
return Float.NEGATIVE_INFINITY;
|
||||
}
|
||||
return Math.max(queue.topScore(), Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim));
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs topDocs() {
|
||||
assert queue.size() <= k() : "Tried to collect more results than the maximum number allowed";
|
||||
ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
|
||||
for (int i = 1; i <= scoreDocs.length; i++) {
|
||||
scoreDocs[scoreDocs.length - i] = new ScoreDoc(queue.topNode(), queue.topScore());
|
||||
queue.pop();
|
||||
}
|
||||
TotalHits.Relation relation =
|
||||
earlyTerminated()
|
||||
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|
||||
: TotalHits.Relation.EQUAL_TO;
|
||||
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "MultiLeafTopKnnCollector[k=" + k() + ", size=" + queue.size() + "]";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.search.knn;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.TopKnnCollector;
|
||||
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
|
||||
|
||||
/**
|
||||
* TopKnnCollectorManager responsible for creating {@link TopKnnCollector} instances. When
|
||||
* concurrency is supported, the {@link BlockingFloatHeap} is used to track the global top scores
|
||||
* collected across all leaves.
|
||||
*/
|
||||
public class TopKnnCollectorManager implements KnnCollectorManager {
|
||||
|
||||
// the number of docs to collect
|
||||
private final int k;
|
||||
// the global score queue used to track the top scores collected across all leaves
|
||||
private final BlockingFloatHeap globalScoreQueue;
|
||||
|
||||
public TopKnnCollectorManager(int k, IndexSearcher indexSearcher) {
|
||||
boolean isMultiSegments = indexSearcher.getIndexReader().leaves().size() > 1;
|
||||
this.k = k;
|
||||
this.globalScoreQueue = isMultiSegments ? new BlockingFloatHeap(k) : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new {@link TopKnnCollector} instance.
|
||||
*
|
||||
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
|
||||
* @param context the leaf reader context
|
||||
*/
|
||||
@Override
|
||||
public TopKnnCollector newCollector(int visitedLimit, LeafReaderContext context)
|
||||
throws IOException {
|
||||
if (globalScoreQueue == null) {
|
||||
return new TopKnnCollector(k, visitedLimit);
|
||||
} else {
|
||||
return new MultiLeafTopKnnCollector(k, visitedLimit, globalScoreQueue);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/** Classes related to vector search: knn and vector fields. */
|
||||
package org.apache.lucene.search.knn;
|
|
@ -0,0 +1,192 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
/**
|
||||
* A blocking bounded min heap that stores floats. The top element is the lowest value of the heap.
|
||||
*
|
||||
* <p>A primitive priority queue that maintains a partial ordering of its elements such that the
|
||||
* least element can always be found in constant time. Implementation is based on {@link
|
||||
* org.apache.lucene.util.LongHeap}
|
||||
*
|
||||
* @lucene.internal
|
||||
*/
|
||||
public final class BlockingFloatHeap {
|
||||
private final int maxSize;
|
||||
private final float[] heap;
|
||||
private final ReentrantLock lock;
|
||||
private int size;
|
||||
|
||||
public BlockingFloatHeap(int maxSize) {
|
||||
this.maxSize = maxSize;
|
||||
this.heap = new float[maxSize + 1];
|
||||
this.lock = new ReentrantLock();
|
||||
this.size = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Inserts a value into this heap.
|
||||
*
|
||||
* <p>If the number of values would exceed the heap's maxSize, the least value is discarded
|
||||
*
|
||||
* @param value the value to add
|
||||
* @return the new 'top' element in the queue.
|
||||
*/
|
||||
public float offer(float value) {
|
||||
lock.lock();
|
||||
try {
|
||||
if (size < maxSize) {
|
||||
push(value);
|
||||
return heap[1];
|
||||
} else {
|
||||
if (value >= heap[1]) {
|
||||
updateTop(value);
|
||||
}
|
||||
return heap[1];
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Inserts array of values into this heap.
|
||||
*
|
||||
* <p>Values must be sorted in ascending order.
|
||||
*
|
||||
* @param values a set of values to insert, must be sorted in ascending order
|
||||
* @return the new 'top' element in the queue.
|
||||
*/
|
||||
public float offer(float[] values) {
|
||||
lock.lock();
|
||||
try {
|
||||
for (int i = values.length - 1; i >= 0; i--) {
|
||||
if (size < maxSize) {
|
||||
push(values[i]);
|
||||
} else {
|
||||
if (values[i] >= heap[1]) {
|
||||
updateTop(values[i]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return heap[1];
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes and returns the head of the heap
|
||||
*
|
||||
* @return the head of the heap, the smallest value
|
||||
* @throws IllegalStateException if the heap is empty
|
||||
*/
|
||||
public float poll() {
|
||||
if (size > 0) {
|
||||
float result;
|
||||
|
||||
lock.lock();
|
||||
try {
|
||||
result = heap[1]; // save first value
|
||||
heap[1] = heap[size]; // move last to first
|
||||
size--;
|
||||
downHeap(1); // adjust heap
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
throw new IllegalStateException("The heap is empty");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves, but does not remove, the head of this heap.
|
||||
*
|
||||
* @return the head of the heap, the smallest value
|
||||
*/
|
||||
public float peek() {
|
||||
lock.lock();
|
||||
try {
|
||||
return heap[1];
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of elements in this heap.
|
||||
*
|
||||
* @return the number of elements in this heap
|
||||
*/
|
||||
public int size() {
|
||||
lock.lock();
|
||||
try {
|
||||
return size;
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
private void push(float element) {
|
||||
size++;
|
||||
heap[size] = element;
|
||||
upHeap(size);
|
||||
}
|
||||
|
||||
private float updateTop(float value) {
|
||||
heap[1] = value;
|
||||
downHeap(1);
|
||||
return heap[1];
|
||||
}
|
||||
|
||||
private void downHeap(int i) {
|
||||
float value = heap[i]; // save top value
|
||||
int j = i << 1; // find smaller child
|
||||
int k = j + 1;
|
||||
if (k <= size && heap[k] < heap[j]) {
|
||||
j = k;
|
||||
}
|
||||
while (j <= size && heap[j] < value) {
|
||||
heap[i] = heap[j]; // shift up child
|
||||
i = j;
|
||||
j = i << 1;
|
||||
k = j + 1;
|
||||
if (k <= size && heap[k] < heap[j]) {
|
||||
j = k;
|
||||
}
|
||||
}
|
||||
heap[i] = value; // install saved value
|
||||
}
|
||||
|
||||
private void upHeap(int origPos) {
|
||||
int i = origPos;
|
||||
float value = heap[i]; // save bottom value
|
||||
int j = i >>> 1;
|
||||
while (j > 0 && value < heap[j]) {
|
||||
heap[i] = heap[j]; // shift parents down
|
||||
i = j;
|
||||
j = j >>> 1;
|
||||
}
|
||||
heap[i] = value; // install saved value
|
||||
}
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
/**
|
||||
* A bounded min heap that stores floats. The top element is the lowest value of the heap.
|
||||
*
|
||||
* <p>A primitive priority queue that maintains a partial ordering of its elements such that the
|
||||
* least element can always be found in constant time. Implementation is based on {@link
|
||||
* org.apache.lucene.util.LongHeap}
|
||||
*
|
||||
* @lucene.internal
|
||||
*/
|
||||
public final class FloatHeap {
|
||||
private final int maxSize;
|
||||
private final float[] heap;
|
||||
private int size;
|
||||
|
||||
public FloatHeap(int maxSize) {
|
||||
this.maxSize = maxSize;
|
||||
this.heap = new float[maxSize + 1];
|
||||
this.size = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Inserts a value into this heap.
|
||||
*
|
||||
* <p>If the number of values would exceed the heap's maxSize, the least value is discarded
|
||||
*
|
||||
* @param value the value to add
|
||||
* @return whether the value was added (unless the heap is full, or the new value is less than the
|
||||
* top value)
|
||||
*/
|
||||
public boolean offer(float value) {
|
||||
if (size >= maxSize) {
|
||||
if (value < heap[1]) {
|
||||
return false;
|
||||
}
|
||||
updateTop(value);
|
||||
return true;
|
||||
}
|
||||
push(value);
|
||||
return true;
|
||||
}
|
||||
|
||||
public float[] getHeap() {
|
||||
float[] result = new float[size];
|
||||
System.arraycopy(this.heap, 1, result, 0, size);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes and returns the head of the heap
|
||||
*
|
||||
* @return the head of the heap, the smallest value
|
||||
* @throws IllegalStateException if the heap is empty
|
||||
*/
|
||||
public float poll() {
|
||||
if (size > 0) {
|
||||
float result;
|
||||
result = heap[1]; // save first value
|
||||
heap[1] = heap[size]; // move last to first
|
||||
size--;
|
||||
downHeap(1); // adjust heap
|
||||
return result;
|
||||
} else {
|
||||
throw new IllegalStateException("The heap is empty");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves, but does not remove, the head of this heap.
|
||||
*
|
||||
* @return the head of the heap, the smallest value
|
||||
*/
|
||||
public float peek() {
|
||||
return heap[1];
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of elements in this heap.
|
||||
*
|
||||
* @return the number of elements in this heap
|
||||
*/
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
public void clear() {
|
||||
size = 0;
|
||||
}
|
||||
|
||||
private void push(float element) {
|
||||
size++;
|
||||
heap[size] = element;
|
||||
upHeap(size);
|
||||
}
|
||||
|
||||
private float updateTop(float value) {
|
||||
heap[1] = value;
|
||||
downHeap(1);
|
||||
return heap[1];
|
||||
}
|
||||
|
||||
private void downHeap(int i) {
|
||||
float value = heap[i]; // save top value
|
||||
int j = i << 1; // find smaller child
|
||||
int k = j + 1;
|
||||
if (k <= size && heap[k] < heap[j]) {
|
||||
j = k;
|
||||
}
|
||||
while (j <= size && heap[j] < value) {
|
||||
heap[i] = heap[j]; // shift up child
|
||||
i = j;
|
||||
j = i << 1;
|
||||
k = j + 1;
|
||||
if (k <= size && heap[k] < heap[j]) {
|
||||
j = k;
|
||||
}
|
||||
}
|
||||
heap[i] = value; // install saved value
|
||||
}
|
||||
|
||||
private void upHeap(int origPos) {
|
||||
int i = origPos;
|
||||
float value = heap[i]; // save bottom value
|
||||
int j = i >>> 1;
|
||||
while (j > 0 && value < heap[j]) {
|
||||
heap[i] = heap[j]; // shift parents down
|
||||
i = j;
|
||||
j = j >>> 1;
|
||||
}
|
||||
heap[i] = value; // install saved value
|
||||
}
|
||||
}
|
|
@ -547,6 +547,7 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
|
|||
5,
|
||||
leaf.getLiveDocs(),
|
||||
Integer.MAX_VALUE));
|
||||
|
||||
} else {
|
||||
DocIdSetIterator iter = leaf.getByteVectorValues("vector");
|
||||
scanAndRetrieve(leaf, iter);
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
|
||||
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.SuppressForbidden;
|
||||
|
||||
public class TestBlockingFloatHeap extends LuceneTestCase {
|
||||
|
||||
public void testBasicOperations() {
|
||||
BlockingFloatHeap heap = new BlockingFloatHeap(3);
|
||||
heap.offer(2);
|
||||
heap.offer(4);
|
||||
heap.offer(1);
|
||||
heap.offer(3);
|
||||
assertEquals(3, heap.size());
|
||||
assertEquals(2, heap.peek(), 0);
|
||||
|
||||
assertEquals(2, heap.poll(), 0);
|
||||
assertEquals(3, heap.poll(), 0);
|
||||
assertEquals(4, heap.poll(), 0);
|
||||
assertEquals(0, heap.size(), 0);
|
||||
}
|
||||
|
||||
public void testBasicOperations2() {
|
||||
int size = atLeast(10);
|
||||
BlockingFloatHeap heap = new BlockingFloatHeap(size);
|
||||
double sum = 0, sum2 = 0;
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
float next = random().nextFloat(100f);
|
||||
sum += next;
|
||||
heap.offer(next);
|
||||
}
|
||||
|
||||
float last = Float.NEGATIVE_INFINITY;
|
||||
for (long i = 0; i < size; i++) {
|
||||
float next = heap.poll();
|
||||
assertTrue(next >= last);
|
||||
last = next;
|
||||
sum2 += last;
|
||||
}
|
||||
assertEquals(sum, sum2, 0.01);
|
||||
}
|
||||
|
||||
@SuppressForbidden(reason = "Thread sleep")
|
||||
public void testMultipleThreads() throws Exception {
|
||||
Thread[] threads = new Thread[randomIntBetween(3, 20)];
|
||||
final CountDownLatch latch = new CountDownLatch(1);
|
||||
BlockingFloatHeap globalHeap = new BlockingFloatHeap(1);
|
||||
|
||||
for (int i = 0; i < threads.length; i++) {
|
||||
threads[i] =
|
||||
new Thread(
|
||||
() -> {
|
||||
try {
|
||||
latch.await();
|
||||
int numIterations = randomIntBetween(10, 100);
|
||||
float bottomValue = 0;
|
||||
|
||||
while (numIterations-- > 0) {
|
||||
bottomValue += randomIntBetween(0, 5);
|
||||
globalHeap.offer(bottomValue);
|
||||
Thread.sleep(randomIntBetween(0, 50));
|
||||
|
||||
float globalBottomValue = globalHeap.peek();
|
||||
assertTrue(globalBottomValue >= bottomValue);
|
||||
bottomValue = globalBottomValue;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
threads[i].start();
|
||||
}
|
||||
|
||||
latch.countDown();
|
||||
for (Thread t : threads) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
|
||||
public class TestFloatHeap extends LuceneTestCase {
|
||||
|
||||
public void testBasicOperations() {
|
||||
FloatHeap heap = new FloatHeap(3);
|
||||
heap.offer(2);
|
||||
heap.offer(4);
|
||||
heap.offer(1);
|
||||
heap.offer(3);
|
||||
assertEquals(3, heap.size());
|
||||
assertEquals(2, heap.peek(), 0);
|
||||
|
||||
assertEquals(2, heap.poll(), 0);
|
||||
assertEquals(3, heap.poll(), 0);
|
||||
assertEquals(4, heap.poll(), 0);
|
||||
assertEquals(0, heap.size(), 0);
|
||||
}
|
||||
|
||||
public void testBasicOperations2() {
|
||||
int size = atLeast(10);
|
||||
FloatHeap heap = new FloatHeap(size);
|
||||
double sum = 0, sum2 = 0;
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
float next = random().nextFloat(100f);
|
||||
sum += next;
|
||||
heap.offer(next);
|
||||
}
|
||||
|
||||
float last = Float.NEGATIVE_INFINITY;
|
||||
for (long i = 0; i < size; i++) {
|
||||
float next = heap.poll();
|
||||
assertTrue(next >= last);
|
||||
last = next;
|
||||
sum2 += last;
|
||||
}
|
||||
assertEquals(sum, sum2, 0.01);
|
||||
}
|
||||
|
||||
public void testClear() {
|
||||
FloatHeap heap = new FloatHeap(3);
|
||||
heap.offer(20);
|
||||
heap.offer(40);
|
||||
heap.offer(30);
|
||||
assertEquals(3, heap.size());
|
||||
assertEquals(20, heap.peek(), 0);
|
||||
|
||||
heap.clear();
|
||||
assertEquals(0, heap.size(), 0);
|
||||
assertEquals(20, heap.peek(), 0);
|
||||
|
||||
heap.offer(15);
|
||||
heap.offer(35);
|
||||
assertEquals(2, heap.size());
|
||||
assertEquals(15, heap.peek(), 0);
|
||||
|
||||
assertEquals(15, heap.poll(), 0);
|
||||
assertEquals(35, heap.poll(), 0);
|
||||
assertEquals(0, heap.size(), 0);
|
||||
}
|
||||
}
|
|
@ -26,6 +26,7 @@ import org.apache.lucene.index.VectorEncoding;
|
|||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.HitQueue;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.KnnByteVectorQuery;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.Query;
|
||||
|
@ -33,6 +34,7 @@ import org.apache.lucene.search.ScoreDoc;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopDocsCollector;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
|
@ -123,7 +125,16 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
|
||||
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
|
||||
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs approximateSearch(
|
||||
LeafReaderContext context,
|
||||
Bits acceptDocs,
|
||||
int visitedLimit,
|
||||
KnnCollectorManager knnCollectorManager)
|
||||
throws IOException {
|
||||
BitSet parentBitSet = parentsFilter.getBitSet(context);
|
||||
if (parentBitSet == null) {
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.lucene.index.VectorEncoding;
|
|||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.HitQueue;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.KnnFloatVectorQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
|
@ -33,6 +34,7 @@ import org.apache.lucene.search.ScoreDoc;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopDocsCollector;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
|
@ -123,14 +125,21 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
|
|||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
|
||||
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
|
||||
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs approximateSearch(
|
||||
LeafReaderContext context,
|
||||
Bits acceptDocs,
|
||||
int visitedLimit,
|
||||
KnnCollectorManager knnCollectorManager)
|
||||
throws IOException {
|
||||
BitSet parentBitSet = parentsFilter.getBitSet(context);
|
||||
if (parentBitSet == null) {
|
||||
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
|
||||
if (collector == null) {
|
||||
return NO_RESULTS;
|
||||
}
|
||||
KnnCollector collector =
|
||||
new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
|
||||
context.reader().searchNearestVectors(field, query, collector, acceptDocs);
|
||||
return collector.topDocs();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.search.join;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
|
||||
/**
|
||||
* DiversifyingNearestChildrenKnnCollectorManager responsible for creating {@link
|
||||
* DiversifyingNearestChildrenKnnCollector} instances.
|
||||
*/
|
||||
public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollectorManager {
|
||||
|
||||
// the number of docs to collect
|
||||
private final int k;
|
||||
// filter identifying the parent documents.
|
||||
private final BitSetProducer parentsFilter;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param k - the number of top k vectors to collect
|
||||
* @param parentsFilter Filter identifying the parent documents.
|
||||
*/
|
||||
public DiversifyingNearestChildrenKnnCollectorManager(int k, BitSetProducer parentsFilter) {
|
||||
this.k = k;
|
||||
this.parentsFilter = parentsFilter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new {@link DiversifyingNearestChildrenKnnCollector} instance.
|
||||
*
|
||||
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
|
||||
* @param context the leaf reader context
|
||||
*/
|
||||
@Override
|
||||
public DiversifyingNearestChildrenKnnCollector newCollector(
|
||||
int visitedLimit, LeafReaderContext context) throws IOException {
|
||||
BitSet parentBitSet = parentsFilter.getBitSet(context);
|
||||
if (parentBitSet == null) {
|
||||
return null;
|
||||
}
|
||||
return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue