mirror of https://github.com/apache/lucene.git
Small edits for KnnGraphTester (#575)
1. Correct the remaining size for input files larger than Integer.MAX_VALUE, as currently with every iteration we try to map the next blockSize of bytes even if less < blockSize bytes are left in the file. 2. Correct java.lang.ClassCastException when retrieving KnnGraphValues for stats printing. 3. Add an option for euclidean metric
This commit is contained in:
parent
8d9fa6dba1
commit
bd2cc4124d
|
@ -36,9 +36,11 @@ import java.util.HashSet;
|
|||
import java.util.Locale;
|
||||
import java.util.Set;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsReader;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
|
@ -74,8 +76,6 @@ public class KnnGraphTester {
|
|||
|
||||
private static final String KNN_FIELD = "knn";
|
||||
private static final String ID_FIELD = "id";
|
||||
private static final VectorSimilarityFunction SIMILARITY_FUNCTION =
|
||||
VectorSimilarityFunction.DOT_PRODUCT;
|
||||
|
||||
private int numDocs;
|
||||
private int dim;
|
||||
|
@ -90,6 +90,7 @@ public class KnnGraphTester {
|
|||
private int reindexTimeMsec;
|
||||
private int beamWidth;
|
||||
private int maxConn;
|
||||
private VectorSimilarityFunction similarityFunction;
|
||||
|
||||
@SuppressForbidden(reason = "uses Random()")
|
||||
private KnnGraphTester() {
|
||||
|
@ -100,6 +101,7 @@ public class KnnGraphTester {
|
|||
topK = 100;
|
||||
warmCount = 1000;
|
||||
fanout = topK;
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
}
|
||||
|
||||
public static void main(String... args) throws Exception {
|
||||
|
@ -183,6 +185,14 @@ public class KnnGraphTester {
|
|||
case "-docs":
|
||||
docVectorsPath = Paths.get(args[++iarg]);
|
||||
break;
|
||||
case "-metric":
|
||||
String metric = args[++iarg];
|
||||
if (metric.equals("euclidean")) {
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
} else if (metric.equals("angular") == false) {
|
||||
throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only");
|
||||
}
|
||||
break;
|
||||
case "-forceMerge":
|
||||
forceMerge = true;
|
||||
break;
|
||||
|
@ -237,12 +247,13 @@ public class KnnGraphTester {
|
|||
private void printFanoutHist(Path indexPath) throws IOException {
|
||||
try (Directory dir = FSDirectory.open(indexPath);
|
||||
DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||
// int[] globalHist = new int[reader.maxDoc()];
|
||||
for (LeafReaderContext context : reader.leaves()) {
|
||||
LeafReader leafReader = context.reader();
|
||||
KnnVectorsReader vectorsReader =
|
||||
((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) leafReader).getVectorReader())
|
||||
.getFieldReader(KNN_FIELD);
|
||||
KnnGraphValues knnValues =
|
||||
((Lucene90HnswVectorsReader) ((CodecReader) leafReader).getVectorReader())
|
||||
.getGraphValues(KNN_FIELD);
|
||||
((Lucene90HnswVectorsReader) vectorsReader).getGraphValues(KNN_FIELD);
|
||||
System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc());
|
||||
printGraphFanout(knnValues, leafReader.maxDoc());
|
||||
}
|
||||
|
@ -253,7 +264,7 @@ public class KnnGraphTester {
|
|||
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
|
||||
RandomAccessVectorValues values = vectors.randomAccess();
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(vectors, SIMILARITY_FUNCTION, maxConn, beamWidth, 0);
|
||||
new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, 0);
|
||||
// start at node 1
|
||||
for (int i = 1; i < numDocs; i++) {
|
||||
builder.addGraphNode(values.vectorValue(i));
|
||||
|
@ -533,25 +544,21 @@ public class KnnGraphTester {
|
|||
for (int i = 0; i < numIters; i++) {
|
||||
queries.get(query);
|
||||
long totalBytes = (long) numDocs * dim * Float.BYTES;
|
||||
int
|
||||
blockSize =
|
||||
(int)
|
||||
Math.min(
|
||||
totalBytes,
|
||||
(Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES)),
|
||||
offset = 0;
|
||||
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
|
||||
int offset = 0;
|
||||
int j = 0;
|
||||
// System.out.println("totalBytes=" + totalBytes);
|
||||
while (j < numDocs) {
|
||||
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
|
||||
FloatBuffer vectors =
|
||||
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN)
|
||||
.asFloatBuffer();
|
||||
offset += blockSize;
|
||||
NeighborQueue queue = new NeighborQueue(topK, SIMILARITY_FUNCTION.reversed);
|
||||
NeighborQueue queue = new NeighborQueue(topK, similarityFunction.reversed);
|
||||
for (; j < numDocs && vectors.hasRemaining(); j++) {
|
||||
vectors.get(vector);
|
||||
float d = SIMILARITY_FUNCTION.compare(query, vector);
|
||||
float d = similarityFunction.compare(query, vector);
|
||||
queue.insertWithOverflow(j, d);
|
||||
}
|
||||
result[i] = new int[topK];
|
||||
|
@ -583,22 +590,22 @@ public class KnnGraphTester {
|
|||
iwc.setRAMBufferSizeMB(1994d);
|
||||
// iwc.setMaxBufferedDocs(10000);
|
||||
|
||||
FieldType fieldType = KnnVectorField.createFieldType(dim, VectorSimilarityFunction.DOT_PRODUCT);
|
||||
FieldType fieldType = KnnVectorField.createFieldType(dim, similarityFunction);
|
||||
if (quiet == false) {
|
||||
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
|
||||
System.out.println("creating index in " + indexPath);
|
||||
}
|
||||
long start = System.nanoTime();
|
||||
long totalBytes = (long) numDocs * dim * Float.BYTES, offset = 0;
|
||||
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
|
||||
|
||||
try (FSDirectory dir = FSDirectory.open(indexPath);
|
||||
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
int blockSize =
|
||||
(int)
|
||||
Math.min(totalBytes, (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES));
|
||||
float[] vector = new float[dim];
|
||||
try (FileChannel in = FileChannel.open(docsPath)) {
|
||||
int i = 0;
|
||||
while (i < numDocs) {
|
||||
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
|
||||
FloatBuffer vectors =
|
||||
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN)
|
||||
|
|
Loading…
Reference in New Issue