Small edits for KnnGraphTester (#575)

1. Correct the remaining size for input files larger
than Integer.MAX_VALUE, as currently with every
iteration we try to map the next blockSize of bytes
even if less < blockSize bytes are left in the file.

2. Correct java.lang.ClassCastException when retrieving
KnnGraphValues for stats printing.

3. Add an option for euclidean metric
This commit is contained in:
Mayya Sharipova 2022-01-12 17:23:10 -05:00 committed by GitHub
parent 8d9fa6dba1
commit bd2cc4124d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 19 deletions

View File

@ -36,9 +36,11 @@ import java.util.HashSet;
import java.util.Locale;
import java.util.Set;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnVectorField;
@ -74,8 +76,6 @@ public class KnnGraphTester {
private static final String KNN_FIELD = "knn";
private static final String ID_FIELD = "id";
private static final VectorSimilarityFunction SIMILARITY_FUNCTION =
VectorSimilarityFunction.DOT_PRODUCT;
private int numDocs;
private int dim;
@ -90,6 +90,7 @@ public class KnnGraphTester {
private int reindexTimeMsec;
private int beamWidth;
private int maxConn;
private VectorSimilarityFunction similarityFunction;
@SuppressForbidden(reason = "uses Random()")
private KnnGraphTester() {
@ -100,6 +101,7 @@ public class KnnGraphTester {
topK = 100;
warmCount = 1000;
fanout = topK;
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
}
public static void main(String... args) throws Exception {
@ -183,6 +185,14 @@ public class KnnGraphTester {
case "-docs":
docVectorsPath = Paths.get(args[++iarg]);
break;
case "-metric":
String metric = args[++iarg];
if (metric.equals("euclidean")) {
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
} else if (metric.equals("angular") == false) {
throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only");
}
break;
case "-forceMerge":
forceMerge = true;
break;
@ -237,12 +247,13 @@ public class KnnGraphTester {
private void printFanoutHist(Path indexPath) throws IOException {
try (Directory dir = FSDirectory.open(indexPath);
DirectoryReader reader = DirectoryReader.open(dir)) {
// int[] globalHist = new int[reader.maxDoc()];
for (LeafReaderContext context : reader.leaves()) {
LeafReader leafReader = context.reader();
KnnVectorsReader vectorsReader =
((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) leafReader).getVectorReader())
.getFieldReader(KNN_FIELD);
KnnGraphValues knnValues =
((Lucene90HnswVectorsReader) ((CodecReader) leafReader).getVectorReader())
.getGraphValues(KNN_FIELD);
((Lucene90HnswVectorsReader) vectorsReader).getGraphValues(KNN_FIELD);
System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc());
printGraphFanout(knnValues, leafReader.maxDoc());
}
@ -253,7 +264,7 @@ public class KnnGraphTester {
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
RandomAccessVectorValues values = vectors.randomAccess();
HnswGraphBuilder builder =
new HnswGraphBuilder(vectors, SIMILARITY_FUNCTION, maxConn, beamWidth, 0);
new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, 0);
// start at node 1
for (int i = 1; i < numDocs; i++) {
builder.addGraphNode(values.vectorValue(i));
@ -533,25 +544,21 @@ public class KnnGraphTester {
for (int i = 0; i < numIters; i++) {
queries.get(query);
long totalBytes = (long) numDocs * dim * Float.BYTES;
int
blockSize =
(int)
Math.min(
totalBytes,
(Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES)),
offset = 0;
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
int offset = 0;
int j = 0;
// System.out.println("totalBytes=" + totalBytes);
while (j < numDocs) {
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
FloatBuffer vectors =
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
NeighborQueue queue = new NeighborQueue(topK, SIMILARITY_FUNCTION.reversed);
NeighborQueue queue = new NeighborQueue(topK, similarityFunction.reversed);
for (; j < numDocs && vectors.hasRemaining(); j++) {
vectors.get(vector);
float d = SIMILARITY_FUNCTION.compare(query, vector);
float d = similarityFunction.compare(query, vector);
queue.insertWithOverflow(j, d);
}
result[i] = new int[topK];
@ -583,22 +590,22 @@ public class KnnGraphTester {
iwc.setRAMBufferSizeMB(1994d);
// iwc.setMaxBufferedDocs(10000);
FieldType fieldType = KnnVectorField.createFieldType(dim, VectorSimilarityFunction.DOT_PRODUCT);
FieldType fieldType = KnnVectorField.createFieldType(dim, similarityFunction);
if (quiet == false) {
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
System.out.println("creating index in " + indexPath);
}
long start = System.nanoTime();
long totalBytes = (long) numDocs * dim * Float.BYTES, offset = 0;
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
try (FSDirectory dir = FSDirectory.open(indexPath);
IndexWriter iw = new IndexWriter(dir, iwc)) {
int blockSize =
(int)
Math.min(totalBytes, (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES));
float[] vector = new float[dim];
try (FileChannel in = FileChannel.open(docsPath)) {
int i = 0;
while (i < numDocs) {
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
FloatBuffer vectors =
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
.order(ByteOrder.LITTLE_ENDIAN)