mirror of https://github.com/apache/lucene.git
Revert "LUCENE-9798 : Fix looping bug and made Full Knn calculation parallelizable (#55)"
This reverts commit e7de06eb51
.
This commit is contained in:
parent
df0780843a
commit
757da76919
|
@ -17,8 +17,6 @@
|
||||||
|
|
||||||
package org.apache.lucene.util;
|
package org.apache.lucene.util;
|
||||||
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
/** Utilities for computations with numeric arrays */
|
/** Utilities for computations with numeric arrays */
|
||||||
public final class VectorUtil {
|
public final class VectorUtil {
|
||||||
|
|
||||||
|
@ -114,15 +112,6 @@ public final class VectorUtil {
|
||||||
return squareSum;
|
return squareSum;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static float[] randomVector(Random random, int dim) {
|
|
||||||
float[] vec = new float[dim];
|
|
||||||
for (int i = 0; i < dim; i++) {
|
|
||||||
vec[i] = random.nextFloat();
|
|
||||||
}
|
|
||||||
VectorUtil.l2normalize(vec);
|
|
||||||
return vec;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is
|
* Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is
|
||||||
* thrown for zero vectors.
|
* thrown for zero vectors.
|
||||||
|
|
|
@ -56,7 +56,6 @@ import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.store.FSDirectory;
|
import org.apache.lucene.store.FSDirectory;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.FullKnn;
|
|
||||||
import org.apache.lucene.util.IntroSorter;
|
import org.apache.lucene.util.IntroSorter;
|
||||||
import org.apache.lucene.util.PrintStreamInfoStream;
|
import org.apache.lucene.util.PrintStreamInfoStream;
|
||||||
import org.apache.lucene.util.SuppressForbidden;
|
import org.apache.lucene.util.SuppressForbidden;
|
||||||
|
@ -477,9 +476,7 @@ public class KnnGraphTester {
|
||||||
if (Files.exists(nnPath)) {
|
if (Files.exists(nnPath)) {
|
||||||
return readNN(nnPath);
|
return readNN(nnPath);
|
||||||
} else {
|
} else {
|
||||||
int[][] nn =
|
int[][] nn = computeNN(docPath, queryPath);
|
||||||
new FullKnn(dim, topK, SEARCH_STRATEGY, quiet)
|
|
||||||
.computeNN(docPath, queryPath, Runtime.getRuntime().availableProcessors());
|
|
||||||
writeNN(nn, nnPath);
|
writeNN(nn, nnPath);
|
||||||
return nn;
|
return nn;
|
||||||
}
|
}
|
||||||
|
@ -514,6 +511,59 @@ public class KnnGraphTester {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
|
||||||
|
int[][] result = new int[numIters][];
|
||||||
|
if (quiet == false) {
|
||||||
|
System.out.println("computing true nearest neighbors of " + numIters + " target vectors");
|
||||||
|
}
|
||||||
|
try (FileChannel in = FileChannel.open(docPath);
|
||||||
|
FileChannel qIn = FileChannel.open(queryPath)) {
|
||||||
|
FloatBuffer queries =
|
||||||
|
qIn.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
|
||||||
|
.order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
.asFloatBuffer();
|
||||||
|
float[] vector = new float[dim];
|
||||||
|
float[] query = new float[dim];
|
||||||
|
for (int i = 0; i < numIters; i++) {
|
||||||
|
queries.get(query);
|
||||||
|
long totalBytes = (long) numDocs * dim * Float.BYTES;
|
||||||
|
int
|
||||||
|
blockSize =
|
||||||
|
(int)
|
||||||
|
Math.min(
|
||||||
|
totalBytes,
|
||||||
|
(Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES)),
|
||||||
|
offset = 0;
|
||||||
|
int j = 0;
|
||||||
|
// System.out.println("totalBytes=" + totalBytes);
|
||||||
|
while (j < numDocs) {
|
||||||
|
FloatBuffer vectors =
|
||||||
|
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
|
||||||
|
.order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
.asFloatBuffer();
|
||||||
|
offset += blockSize;
|
||||||
|
NeighborQueue queue = new NeighborQueue(topK, SEARCH_STRATEGY.reversed);
|
||||||
|
for (; j < numDocs && vectors.hasRemaining(); j++) {
|
||||||
|
vectors.get(vector);
|
||||||
|
float d = SEARCH_STRATEGY.compare(query, vector);
|
||||||
|
queue.insertWithOverflow(j, d);
|
||||||
|
}
|
||||||
|
result[i] = new int[topK];
|
||||||
|
for (int k = topK - 1; k >= 0; k--) {
|
||||||
|
result[i][k] = queue.topNode();
|
||||||
|
queue.pop();
|
||||||
|
// System.out.print(" " + n);
|
||||||
|
}
|
||||||
|
if (quiet == false && (i + 1) % 10 == 0) {
|
||||||
|
System.out.print(" " + (i + 1));
|
||||||
|
System.out.flush();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
private int createIndex(Path docsPath, Path indexPath) throws IOException {
|
private int createIndex(Path docsPath, Path indexPath) throws IOException {
|
||||||
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE);
|
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE);
|
||||||
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
|
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
|
||||||
|
|
|
@ -235,7 +235,7 @@ public class TestHnsw extends LuceneTestCase {
|
||||||
HnswGraph hnsw = builder.build(vectors);
|
HnswGraph hnsw = builder.build(vectors);
|
||||||
int totalMatches = 0;
|
int totalMatches = 0;
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
float[] query = VectorUtil.randomVector(random(), dim);
|
float[] query = randomVector(random(), dim);
|
||||||
NeighborQueue actual = HnswGraph.search(query, topK, 100, vectors, hnsw, random());
|
NeighborQueue actual = HnswGraph.search(query, topK, 100, vectors, hnsw, random());
|
||||||
NeighborQueue expected = new NeighborQueue(topK, vectors.searchStrategy.reversed);
|
NeighborQueue expected = new NeighborQueue(topK, vectors.searchStrategy.reversed);
|
||||||
for (int j = 0; j < size; j++) {
|
for (int j = 0; j < size; j++) {
|
||||||
|
@ -415,9 +415,18 @@ public class TestHnsw extends LuceneTestCase {
|
||||||
private static float[][] createRandomVectors(int size, int dimension, Random random) {
|
private static float[][] createRandomVectors(int size, int dimension, Random random) {
|
||||||
float[][] vectors = new float[size][];
|
float[][] vectors = new float[size][];
|
||||||
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
||||||
vectors[offset] = VectorUtil.randomVector(random, dimension);
|
vectors[offset] = randomVector(random, dimension);
|
||||||
}
|
}
|
||||||
return vectors;
|
return vectors;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static float[] randomVector(Random random, int dim) {
|
||||||
|
float[] vec = new float[dim];
|
||||||
|
for (int i = 0; i < dim; i++) {
|
||||||
|
vec[i] = random.nextFloat();
|
||||||
|
}
|
||||||
|
VectorUtil.l2normalize(vec);
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,254 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.lucene.util;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.nio.ByteOrder;
|
|
||||||
import java.nio.FloatBuffer;
|
|
||||||
import java.nio.channels.FileChannel;
|
|
||||||
import java.nio.file.Files;
|
|
||||||
import java.nio.file.Path;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Locale;
|
|
||||||
import java.util.concurrent.CompletableFuture;
|
|
||||||
import java.util.concurrent.ExecutorService;
|
|
||||||
import java.util.concurrent.Executors;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import org.apache.lucene.index.VectorValues;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A utility class to calculate the Full KNN / Exact KNN over a set of query vectors and document
|
|
||||||
* vectors.
|
|
||||||
*/
|
|
||||||
public class FullKnn {
|
|
||||||
|
|
||||||
private final int dim;
|
|
||||||
private final int topK;
|
|
||||||
private final VectorValues.SearchStrategy searchStrategy;
|
|
||||||
private final boolean quiet;
|
|
||||||
|
|
||||||
public FullKnn(int dim, int topK, VectorValues.SearchStrategy searchStrategy, boolean quiet) {
|
|
||||||
this.dim = dim;
|
|
||||||
this.topK = topK;
|
|
||||||
this.searchStrategy = searchStrategy;
|
|
||||||
this.quiet = quiet;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** internal object to track KNN calculation for one query */
|
|
||||||
private static class KnnJob {
|
|
||||||
public int currDocIndex;
|
|
||||||
float[] queryVector;
|
|
||||||
float[] currDocVector;
|
|
||||||
int queryIndex;
|
|
||||||
private LongHeap queue;
|
|
||||||
FloatBuffer docVectors;
|
|
||||||
VectorValues.SearchStrategy searchStrategy;
|
|
||||||
|
|
||||||
public KnnJob(
|
|
||||||
int queryIndex, float[] queryVector, int topK, VectorValues.SearchStrategy searchStrategy) {
|
|
||||||
this.queryIndex = queryIndex;
|
|
||||||
this.queryVector = queryVector;
|
|
||||||
this.currDocVector = new float[queryVector.length];
|
|
||||||
if (searchStrategy.reversed) {
|
|
||||||
queue = LongHeap.create(LongHeap.Order.MAX, topK);
|
|
||||||
} else {
|
|
||||||
queue = LongHeap.create(LongHeap.Order.MIN, topK);
|
|
||||||
}
|
|
||||||
this.searchStrategy = searchStrategy;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void execute() {
|
|
||||||
while (this.docVectors.hasRemaining()) {
|
|
||||||
this.docVectors.get(this.currDocVector);
|
|
||||||
float d = this.searchStrategy.compare(this.queryVector, this.currDocVector);
|
|
||||||
this.queue.insertWithOverflow(encodeNodeIdAndScore(this.currDocIndex, d));
|
|
||||||
this.currDocIndex++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* computes the exact KNN match for each query vector in queryPath for all the document vectors in
|
|
||||||
* docPath
|
|
||||||
*
|
|
||||||
* @param docPath : path to the file containing the float 32 document vectors in bytes with
|
|
||||||
* little-endian byte order
|
|
||||||
* @param queryPath : path to the file containing the containing 32-bit floating point vectors in
|
|
||||||
* little-endian byte order
|
|
||||||
* @param numThreads : create numThreads to parallelize work
|
|
||||||
* @return : returns an int 2D array ( int matches[][]) of size 'numIters x topK'. matches[i] is
|
|
||||||
* an array containing the indexes of the topK most similar document vectors to the ith query
|
|
||||||
* vector, and is sorted by similarity, with the most similar vector first. Similarity is
|
|
||||||
* defined by the searchStrategy used to construct this FullKnn.
|
|
||||||
* @throws IllegalArgumentException : if topK is greater than number of documents in docPath file
|
|
||||||
* IOException : In case of IO exception while reading files.
|
|
||||||
*/
|
|
||||||
public int[][] computeNN(Path docPath, Path queryPath, int numThreads) throws IOException {
|
|
||||||
assert numThreads > 0;
|
|
||||||
final int numDocs = (int) (Files.size(docPath) / (dim * Float.BYTES));
|
|
||||||
final int numQueries = (int) (Files.size(docPath) / (dim * Float.BYTES));
|
|
||||||
|
|
||||||
if (!quiet) {
|
|
||||||
System.out.println(
|
|
||||||
"computing true nearest neighbors of "
|
|
||||||
+ numQueries
|
|
||||||
+ " target vectors using "
|
|
||||||
+ numThreads
|
|
||||||
+ " threads.");
|
|
||||||
}
|
|
||||||
|
|
||||||
try (FileChannel docInput = FileChannel.open(docPath);
|
|
||||||
FileChannel queryInput = FileChannel.open(queryPath)) {
|
|
||||||
return doFullKnn(
|
|
||||||
numDocs,
|
|
||||||
numQueries,
|
|
||||||
numThreads,
|
|
||||||
new FileChannelBufferProvider(docInput),
|
|
||||||
new FileChannelBufferProvider(queryInput));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int[][] doFullKnn(
|
|
||||||
int numDocs,
|
|
||||||
int numQueries,
|
|
||||||
int numThreads,
|
|
||||||
BufferProvider docInput,
|
|
||||||
BufferProvider queryInput)
|
|
||||||
throws IOException {
|
|
||||||
if (numDocs < topK) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
String.format(
|
|
||||||
Locale.ROOT,
|
|
||||||
"topK (%d) cannot be greater than number of docs in docPath (%d)",
|
|
||||||
topK,
|
|
||||||
numDocs));
|
|
||||||
}
|
|
||||||
|
|
||||||
final ExecutorService executorService =
|
|
||||||
Executors.newFixedThreadPool(numThreads, new NamedThreadFactory("FullKnnExecutor"));
|
|
||||||
int[][] result = new int[numQueries][];
|
|
||||||
|
|
||||||
FloatBuffer queries = queryInput.getBuffer(0, numQueries * dim * Float.BYTES).asFloatBuffer();
|
|
||||||
float[] query = new float[dim];
|
|
||||||
List<KnnJob> jobList = new ArrayList<>(numThreads);
|
|
||||||
for (int i = 0; i < numQueries; ) {
|
|
||||||
|
|
||||||
for (int j = 0; j < numThreads && i < numQueries; i++, j++) {
|
|
||||||
queries.get(query);
|
|
||||||
jobList.add(
|
|
||||||
new KnnJob(i, ArrayUtil.copyOfSubArray(query, 0, query.length), topK, searchStrategy));
|
|
||||||
}
|
|
||||||
|
|
||||||
long maxBufferSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
|
|
||||||
int docsLeft = numDocs;
|
|
||||||
int currDocIndex = 0;
|
|
||||||
int offset = 0;
|
|
||||||
while (docsLeft > 0) {
|
|
||||||
long totalBytes = (long) docsLeft * dim * Float.BYTES;
|
|
||||||
int blockSize = (int) Math.min(totalBytes, maxBufferSize);
|
|
||||||
|
|
||||||
FloatBuffer docVectors = docInput.getBuffer(offset, blockSize).asFloatBuffer();
|
|
||||||
offset += blockSize;
|
|
||||||
|
|
||||||
final List<CompletableFuture<Void>> completableFutures =
|
|
||||||
jobList.stream()
|
|
||||||
.peek(job -> job.docVectors = docVectors.duplicate())
|
|
||||||
.peek(job -> job.currDocIndex = currDocIndex)
|
|
||||||
.map(job -> CompletableFuture.runAsync(() -> job.execute(), executorService))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
|
|
||||||
CompletableFuture.allOf(
|
|
||||||
completableFutures.toArray(new CompletableFuture<?>[completableFutures.size()]))
|
|
||||||
.join();
|
|
||||||
docsLeft -= (blockSize / (dim * Float.BYTES));
|
|
||||||
}
|
|
||||||
|
|
||||||
jobList.forEach(
|
|
||||||
job -> {
|
|
||||||
result[job.queryIndex] = new int[topK];
|
|
||||||
for (int k = topK - 1; k >= 0; k--) {
|
|
||||||
result[job.queryIndex][k] = popNodeId(job.queue);
|
|
||||||
// System.out.print(" " + n);
|
|
||||||
}
|
|
||||||
if (!quiet && (job.queryIndex + 1) % 10 == 0) {
|
|
||||||
System.out.print(" " + (job.queryIndex + 1));
|
|
||||||
System.out.flush();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
jobList.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
executorService.shutdown();
|
|
||||||
try {
|
|
||||||
if (!executorService.awaitTermination(1, TimeUnit.MINUTES)) {
|
|
||||||
executorService.shutdownNow();
|
|
||||||
}
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
executorService.shutdownNow();
|
|
||||||
throw new RuntimeException(
|
|
||||||
"Exception occured while waiting for executor service to finish.", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* pops the queue and returns the last 4 bytes of long where nodeId is stored.
|
|
||||||
*
|
|
||||||
* @param queue queue from which to pop the node id
|
|
||||||
* @return the node id
|
|
||||||
*/
|
|
||||||
private int popNodeId(LongHeap queue) {
|
|
||||||
return (int) queue.pop();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* encodes the score and nodeId into a single long with score in first 4 bytes, to make it
|
|
||||||
* sortable by score.
|
|
||||||
*
|
|
||||||
* @param node node id
|
|
||||||
* @param score float score of the node wrt incoming node/query
|
|
||||||
* @return a score sortable long that can be used in heap
|
|
||||||
*/
|
|
||||||
private static long encodeNodeIdAndScore(int node, float score) {
|
|
||||||
return (((long) NumericUtils.floatToSortableInt(score)) << 32) | node;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface BufferProvider {
|
|
||||||
ByteBuffer getBuffer(int offset, int blockSize) throws IOException;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class FileChannelBufferProvider implements BufferProvider {
|
|
||||||
private FileChannel fileChannel;
|
|
||||||
|
|
||||||
FileChannelBufferProvider(FileChannel fileChannel) {
|
|
||||||
this.fileChannel = fileChannel;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ByteBuffer getBuffer(int offset, int blockSize) throws IOException {
|
|
||||||
return fileChannel
|
|
||||||
.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
|
|
||||||
.order(ByteOrder.LITTLE_ENDIAN);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,185 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.lucene.util;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.nio.ByteOrder;
|
|
||||||
import java.nio.FloatBuffer;
|
|
||||||
import org.apache.lucene.index.VectorValues;
|
|
||||||
import org.junit.Assert;
|
|
||||||
|
|
||||||
public class TestFullKnn extends LuceneTestCase {
|
|
||||||
|
|
||||||
static class SimpleBufferProvider implements FullKnn.BufferProvider {
|
|
||||||
private ByteBuffer buffer;
|
|
||||||
|
|
||||||
SimpleBufferProvider(ByteBuffer buffer) {
|
|
||||||
this.buffer = buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ByteBuffer getBuffer(int offset, int blockSize) throws IOException {
|
|
||||||
return buffer.position(offset).slice().order(ByteOrder.LITTLE_ENDIAN).limit(blockSize);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
FullKnn fullKnn;
|
|
||||||
float[] vec0, vec1, vec2;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setUp() throws Exception {
|
|
||||||
super.setUp();
|
|
||||||
fullKnn = new FullKnn(5, 3, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, true);
|
|
||||||
vec0 = new float[] {1, 2, 3, 4, 5};
|
|
||||||
VectorUtil.l2normalize(vec0);
|
|
||||||
vec1 = new float[] {6, 7, 8, 9, 10};
|
|
||||||
VectorUtil.l2normalize(vec1);
|
|
||||||
vec2 = new float[] {1, 2, 3, 4, 6};
|
|
||||||
VectorUtil.l2normalize(vec2);
|
|
||||||
}
|
|
||||||
|
|
||||||
private ByteBuffer floatArrayToByteBuffer(float[] floats) {
|
|
||||||
ByteBuffer byteBuf =
|
|
||||||
ByteBuffer.allocateDirect(floats.length * Float.BYTES); // 4 bytes per float
|
|
||||||
byteBuf.order(ByteOrder.LITTLE_ENDIAN);
|
|
||||||
FloatBuffer buffer = byteBuf.asFloatBuffer();
|
|
||||||
buffer.put(floats);
|
|
||||||
return byteBuf;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void assertBufferEqualsArray(byte[] fa, ByteBuffer wholeBuffer) {
|
|
||||||
final byte[] tempBufferFloats = new byte[wholeBuffer.remaining()];
|
|
||||||
wholeBuffer.get(tempBufferFloats);
|
|
||||||
Assert.assertArrayEquals(fa, tempBufferFloats);
|
|
||||||
}
|
|
||||||
|
|
||||||
public float[] joinArrays(float[]... args) {
|
|
||||||
int length = 0;
|
|
||||||
for (float[] arr : args) {
|
|
||||||
length += arr.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
float[] merged = new float[length];
|
|
||||||
int i = 0;
|
|
||||||
|
|
||||||
for (float[] arr : args) {
|
|
||||||
for (float f : arr) {
|
|
||||||
merged[i++] = f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return merged;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testSimpleFloatBufferProvider() throws IOException {
|
|
||||||
byte[] fa = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
|
|
||||||
ByteBuffer fb = ByteBuffer.wrap(fa);
|
|
||||||
|
|
||||||
final SimpleBufferProvider sf = new SimpleBufferProvider(fb);
|
|
||||||
|
|
||||||
final ByteBuffer wholeBuffer = sf.getBuffer(0, 10 * Float.BYTES);
|
|
||||||
assertBufferEqualsArray(fa, wholeBuffer);
|
|
||||||
|
|
||||||
final ByteBuffer fb5_5 = sf.getBuffer(5, 5);
|
|
||||||
assertBufferEqualsArray(ArrayUtil.copyOfSubArray(fa, 5, 10), fb5_5);
|
|
||||||
|
|
||||||
final ByteBuffer fb5_3 = sf.getBuffer(5, 3);
|
|
||||||
assertBufferEqualsArray(ArrayUtil.copyOfSubArray(fa, 5, 8), fb5_3);
|
|
||||||
|
|
||||||
final ByteBuffer fb2_12 = sf.getBuffer(5, 0);
|
|
||||||
assertBufferEqualsArray(new byte[] {}, fb2_12);
|
|
||||||
|
|
||||||
final ByteBuffer fb5_1 = sf.getBuffer(5, 1);
|
|
||||||
assertBufferEqualsArray(ArrayUtil.copyOfSubArray(fa, 5, 6), fb5_1);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testSuccessFullKnn() throws IOException {
|
|
||||||
float[] twoDocs = joinArrays(vec0, vec1);
|
|
||||||
float[] threeDocs = joinArrays(vec0, vec1, vec2);
|
|
||||||
float[] threeQueries = joinArrays(vec1, vec0, vec2);
|
|
||||||
|
|
||||||
int[][] result =
|
|
||||||
fullKnn.doFullKnn(
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)),
|
|
||||||
new SimpleBufferProvider(floatArrayToByteBuffer(vec1)));
|
|
||||||
|
|
||||||
Assert.assertArrayEquals(result[0], new int[] {1, 0});
|
|
||||||
|
|
||||||
result =
|
|
||||||
fullKnn.doFullKnn(
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)),
|
|
||||||
new SimpleBufferProvider(floatArrayToByteBuffer(vec0)));
|
|
||||||
|
|
||||||
Assert.assertArrayEquals(result[0], new int[] {0, 1});
|
|
||||||
|
|
||||||
float[] twoQueries = joinArrays(vec1, vec0);
|
|
||||||
result =
|
|
||||||
fullKnn.doFullKnn(
|
|
||||||
2,
|
|
||||||
2,
|
|
||||||
2,
|
|
||||||
new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)),
|
|
||||||
new SimpleBufferProvider(floatArrayToByteBuffer(twoQueries)));
|
|
||||||
|
|
||||||
Assert.assertArrayEquals(result, new int[][] {{1, 0}, {0, 1}});
|
|
||||||
|
|
||||||
result =
|
|
||||||
fullKnn.doFullKnn(
|
|
||||||
3,
|
|
||||||
3,
|
|
||||||
3,
|
|
||||||
new SimpleBufferProvider(floatArrayToByteBuffer(threeDocs)),
|
|
||||||
new SimpleBufferProvider(floatArrayToByteBuffer(threeQueries)));
|
|
||||||
|
|
||||||
Assert.assertArrayEquals(new int[][] {{1, 0}, {0, 2}, {2, 0}}, result);
|
|
||||||
|
|
||||||
FullKnn full3nn = new FullKnn(5, 3, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, true);
|
|
||||||
result =
|
|
||||||
full3nn.doFullKnn(
|
|
||||||
3,
|
|
||||||
3,
|
|
||||||
3,
|
|
||||||
new SimpleBufferProvider(floatArrayToByteBuffer(threeDocs)),
|
|
||||||
new SimpleBufferProvider(floatArrayToByteBuffer(threeQueries)));
|
|
||||||
|
|
||||||
Assert.assertArrayEquals(new int[][] {{1, 0, 2}, {0, 2, 1}, {2, 0, 1}}, result);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testExceptionFullKnn() {
|
|
||||||
float[] twoDocs = joinArrays(vec0, vec1);
|
|
||||||
|
|
||||||
expectThrows(
|
|
||||||
IllegalArgumentException.class,
|
|
||||||
() -> {
|
|
||||||
fullKnn.doFullKnn(
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
new SimpleBufferProvider(floatArrayToByteBuffer(twoDocs)),
|
|
||||||
new SimpleBufferProvider(floatArrayToByteBuffer(vec1)));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue