LUCENE-9798 : Fix looping bug and made Full Knn calculation parallelizable (#55)

This commit is contained in:
nitirajrathore 2021-04-12 22:08:29 +05:30 committed by GitHub
parent a7b0aadcfc
commit e7de06eb51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 456 additions and 65 deletions

View File

@ -17,6 +17,8 @@
package org.apache.lucene.util;
import java.util.Random;
/** Utilities for computations with numeric arrays */
public final class VectorUtil {
@ -112,6 +114,15 @@ public final class VectorUtil {
return squareSum;
}
public static float[] randomVector(Random random, int dim) {
float[] vec = new float[dim];
for (int i = 0; i < dim; i++) {
vec[i] = random.nextFloat();
}
VectorUtil.l2normalize(vec);
return vec;
}
/**
* Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is
* thrown for zero vectors.

View File

@ -56,6 +56,7 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FullKnn;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.PrintStreamInfoStream;
import org.apache.lucene.util.SuppressForbidden;
@ -476,7 +477,9 @@ public class KnnGraphTester {
if (Files.exists(nnPath)) {
return readNN(nnPath);
} else {
int[][] nn = computeNN(docPath, queryPath);
int[][] nn =
new FullKnn(dim, topK, SEARCH_STRATEGY, quiet)
.computeNN(docPath, queryPath, Runtime.getRuntime().availableProcessors());
writeNN(nn, nnPath);
return nn;
}
@ -511,59 +514,6 @@ public class KnnGraphTester {
}
}
private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
int[][] result = new int[numIters][];
if (quiet == false) {
System.out.println("computing true nearest neighbors of " + numIters + " target vectors");
}
try (FileChannel in = FileChannel.open(docPath);
FileChannel qIn = FileChannel.open(queryPath)) {
FloatBuffer queries =
qIn.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
float[] vector = new float[dim];
float[] query = new float[dim];
for (int i = 0; i < numIters; i++) {
queries.get(query);
long totalBytes = (long) numDocs * dim * Float.BYTES;
int
blockSize =
(int)
Math.min(
totalBytes,
(Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES)),
offset = 0;
int j = 0;
// System.out.println("totalBytes=" + totalBytes);
while (j < numDocs) {
FloatBuffer vectors =
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
NeighborQueue queue = new NeighborQueue(topK, SEARCH_STRATEGY.reversed);
for (; j < numDocs && vectors.hasRemaining(); j++) {
vectors.get(vector);
float d = SEARCH_STRATEGY.compare(query, vector);
queue.insertWithOverflow(j, d);
}
result[i] = new int[topK];
for (int k = topK - 1; k >= 0; k--) {
result[i][k] = queue.topNode();
queue.pop();
// System.out.print(" " + n);
}
if (quiet == false && (i + 1) % 10 == 0) {
System.out.print(" " + (i + 1));
System.out.flush();
}
}
}
}
return result;
}
private int createIndex(Path docsPath, Path indexPath) throws IOException {
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE);
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);

View File

@ -235,7 +235,7 @@ public class TestHnsw extends LuceneTestCase {
HnswGraph hnsw = builder.build(vectors);
int totalMatches = 0;
for (int i = 0; i < 100; i++) {
float[] query = randomVector(random(), dim);
float[] query = VectorUtil.randomVector(random(), dim);
NeighborQueue actual = HnswGraph.search(query, topK, 100, vectors, hnsw, random());
NeighborQueue expected = new NeighborQueue(topK, vectors.searchStrategy.reversed);
for (int j = 0; j < size; j++) {
@ -415,18 +415,9 @@ public class TestHnsw extends LuceneTestCase {
private static float[][] createRandomVectors(int size, int dimension, Random random) {
float[][] vectors = new float[size][];
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
vectors[offset] = randomVector(random, dimension);
vectors[offset] = VectorUtil.randomVector(random, dimension);
}
return vectors;
}
}
private static float[] randomVector(Random random, int dim) {
float[] vec = new float[dim];
for (int i = 0; i < dim; i++) {
vec[i] = random.nextFloat();
}
VectorUtil.l2normalize(vec);
return vec;
}
}

View File

@ -0,0 +1,254 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.lucene.index.VectorValues;
/**
* A utility class to calculate the Full KNN / Exact KNN over a set of query vectors and document
* vectors.
*/
public class FullKnn {
private final int dim;
private final int topK;
private final VectorValues.SearchStrategy searchStrategy;
private final boolean quiet;
public FullKnn(int dim, int topK, VectorValues.SearchStrategy searchStrategy, boolean quiet) {
this.dim = dim;
this.topK = topK;
this.searchStrategy = searchStrategy;
this.quiet = quiet;
}
/** internal object to track KNN calculation for one query */
private static class KnnJob {
public int currDocIndex;
float[] queryVector;
float[] currDocVector;
int queryIndex;
private LongHeap queue;
FloatBuffer docVectors;
VectorValues.SearchStrategy searchStrategy;
public KnnJob(
int queryIndex, float[] queryVector, int topK, VectorValues.SearchStrategy searchStrategy) {
this.queryIndex = queryIndex;
this.queryVector = queryVector;
this.currDocVector = new float[queryVector.length];
if (searchStrategy.reversed) {
queue = LongHeap.create(LongHeap.Order.MAX, topK);
} else {
queue = LongHeap.create(LongHeap.Order.MIN, topK);
}
this.searchStrategy = searchStrategy;
}
public void execute() {
while (this.docVectors.hasRemaining()) {
this.docVectors.get(this.currDocVector);
float d = this.searchStrategy.compare(this.queryVector, this.currDocVector);
this.queue.insertWithOverflow(encodeNodeIdAndScore(this.currDocIndex, d));
this.currDocIndex++;
}
}
}
/**
* computes the exact KNN match for each query vector in queryPath for all the document vectors in
* docPath
*
* @param docPath : path to the file containing the float 32 document vectors in bytes with
* little-endian byte order
* @param queryPath : path to the file containing the containing 32-bit floating point vectors in
* little-endian byte order
* @param numThreads : create numThreads to parallelize work
* @return : returns an int 2D array ( int matches[][]) of size 'numIters x topK'. matches[i] is
* an array containing the indexes of the topK most similar document vectors to the ith query
* vector, and is sorted by similarity, with the most similar vector first. Similarity is
* defined by the searchStrategy used to construct this FullKnn.
* @throws IllegalArgumentException : if topK is greater than number of documents in docPath file
* IOException : In case of IO exception while reading files.
*/
public int[][] computeNN(Path docPath, Path queryPath, int numThreads) throws IOException {
assert numThreads > 0;
final int numDocs = (int) (Files.size(docPath) / (dim * Float.BYTES));
final int numQueries = (int) (Files.size(docPath) / (dim * Float.BYTES));
if (!quiet) {
System.out.println(
"computing true nearest neighbors of "
+ numQueries
+ " target vectors using "
+ numThreads
+ " threads.");
}
try (FileChannel docInput = FileChannel.open(docPath);
FileChannel queryInput = FileChannel.open(queryPath)) {
return doFullKnn(
numDocs,
numQueries,
numThreads,
new FileChannelBufferProvider(docInput),
new FileChannelBufferProvider(queryInput));
}
}
int[][] doFullKnn(
int numDocs,
int numQueries,
int numThreads,
BufferProvider docInput,
BufferProvider queryInput)
throws IOException {
if (numDocs < topK) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"topK (%d) cannot be greater than number of docs in docPath (%d)",
topK,
numDocs));
}
final ExecutorService executorService =
Executors.newFixedThreadPool(numThreads, new NamedThreadFactory("FullKnnExecutor"));
int[][] result = new int[numQueries][];
FloatBuffer queries = queryInput.getBuffer(0, numQueries * dim * Float.BYTES).asFloatBuffer();
float[] query = new float[dim];
List<KnnJob> jobList = new ArrayList<>(numThreads);
for (int i = 0; i < numQueries; ) {
for (int j = 0; j < numThreads && i < numQueries; i++, j++) {
queries.get(query);
jobList.add(
new KnnJob(i, ArrayUtil.copyOfSubArray(query, 0, query.length), topK, searchStrategy));
}
long maxBufferSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
int docsLeft = numDocs;
int currDocIndex = 0;
int offset = 0;
while (docsLeft > 0) {
long totalBytes = (long) docsLeft * dim * Float.BYTES;
int blockSize = (int) Math.min(totalBytes, maxBufferSize);
FloatBuffer docVectors = docInput.getBuffer(offset, blockSize).asFloatBuffer();
offset += blockSize;
final List<CompletableFuture<Void>> completableFutures =
jobList.stream()
.peek(job -> job.docVectors = docVectors.duplicate())
.peek(job -> job.currDocIndex = currDocIndex)
.map(job -> CompletableFuture.runAsync(() -> job.execute(), executorService))
.collect(Collectors.toList());
CompletableFuture.allOf(
completableFutures.toArray(new CompletableFuture<?>[completableFutures.size()]))
.join();
docsLeft -= (blockSize / (dim * Float.BYTES));
}
jobList.forEach(
job -> {
result[job.queryIndex] = new int[topK];
for (int k = topK - 1; k >= 0; k--) {
result[job.queryIndex][k] = popNodeId(job.queue);
// System.out.print(" " + n);
}
if (!quiet && (job.queryIndex + 1) % 10 == 0) {
System.out.print(" " + (job.queryIndex + 1));
System.out.flush();
}
});
jobList.clear();
}
executorService.shutdown();
try {
if (!executorService.awaitTermination(1, TimeUnit.MINUTES)) {
executorService.shutdownNow();
}
} catch (InterruptedException e) {
executorService.shutdownNow();
throw new RuntimeException(
"Exception occured while waiting for executor service to finish.", e);
}
return result;
}
/**
* pops the queue and returns the last 4 bytes of long where nodeId is stored.
*
* @param queue queue from which to pop the node id
* @return the node id
*/
private int popNodeId(LongHeap queue) {
return (int) queue.pop();
}
/**
* encodes the score and nodeId into a single long with score in first 4 bytes, to make it
* sortable by score.
*
* @param node node id
* @param score float score of the node wrt incoming node/query
* @return a score sortable long that can be used in heap
*/
private static long encodeNodeIdAndScore(int node, float score) {
return (((long) NumericUtils.floatToSortableInt(score)) << 32) | node;
}
interface BufferProvider {
ByteBuffer getBuffer(int offset, int blockSize) throws IOException;
}
private static class FileChannelBufferProvider implements BufferProvider {
private FileChannel fileChannel;
FileChannelBufferProvider(FileChannel fileChannel) {
this.fileChannel = fileChannel;
}
@Override
public ByteBuffer getBuffer(int offset, int blockSize) throws IOException {
return fileChannel
.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
.order(ByteOrder.LITTLE_ENDIAN);
}
}
}

View File

@ -0,0 +1,185 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import org.apache.lucene.index.VectorValues;
import org.junit.Assert;
public class TestFullKnn extends LuceneTestCase {
static class SimpleBufferProvider implements FullKnn.BufferProvider {
private ByteBuffer buffer;
SimpleBufferProvider(ByteBuffer buffer) {
this.buffer = buffer;
}
@Override
public ByteBuffer getBuffer(int offset, int blockSize) throws IOException {
return buffer.position(offset).slice().order(ByteOrder.LITTLE_ENDIAN).limit(blockSize);
}
}
FullKnn fullKnn;
float[] vec0, vec1, vec2;
@Override
public void setUp() throws Exception {
super.setUp();
fullKnn = new FullKnn(5, 3, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, true);
vec0 = new float[] {1, 2, 3, 4, 5};
VectorUtil.l2normalize(vec0);
vec1 = new float[] {6, 7, 8, 9, 10};
VectorUtil.l2normalize(vec1);
vec2 = new float[] {1, 2, 3, 4, 6};
VectorUtil.l2normalize(vec2);
}
private ByteBuffer floatArrayToByteBuffer(float[] floats) {
ByteBuffer byteBuf =
ByteBuffer.allocateDirect(floats.length * Float.BYTES); // 4 bytes per float
byteBuf.order(ByteOrder.LITTLE_ENDIAN);
FloatBuffer buffer = byteBuf.asFloatBuffer();
buffer.put(floats);
return byteBuf;
}
private void assertBufferEqualsArray(byte[] fa, ByteBuffer wholeBuffer) {
final byte[] tempBufferFloats = new byte[wholeBuffer.remaining()];
wholeBuffer.get(tempBufferFloats);
Assert.assertArrayEquals(fa, tempBufferFloats);
}
public float[] joinArrays(float[]... args) {
int length = 0;
for (float[] arr : args) {
length += arr.length;
}
float[] merged = new float[length];
int i = 0;
for (float[] arr : args) {
for (float f : arr) {
merged[i++] = f;
}
}
return merged;
}
public void testSimpleFloatBufferProvider() throws IOException {
byte[] fa = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
ByteBuffer fb = ByteBuffer.wrap(fa);
final SimpleBufferProvider sf = new SimpleBufferProvider(fb);
final ByteBuffer wholeBuffer = sf.getBuffer(0, 10 * Float.BYTES);
assertBufferEqualsArray(fa, wholeBuffer);
final ByteBuffer fb5_5 = sf.getBuffer(5, 5);
assertBufferEqualsArray(ArrayUtil.copyOfSubArray(fa, 5, 10), fb5_5);
final ByteBuffer fb5_3 = sf.getBuffer(5, 3);
assertBufferEqualsArray(ArrayUtil.copyOfSubArray(fa, 5, 8), fb5_3);
final ByteBuffer fb2_12 = sf.getBuffer(5, 0);
assertBufferEqualsArray(new byte[] {}, fb2_12);
final ByteBuffer fb5_1 = sf.getBuffer(5, 1);
assertBufferEqualsArray(ArrayUtil.copyOfSubArray(fa, 5, 6), fb5_1);
}
public void testSuccessFullKnn() throws IOException {
float[] twoDocs = joinArrays(vec0, vec1);
float[] threeDocs = joinArrays(vec0, vec1, vec2);
float[] threeQueries = joinArrays(vec1, vec0, vec2);
int[][] result =
fullKnn.doFullKnn(
2,
1,
1,
new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)),
new SimpleBufferProvider(floatArrayToByteBuffer(vec1)));
Assert.assertArrayEquals(result[0], new int[] {1, 0});
result =
fullKnn.doFullKnn(
2,
1,
2,
new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)),
new SimpleBufferProvider(floatArrayToByteBuffer(vec0)));
Assert.assertArrayEquals(result[0], new int[] {0, 1});
float[] twoQueries = joinArrays(vec1, vec0);
result =
fullKnn.doFullKnn(
2,
2,
2,
new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)),
new SimpleBufferProvider(floatArrayToByteBuffer(twoQueries)));
Assert.assertArrayEquals(result, new int[][] {{1, 0}, {0, 1}});
result =
fullKnn.doFullKnn(
3,
3,
3,
new SimpleBufferProvider(floatArrayToByteBuffer(threeDocs)),
new SimpleBufferProvider(floatArrayToByteBuffer(threeQueries)));
Assert.assertArrayEquals(new int[][] {{1, 0}, {0, 2}, {2, 0}}, result);
FullKnn full3nn = new FullKnn(5, 3, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, true);
result =
full3nn.doFullKnn(
3,
3,
3,
new SimpleBufferProvider(floatArrayToByteBuffer(threeDocs)),
new SimpleBufferProvider(floatArrayToByteBuffer(threeQueries)));
Assert.assertArrayEquals(new int[][] {{1, 0, 2}, {0, 2, 1}, {2, 0, 1}}, result);
}
public void testExceptionFullKnn() {
float[] twoDocs = joinArrays(vec0, vec1);
expectThrows(
IllegalArgumentException.class,
() -> {
fullKnn.doFullKnn(
2,
1,
1,
new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)),
new SimpleBufferProvider(floatArrayToByteBuffer(vec1)));
});
}
}