mirror of https://github.com/apache/lucene.git
LUCENE-9798 : Fix looping bug and made Full Knn calculation parallelizable (#55)
This commit is contained in:
parent
a7b0aadcfc
commit
e7de06eb51
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
/** Utilities for computations with numeric arrays */
|
||||
public final class VectorUtil {
|
||||
|
||||
|
@ -112,6 +114,15 @@ public final class VectorUtil {
|
|||
return squareSum;
|
||||
}
|
||||
|
||||
public static float[] randomVector(Random random, int dim) {
|
||||
float[] vec = new float[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
vec[i] = random.nextFloat();
|
||||
}
|
||||
VectorUtil.l2normalize(vec);
|
||||
return vec;
|
||||
}
|
||||
|
||||
/**
|
||||
* Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is
|
||||
* thrown for zero vectors.
|
||||
|
|
|
@ -56,6 +56,7 @@ import org.apache.lucene.search.TopDocs;
|
|||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.FSDirectory;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FullKnn;
|
||||
import org.apache.lucene.util.IntroSorter;
|
||||
import org.apache.lucene.util.PrintStreamInfoStream;
|
||||
import org.apache.lucene.util.SuppressForbidden;
|
||||
|
@ -476,7 +477,9 @@ public class KnnGraphTester {
|
|||
if (Files.exists(nnPath)) {
|
||||
return readNN(nnPath);
|
||||
} else {
|
||||
int[][] nn = computeNN(docPath, queryPath);
|
||||
int[][] nn =
|
||||
new FullKnn(dim, topK, SEARCH_STRATEGY, quiet)
|
||||
.computeNN(docPath, queryPath, Runtime.getRuntime().availableProcessors());
|
||||
writeNN(nn, nnPath);
|
||||
return nn;
|
||||
}
|
||||
|
@ -511,59 +514,6 @@ public class KnnGraphTester {
|
|||
}
|
||||
}
|
||||
|
||||
private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
|
||||
int[][] result = new int[numIters][];
|
||||
if (quiet == false) {
|
||||
System.out.println("computing true nearest neighbors of " + numIters + " target vectors");
|
||||
}
|
||||
try (FileChannel in = FileChannel.open(docPath);
|
||||
FileChannel qIn = FileChannel.open(queryPath)) {
|
||||
FloatBuffer queries =
|
||||
qIn.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
|
||||
.order(ByteOrder.LITTLE_ENDIAN)
|
||||
.asFloatBuffer();
|
||||
float[] vector = new float[dim];
|
||||
float[] query = new float[dim];
|
||||
for (int i = 0; i < numIters; i++) {
|
||||
queries.get(query);
|
||||
long totalBytes = (long) numDocs * dim * Float.BYTES;
|
||||
int
|
||||
blockSize =
|
||||
(int)
|
||||
Math.min(
|
||||
totalBytes,
|
||||
(Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES)),
|
||||
offset = 0;
|
||||
int j = 0;
|
||||
// System.out.println("totalBytes=" + totalBytes);
|
||||
while (j < numDocs) {
|
||||
FloatBuffer vectors =
|
||||
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN)
|
||||
.asFloatBuffer();
|
||||
offset += blockSize;
|
||||
NeighborQueue queue = new NeighborQueue(topK, SEARCH_STRATEGY.reversed);
|
||||
for (; j < numDocs && vectors.hasRemaining(); j++) {
|
||||
vectors.get(vector);
|
||||
float d = SEARCH_STRATEGY.compare(query, vector);
|
||||
queue.insertWithOverflow(j, d);
|
||||
}
|
||||
result[i] = new int[topK];
|
||||
for (int k = topK - 1; k >= 0; k--) {
|
||||
result[i][k] = queue.topNode();
|
||||
queue.pop();
|
||||
// System.out.print(" " + n);
|
||||
}
|
||||
if (quiet == false && (i + 1) % 10 == 0) {
|
||||
System.out.print(" " + (i + 1));
|
||||
System.out.flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private int createIndex(Path docsPath, Path indexPath) throws IOException {
|
||||
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE);
|
||||
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
|
||||
|
|
|
@ -235,7 +235,7 @@ public class TestHnsw extends LuceneTestCase {
|
|||
HnswGraph hnsw = builder.build(vectors);
|
||||
int totalMatches = 0;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
float[] query = randomVector(random(), dim);
|
||||
float[] query = VectorUtil.randomVector(random(), dim);
|
||||
NeighborQueue actual = HnswGraph.search(query, topK, 100, vectors, hnsw, random());
|
||||
NeighborQueue expected = new NeighborQueue(topK, vectors.searchStrategy.reversed);
|
||||
for (int j = 0; j < size; j++) {
|
||||
|
@ -415,18 +415,9 @@ public class TestHnsw extends LuceneTestCase {
|
|||
private static float[][] createRandomVectors(int size, int dimension, Random random) {
|
||||
float[][] vectors = new float[size][];
|
||||
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
||||
vectors[offset] = randomVector(random, dimension);
|
||||
vectors[offset] = VectorUtil.randomVector(random, dimension);
|
||||
}
|
||||
return vectors;
|
||||
}
|
||||
}
|
||||
|
||||
private static float[] randomVector(Random random, int dim) {
|
||||
float[] vec = new float[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
vec[i] = random.nextFloat();
|
||||
}
|
||||
VectorUtil.l2normalize(vec);
|
||||
return vec;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,254 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.channels.FileChannel;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
|
||||
/**
|
||||
* A utility class to calculate the Full KNN / Exact KNN over a set of query vectors and document
|
||||
* vectors.
|
||||
*/
|
||||
public class FullKnn {
|
||||
|
||||
private final int dim;
|
||||
private final int topK;
|
||||
private final VectorValues.SearchStrategy searchStrategy;
|
||||
private final boolean quiet;
|
||||
|
||||
public FullKnn(int dim, int topK, VectorValues.SearchStrategy searchStrategy, boolean quiet) {
|
||||
this.dim = dim;
|
||||
this.topK = topK;
|
||||
this.searchStrategy = searchStrategy;
|
||||
this.quiet = quiet;
|
||||
}
|
||||
|
||||
/** internal object to track KNN calculation for one query */
|
||||
private static class KnnJob {
|
||||
public int currDocIndex;
|
||||
float[] queryVector;
|
||||
float[] currDocVector;
|
||||
int queryIndex;
|
||||
private LongHeap queue;
|
||||
FloatBuffer docVectors;
|
||||
VectorValues.SearchStrategy searchStrategy;
|
||||
|
||||
public KnnJob(
|
||||
int queryIndex, float[] queryVector, int topK, VectorValues.SearchStrategy searchStrategy) {
|
||||
this.queryIndex = queryIndex;
|
||||
this.queryVector = queryVector;
|
||||
this.currDocVector = new float[queryVector.length];
|
||||
if (searchStrategy.reversed) {
|
||||
queue = LongHeap.create(LongHeap.Order.MAX, topK);
|
||||
} else {
|
||||
queue = LongHeap.create(LongHeap.Order.MIN, topK);
|
||||
}
|
||||
this.searchStrategy = searchStrategy;
|
||||
}
|
||||
|
||||
public void execute() {
|
||||
while (this.docVectors.hasRemaining()) {
|
||||
this.docVectors.get(this.currDocVector);
|
||||
float d = this.searchStrategy.compare(this.queryVector, this.currDocVector);
|
||||
this.queue.insertWithOverflow(encodeNodeIdAndScore(this.currDocIndex, d));
|
||||
this.currDocIndex++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* computes the exact KNN match for each query vector in queryPath for all the document vectors in
|
||||
* docPath
|
||||
*
|
||||
* @param docPath : path to the file containing the float 32 document vectors in bytes with
|
||||
* little-endian byte order
|
||||
* @param queryPath : path to the file containing the containing 32-bit floating point vectors in
|
||||
* little-endian byte order
|
||||
* @param numThreads : create numThreads to parallelize work
|
||||
* @return : returns an int 2D array ( int matches[][]) of size 'numIters x topK'. matches[i] is
|
||||
* an array containing the indexes of the topK most similar document vectors to the ith query
|
||||
* vector, and is sorted by similarity, with the most similar vector first. Similarity is
|
||||
* defined by the searchStrategy used to construct this FullKnn.
|
||||
* @throws IllegalArgumentException : if topK is greater than number of documents in docPath file
|
||||
* IOException : In case of IO exception while reading files.
|
||||
*/
|
||||
public int[][] computeNN(Path docPath, Path queryPath, int numThreads) throws IOException {
|
||||
assert numThreads > 0;
|
||||
final int numDocs = (int) (Files.size(docPath) / (dim * Float.BYTES));
|
||||
final int numQueries = (int) (Files.size(docPath) / (dim * Float.BYTES));
|
||||
|
||||
if (!quiet) {
|
||||
System.out.println(
|
||||
"computing true nearest neighbors of "
|
||||
+ numQueries
|
||||
+ " target vectors using "
|
||||
+ numThreads
|
||||
+ " threads.");
|
||||
}
|
||||
|
||||
try (FileChannel docInput = FileChannel.open(docPath);
|
||||
FileChannel queryInput = FileChannel.open(queryPath)) {
|
||||
return doFullKnn(
|
||||
numDocs,
|
||||
numQueries,
|
||||
numThreads,
|
||||
new FileChannelBufferProvider(docInput),
|
||||
new FileChannelBufferProvider(queryInput));
|
||||
}
|
||||
}
|
||||
|
||||
int[][] doFullKnn(
|
||||
int numDocs,
|
||||
int numQueries,
|
||||
int numThreads,
|
||||
BufferProvider docInput,
|
||||
BufferProvider queryInput)
|
||||
throws IOException {
|
||||
if (numDocs < topK) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
Locale.ROOT,
|
||||
"topK (%d) cannot be greater than number of docs in docPath (%d)",
|
||||
topK,
|
||||
numDocs));
|
||||
}
|
||||
|
||||
final ExecutorService executorService =
|
||||
Executors.newFixedThreadPool(numThreads, new NamedThreadFactory("FullKnnExecutor"));
|
||||
int[][] result = new int[numQueries][];
|
||||
|
||||
FloatBuffer queries = queryInput.getBuffer(0, numQueries * dim * Float.BYTES).asFloatBuffer();
|
||||
float[] query = new float[dim];
|
||||
List<KnnJob> jobList = new ArrayList<>(numThreads);
|
||||
for (int i = 0; i < numQueries; ) {
|
||||
|
||||
for (int j = 0; j < numThreads && i < numQueries; i++, j++) {
|
||||
queries.get(query);
|
||||
jobList.add(
|
||||
new KnnJob(i, ArrayUtil.copyOfSubArray(query, 0, query.length), topK, searchStrategy));
|
||||
}
|
||||
|
||||
long maxBufferSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
|
||||
int docsLeft = numDocs;
|
||||
int currDocIndex = 0;
|
||||
int offset = 0;
|
||||
while (docsLeft > 0) {
|
||||
long totalBytes = (long) docsLeft * dim * Float.BYTES;
|
||||
int blockSize = (int) Math.min(totalBytes, maxBufferSize);
|
||||
|
||||
FloatBuffer docVectors = docInput.getBuffer(offset, blockSize).asFloatBuffer();
|
||||
offset += blockSize;
|
||||
|
||||
final List<CompletableFuture<Void>> completableFutures =
|
||||
jobList.stream()
|
||||
.peek(job -> job.docVectors = docVectors.duplicate())
|
||||
.peek(job -> job.currDocIndex = currDocIndex)
|
||||
.map(job -> CompletableFuture.runAsync(() -> job.execute(), executorService))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
CompletableFuture.allOf(
|
||||
completableFutures.toArray(new CompletableFuture<?>[completableFutures.size()]))
|
||||
.join();
|
||||
docsLeft -= (blockSize / (dim * Float.BYTES));
|
||||
}
|
||||
|
||||
jobList.forEach(
|
||||
job -> {
|
||||
result[job.queryIndex] = new int[topK];
|
||||
for (int k = topK - 1; k >= 0; k--) {
|
||||
result[job.queryIndex][k] = popNodeId(job.queue);
|
||||
// System.out.print(" " + n);
|
||||
}
|
||||
if (!quiet && (job.queryIndex + 1) % 10 == 0) {
|
||||
System.out.print(" " + (job.queryIndex + 1));
|
||||
System.out.flush();
|
||||
}
|
||||
});
|
||||
|
||||
jobList.clear();
|
||||
}
|
||||
|
||||
executorService.shutdown();
|
||||
try {
|
||||
if (!executorService.awaitTermination(1, TimeUnit.MINUTES)) {
|
||||
executorService.shutdownNow();
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
executorService.shutdownNow();
|
||||
throw new RuntimeException(
|
||||
"Exception occured while waiting for executor service to finish.", e);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* pops the queue and returns the last 4 bytes of long where nodeId is stored.
|
||||
*
|
||||
* @param queue queue from which to pop the node id
|
||||
* @return the node id
|
||||
*/
|
||||
private int popNodeId(LongHeap queue) {
|
||||
return (int) queue.pop();
|
||||
}
|
||||
|
||||
/**
|
||||
* encodes the score and nodeId into a single long with score in first 4 bytes, to make it
|
||||
* sortable by score.
|
||||
*
|
||||
* @param node node id
|
||||
* @param score float score of the node wrt incoming node/query
|
||||
* @return a score sortable long that can be used in heap
|
||||
*/
|
||||
private static long encodeNodeIdAndScore(int node, float score) {
|
||||
return (((long) NumericUtils.floatToSortableInt(score)) << 32) | node;
|
||||
}
|
||||
|
||||
interface BufferProvider {
|
||||
ByteBuffer getBuffer(int offset, int blockSize) throws IOException;
|
||||
}
|
||||
|
||||
private static class FileChannelBufferProvider implements BufferProvider {
|
||||
private FileChannel fileChannel;
|
||||
|
||||
FileChannelBufferProvider(FileChannel fileChannel) {
|
||||
this.fileChannel = fileChannel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteBuffer getBuffer(int offset, int blockSize) throws IOException {
|
||||
return fileChannel
|
||||
.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.FloatBuffer;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.junit.Assert;
|
||||
|
||||
public class TestFullKnn extends LuceneTestCase {
|
||||
|
||||
static class SimpleBufferProvider implements FullKnn.BufferProvider {
|
||||
private ByteBuffer buffer;
|
||||
|
||||
SimpleBufferProvider(ByteBuffer buffer) {
|
||||
this.buffer = buffer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteBuffer getBuffer(int offset, int blockSize) throws IOException {
|
||||
return buffer.position(offset).slice().order(ByteOrder.LITTLE_ENDIAN).limit(blockSize);
|
||||
}
|
||||
}
|
||||
|
||||
FullKnn fullKnn;
|
||||
float[] vec0, vec1, vec2;
|
||||
|
||||
@Override
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
fullKnn = new FullKnn(5, 3, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, true);
|
||||
vec0 = new float[] {1, 2, 3, 4, 5};
|
||||
VectorUtil.l2normalize(vec0);
|
||||
vec1 = new float[] {6, 7, 8, 9, 10};
|
||||
VectorUtil.l2normalize(vec1);
|
||||
vec2 = new float[] {1, 2, 3, 4, 6};
|
||||
VectorUtil.l2normalize(vec2);
|
||||
}
|
||||
|
||||
private ByteBuffer floatArrayToByteBuffer(float[] floats) {
|
||||
ByteBuffer byteBuf =
|
||||
ByteBuffer.allocateDirect(floats.length * Float.BYTES); // 4 bytes per float
|
||||
byteBuf.order(ByteOrder.LITTLE_ENDIAN);
|
||||
FloatBuffer buffer = byteBuf.asFloatBuffer();
|
||||
buffer.put(floats);
|
||||
return byteBuf;
|
||||
}
|
||||
|
||||
private void assertBufferEqualsArray(byte[] fa, ByteBuffer wholeBuffer) {
|
||||
final byte[] tempBufferFloats = new byte[wholeBuffer.remaining()];
|
||||
wholeBuffer.get(tempBufferFloats);
|
||||
Assert.assertArrayEquals(fa, tempBufferFloats);
|
||||
}
|
||||
|
||||
public float[] joinArrays(float[]... args) {
|
||||
int length = 0;
|
||||
for (float[] arr : args) {
|
||||
length += arr.length;
|
||||
}
|
||||
|
||||
float[] merged = new float[length];
|
||||
int i = 0;
|
||||
|
||||
for (float[] arr : args) {
|
||||
for (float f : arr) {
|
||||
merged[i++] = f;
|
||||
}
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
public void testSimpleFloatBufferProvider() throws IOException {
|
||||
byte[] fa = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
|
||||
|
||||
ByteBuffer fb = ByteBuffer.wrap(fa);
|
||||
|
||||
final SimpleBufferProvider sf = new SimpleBufferProvider(fb);
|
||||
|
||||
final ByteBuffer wholeBuffer = sf.getBuffer(0, 10 * Float.BYTES);
|
||||
assertBufferEqualsArray(fa, wholeBuffer);
|
||||
|
||||
final ByteBuffer fb5_5 = sf.getBuffer(5, 5);
|
||||
assertBufferEqualsArray(ArrayUtil.copyOfSubArray(fa, 5, 10), fb5_5);
|
||||
|
||||
final ByteBuffer fb5_3 = sf.getBuffer(5, 3);
|
||||
assertBufferEqualsArray(ArrayUtil.copyOfSubArray(fa, 5, 8), fb5_3);
|
||||
|
||||
final ByteBuffer fb2_12 = sf.getBuffer(5, 0);
|
||||
assertBufferEqualsArray(new byte[] {}, fb2_12);
|
||||
|
||||
final ByteBuffer fb5_1 = sf.getBuffer(5, 1);
|
||||
assertBufferEqualsArray(ArrayUtil.copyOfSubArray(fa, 5, 6), fb5_1);
|
||||
}
|
||||
|
||||
public void testSuccessFullKnn() throws IOException {
|
||||
float[] twoDocs = joinArrays(vec0, vec1);
|
||||
float[] threeDocs = joinArrays(vec0, vec1, vec2);
|
||||
float[] threeQueries = joinArrays(vec1, vec0, vec2);
|
||||
|
||||
int[][] result =
|
||||
fullKnn.doFullKnn(
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)),
|
||||
new SimpleBufferProvider(floatArrayToByteBuffer(vec1)));
|
||||
|
||||
Assert.assertArrayEquals(result[0], new int[] {1, 0});
|
||||
|
||||
result =
|
||||
fullKnn.doFullKnn(
|
||||
2,
|
||||
1,
|
||||
2,
|
||||
new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)),
|
||||
new SimpleBufferProvider(floatArrayToByteBuffer(vec0)));
|
||||
|
||||
Assert.assertArrayEquals(result[0], new int[] {0, 1});
|
||||
|
||||
float[] twoQueries = joinArrays(vec1, vec0);
|
||||
result =
|
||||
fullKnn.doFullKnn(
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)),
|
||||
new SimpleBufferProvider(floatArrayToByteBuffer(twoQueries)));
|
||||
|
||||
Assert.assertArrayEquals(result, new int[][] {{1, 0}, {0, 1}});
|
||||
|
||||
result =
|
||||
fullKnn.doFullKnn(
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
new SimpleBufferProvider(floatArrayToByteBuffer(threeDocs)),
|
||||
new SimpleBufferProvider(floatArrayToByteBuffer(threeQueries)));
|
||||
|
||||
Assert.assertArrayEquals(new int[][] {{1, 0}, {0, 2}, {2, 0}}, result);
|
||||
|
||||
FullKnn full3nn = new FullKnn(5, 3, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, true);
|
||||
result =
|
||||
full3nn.doFullKnn(
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
new SimpleBufferProvider(floatArrayToByteBuffer(threeDocs)),
|
||||
new SimpleBufferProvider(floatArrayToByteBuffer(threeQueries)));
|
||||
|
||||
Assert.assertArrayEquals(new int[][] {{1, 0, 2}, {0, 2, 1}, {2, 0, 1}}, result);
|
||||
}
|
||||
|
||||
public void testExceptionFullKnn() {
|
||||
float[] twoDocs = joinArrays(vec0, vec1);
|
||||
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> {
|
||||
fullKnn.doFullKnn(
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)),
|
||||
new SimpleBufferProvider(floatArrayToByteBuffer(vec1)));
|
||||
});
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue