From e7de06eb51e70de39625718354a8e1088f283d1d Mon Sep 17 00:00:00 2001 From: nitirajrathore Date: Mon, 12 Apr 2021 22:08:29 +0530 Subject: [PATCH] LUCENE-9798 : Fix looping bug and made Full Knn calculation parallelizable (#55) --- .../org/apache/lucene/util/VectorUtil.java | 11 + .../lucene/util/hnsw/KnnGraphTester.java | 58 +--- .../org/apache/lucene/util/hnsw/TestHnsw.java | 13 +- .../java/org/apache/lucene/util/FullKnn.java | 254 ++++++++++++++++++ .../org/apache/lucene/util/TestFullKnn.java | 185 +++++++++++++ 5 files changed, 456 insertions(+), 65 deletions(-) create mode 100644 lucene/test-framework/src/java/org/apache/lucene/util/FullKnn.java create mode 100644 lucene/test-framework/src/test/org/apache/lucene/util/TestFullKnn.java diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 546d13de7fd..493a513a995 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -17,6 +17,8 @@ package org.apache.lucene.util; +import java.util.Random; + /** Utilities for computations with numeric arrays */ public final class VectorUtil { @@ -112,6 +114,15 @@ public final class VectorUtil { return squareSum; } + public static float[] randomVector(Random random, int dim) { + float[] vec = new float[dim]; + for (int i = 0; i < dim; i++) { + vec[i] = random.nextFloat(); + } + VectorUtil.l2normalize(vec); + return vec; + } + /** * Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is * thrown for zero vectors. diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java index 94febfb0f6c..35145e4003c 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java @@ -56,6 +56,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FullKnn; import org.apache.lucene.util.IntroSorter; import org.apache.lucene.util.PrintStreamInfoStream; import org.apache.lucene.util.SuppressForbidden; @@ -476,7 +477,9 @@ public class KnnGraphTester { if (Files.exists(nnPath)) { return readNN(nnPath); } else { - int[][] nn = computeNN(docPath, queryPath); + int[][] nn = + new FullKnn(dim, topK, SEARCH_STRATEGY, quiet) + .computeNN(docPath, queryPath, Runtime.getRuntime().availableProcessors()); writeNN(nn, nnPath); return nn; } @@ -511,59 +514,6 @@ public class KnnGraphTester { } } - private int[][] computeNN(Path docPath, Path queryPath) throws IOException { - int[][] result = new int[numIters][]; - if (quiet == false) { - System.out.println("computing true nearest neighbors of " + numIters + " target vectors"); - } - try (FileChannel in = FileChannel.open(docPath); - FileChannel qIn = FileChannel.open(queryPath)) { - FloatBuffer queries = - qIn.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES) - .order(ByteOrder.LITTLE_ENDIAN) - .asFloatBuffer(); - float[] vector = new float[dim]; - float[] query = new float[dim]; - for (int i = 0; i < numIters; i++) { - queries.get(query); - long totalBytes = (long) numDocs * dim * Float.BYTES; - int - blockSize = - (int) - Math.min( - totalBytes, - (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES)), - offset = 0; - int j = 0; - // System.out.println("totalBytes=" + totalBytes); - while (j < numDocs) { - FloatBuffer vectors = - in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize) - .order(ByteOrder.LITTLE_ENDIAN) - .asFloatBuffer(); - offset += blockSize; - NeighborQueue queue = new NeighborQueue(topK, SEARCH_STRATEGY.reversed); - for (; j < numDocs && vectors.hasRemaining(); j++) { - vectors.get(vector); - float d = SEARCH_STRATEGY.compare(query, vector); - queue.insertWithOverflow(j, d); - } - result[i] = new int[topK]; - for (int k = topK - 1; k >= 0; k--) { - result[i][k] = queue.topNode(); - queue.pop(); - // System.out.print(" " + n); - } - if (quiet == false && (i + 1) % 10 == 0) { - System.out.print(" " + (i + 1)); - System.out.flush(); - } - } - } - } - return result; - } - private int createIndex(Path docsPath, Path indexPath) throws IOException { IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE); // iwc.setMergePolicy(NoMergePolicy.INSTANCE); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java index 26d01d6faf0..f76d5a509b3 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java @@ -235,7 +235,7 @@ public class TestHnsw extends LuceneTestCase { HnswGraph hnsw = builder.build(vectors); int totalMatches = 0; for (int i = 0; i < 100; i++) { - float[] query = randomVector(random(), dim); + float[] query = VectorUtil.randomVector(random(), dim); NeighborQueue actual = HnswGraph.search(query, topK, 100, vectors, hnsw, random()); NeighborQueue expected = new NeighborQueue(topK, vectors.searchStrategy.reversed); for (int j = 0; j < size; j++) { @@ -415,18 +415,9 @@ public class TestHnsw extends LuceneTestCase { private static float[][] createRandomVectors(int size, int dimension, Random random) { float[][] vectors = new float[size][]; for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) { - vectors[offset] = randomVector(random, dimension); + vectors[offset] = VectorUtil.randomVector(random, dimension); } return vectors; } } - - private static float[] randomVector(Random random, int dim) { - float[] vec = new float[dim]; - for (int i = 0; i < dim; i++) { - vec[i] = random.nextFloat(); - } - VectorUtil.l2normalize(vec); - return vec; - } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/util/FullKnn.java b/lucene/test-framework/src/java/org/apache/lucene/util/FullKnn.java new file mode 100644 index 00000000000..61624a0e4e3 --- /dev/null +++ b/lucene/test-framework/src/java/org/apache/lucene/util/FullKnn.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.apache.lucene.index.VectorValues; + +/** + * A utility class to calculate the Full KNN / Exact KNN over a set of query vectors and document + * vectors. + */ +public class FullKnn { + + private final int dim; + private final int topK; + private final VectorValues.SearchStrategy searchStrategy; + private final boolean quiet; + + public FullKnn(int dim, int topK, VectorValues.SearchStrategy searchStrategy, boolean quiet) { + this.dim = dim; + this.topK = topK; + this.searchStrategy = searchStrategy; + this.quiet = quiet; + } + + /** internal object to track KNN calculation for one query */ + private static class KnnJob { + public int currDocIndex; + float[] queryVector; + float[] currDocVector; + int queryIndex; + private LongHeap queue; + FloatBuffer docVectors; + VectorValues.SearchStrategy searchStrategy; + + public KnnJob( + int queryIndex, float[] queryVector, int topK, VectorValues.SearchStrategy searchStrategy) { + this.queryIndex = queryIndex; + this.queryVector = queryVector; + this.currDocVector = new float[queryVector.length]; + if (searchStrategy.reversed) { + queue = LongHeap.create(LongHeap.Order.MAX, topK); + } else { + queue = LongHeap.create(LongHeap.Order.MIN, topK); + } + this.searchStrategy = searchStrategy; + } + + public void execute() { + while (this.docVectors.hasRemaining()) { + this.docVectors.get(this.currDocVector); + float d = this.searchStrategy.compare(this.queryVector, this.currDocVector); + this.queue.insertWithOverflow(encodeNodeIdAndScore(this.currDocIndex, d)); + this.currDocIndex++; + } + } + } + + /** + * computes the exact KNN match for each query vector in queryPath for all the document vectors in + * docPath + * + * @param docPath : path to the file containing the float 32 document vectors in bytes with + * little-endian byte order + * @param queryPath : path to the file containing the containing 32-bit floating point vectors in + * little-endian byte order + * @param numThreads : create numThreads to parallelize work + * @return : returns an int 2D array ( int matches[][]) of size 'numIters x topK'. matches[i] is + * an array containing the indexes of the topK most similar document vectors to the ith query + * vector, and is sorted by similarity, with the most similar vector first. Similarity is + * defined by the searchStrategy used to construct this FullKnn. + * @throws IllegalArgumentException : if topK is greater than number of documents in docPath file + * IOException : In case of IO exception while reading files. + */ + public int[][] computeNN(Path docPath, Path queryPath, int numThreads) throws IOException { + assert numThreads > 0; + final int numDocs = (int) (Files.size(docPath) / (dim * Float.BYTES)); + final int numQueries = (int) (Files.size(docPath) / (dim * Float.BYTES)); + + if (!quiet) { + System.out.println( + "computing true nearest neighbors of " + + numQueries + + " target vectors using " + + numThreads + + " threads."); + } + + try (FileChannel docInput = FileChannel.open(docPath); + FileChannel queryInput = FileChannel.open(queryPath)) { + return doFullKnn( + numDocs, + numQueries, + numThreads, + new FileChannelBufferProvider(docInput), + new FileChannelBufferProvider(queryInput)); + } + } + + int[][] doFullKnn( + int numDocs, + int numQueries, + int numThreads, + BufferProvider docInput, + BufferProvider queryInput) + throws IOException { + if (numDocs < topK) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "topK (%d) cannot be greater than number of docs in docPath (%d)", + topK, + numDocs)); + } + + final ExecutorService executorService = + Executors.newFixedThreadPool(numThreads, new NamedThreadFactory("FullKnnExecutor")); + int[][] result = new int[numQueries][]; + + FloatBuffer queries = queryInput.getBuffer(0, numQueries * dim * Float.BYTES).asFloatBuffer(); + float[] query = new float[dim]; + List jobList = new ArrayList<>(numThreads); + for (int i = 0; i < numQueries; ) { + + for (int j = 0; j < numThreads && i < numQueries; i++, j++) { + queries.get(query); + jobList.add( + new KnnJob(i, ArrayUtil.copyOfSubArray(query, 0, query.length), topK, searchStrategy)); + } + + long maxBufferSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES); + int docsLeft = numDocs; + int currDocIndex = 0; + int offset = 0; + while (docsLeft > 0) { + long totalBytes = (long) docsLeft * dim * Float.BYTES; + int blockSize = (int) Math.min(totalBytes, maxBufferSize); + + FloatBuffer docVectors = docInput.getBuffer(offset, blockSize).asFloatBuffer(); + offset += blockSize; + + final List> completableFutures = + jobList.stream() + .peek(job -> job.docVectors = docVectors.duplicate()) + .peek(job -> job.currDocIndex = currDocIndex) + .map(job -> CompletableFuture.runAsync(() -> job.execute(), executorService)) + .collect(Collectors.toList()); + + CompletableFuture.allOf( + completableFutures.toArray(new CompletableFuture[completableFutures.size()])) + .join(); + docsLeft -= (blockSize / (dim * Float.BYTES)); + } + + jobList.forEach( + job -> { + result[job.queryIndex] = new int[topK]; + for (int k = topK - 1; k >= 0; k--) { + result[job.queryIndex][k] = popNodeId(job.queue); + // System.out.print(" " + n); + } + if (!quiet && (job.queryIndex + 1) % 10 == 0) { + System.out.print(" " + (job.queryIndex + 1)); + System.out.flush(); + } + }); + + jobList.clear(); + } + + executorService.shutdown(); + try { + if (!executorService.awaitTermination(1, TimeUnit.MINUTES)) { + executorService.shutdownNow(); + } + } catch (InterruptedException e) { + executorService.shutdownNow(); + throw new RuntimeException( + "Exception occured while waiting for executor service to finish.", e); + } + + return result; + } + + /** + * pops the queue and returns the last 4 bytes of long where nodeId is stored. + * + * @param queue queue from which to pop the node id + * @return the node id + */ + private int popNodeId(LongHeap queue) { + return (int) queue.pop(); + } + + /** + * encodes the score and nodeId into a single long with score in first 4 bytes, to make it + * sortable by score. + * + * @param node node id + * @param score float score of the node wrt incoming node/query + * @return a score sortable long that can be used in heap + */ + private static long encodeNodeIdAndScore(int node, float score) { + return (((long) NumericUtils.floatToSortableInt(score)) << 32) | node; + } + + interface BufferProvider { + ByteBuffer getBuffer(int offset, int blockSize) throws IOException; + } + + private static class FileChannelBufferProvider implements BufferProvider { + private FileChannel fileChannel; + + FileChannelBufferProvider(FileChannel fileChannel) { + this.fileChannel = fileChannel; + } + + @Override + public ByteBuffer getBuffer(int offset, int blockSize) throws IOException { + return fileChannel + .map(FileChannel.MapMode.READ_ONLY, offset, blockSize) + .order(ByteOrder.LITTLE_ENDIAN); + } + } +} diff --git a/lucene/test-framework/src/test/org/apache/lucene/util/TestFullKnn.java b/lucene/test-framework/src/test/org/apache/lucene/util/TestFullKnn.java new file mode 100644 index 00000000000..58f1e1a6d80 --- /dev/null +++ b/lucene/test-framework/src/test/org/apache/lucene/util/TestFullKnn.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import org.apache.lucene.index.VectorValues; +import org.junit.Assert; + +public class TestFullKnn extends LuceneTestCase { + + static class SimpleBufferProvider implements FullKnn.BufferProvider { + private ByteBuffer buffer; + + SimpleBufferProvider(ByteBuffer buffer) { + this.buffer = buffer; + } + + @Override + public ByteBuffer getBuffer(int offset, int blockSize) throws IOException { + return buffer.position(offset).slice().order(ByteOrder.LITTLE_ENDIAN).limit(blockSize); + } + } + + FullKnn fullKnn; + float[] vec0, vec1, vec2; + + @Override + public void setUp() throws Exception { + super.setUp(); + fullKnn = new FullKnn(5, 3, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, true); + vec0 = new float[] {1, 2, 3, 4, 5}; + VectorUtil.l2normalize(vec0); + vec1 = new float[] {6, 7, 8, 9, 10}; + VectorUtil.l2normalize(vec1); + vec2 = new float[] {1, 2, 3, 4, 6}; + VectorUtil.l2normalize(vec2); + } + + private ByteBuffer floatArrayToByteBuffer(float[] floats) { + ByteBuffer byteBuf = + ByteBuffer.allocateDirect(floats.length * Float.BYTES); // 4 bytes per float + byteBuf.order(ByteOrder.LITTLE_ENDIAN); + FloatBuffer buffer = byteBuf.asFloatBuffer(); + buffer.put(floats); + return byteBuf; + } + + private void assertBufferEqualsArray(byte[] fa, ByteBuffer wholeBuffer) { + final byte[] tempBufferFloats = new byte[wholeBuffer.remaining()]; + wholeBuffer.get(tempBufferFloats); + Assert.assertArrayEquals(fa, tempBufferFloats); + } + + public float[] joinArrays(float[]... args) { + int length = 0; + for (float[] arr : args) { + length += arr.length; + } + + float[] merged = new float[length]; + int i = 0; + + for (float[] arr : args) { + for (float f : arr) { + merged[i++] = f; + } + } + + return merged; + } + + public void testSimpleFloatBufferProvider() throws IOException { + byte[] fa = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + + ByteBuffer fb = ByteBuffer.wrap(fa); + + final SimpleBufferProvider sf = new SimpleBufferProvider(fb); + + final ByteBuffer wholeBuffer = sf.getBuffer(0, 10 * Float.BYTES); + assertBufferEqualsArray(fa, wholeBuffer); + + final ByteBuffer fb5_5 = sf.getBuffer(5, 5); + assertBufferEqualsArray(ArrayUtil.copyOfSubArray(fa, 5, 10), fb5_5); + + final ByteBuffer fb5_3 = sf.getBuffer(5, 3); + assertBufferEqualsArray(ArrayUtil.copyOfSubArray(fa, 5, 8), fb5_3); + + final ByteBuffer fb2_12 = sf.getBuffer(5, 0); + assertBufferEqualsArray(new byte[] {}, fb2_12); + + final ByteBuffer fb5_1 = sf.getBuffer(5, 1); + assertBufferEqualsArray(ArrayUtil.copyOfSubArray(fa, 5, 6), fb5_1); + } + + public void testSuccessFullKnn() throws IOException { + float[] twoDocs = joinArrays(vec0, vec1); + float[] threeDocs = joinArrays(vec0, vec1, vec2); + float[] threeQueries = joinArrays(vec1, vec0, vec2); + + int[][] result = + fullKnn.doFullKnn( + 2, + 1, + 1, + new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)), + new SimpleBufferProvider(floatArrayToByteBuffer(vec1))); + + Assert.assertArrayEquals(result[0], new int[] {1, 0}); + + result = + fullKnn.doFullKnn( + 2, + 1, + 2, + new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)), + new SimpleBufferProvider(floatArrayToByteBuffer(vec0))); + + Assert.assertArrayEquals(result[0], new int[] {0, 1}); + + float[] twoQueries = joinArrays(vec1, vec0); + result = + fullKnn.doFullKnn( + 2, + 2, + 2, + new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)), + new SimpleBufferProvider(floatArrayToByteBuffer(twoQueries))); + + Assert.assertArrayEquals(result, new int[][] {{1, 0}, {0, 1}}); + + result = + fullKnn.doFullKnn( + 3, + 3, + 3, + new SimpleBufferProvider(floatArrayToByteBuffer(threeDocs)), + new SimpleBufferProvider(floatArrayToByteBuffer(threeQueries))); + + Assert.assertArrayEquals(new int[][] {{1, 0}, {0, 2}, {2, 0}}, result); + + FullKnn full3nn = new FullKnn(5, 3, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, true); + result = + full3nn.doFullKnn( + 3, + 3, + 3, + new SimpleBufferProvider(floatArrayToByteBuffer(threeDocs)), + new SimpleBufferProvider(floatArrayToByteBuffer(threeQueries))); + + Assert.assertArrayEquals(new int[][] {{1, 0, 2}, {0, 2, 1}, {2, 0, 1}}, result); + } + + public void testExceptionFullKnn() { + float[] twoDocs = joinArrays(vec0, vec1); + + expectThrows( + IllegalArgumentException.class, + () -> { + fullKnn.doFullKnn( + 2, + 1, + 1, + new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)), + new SimpleBufferProvider(floatArrayToByteBuffer(vec1))); + }); + } +}