LUCENE-10559: Add Prefilter Option to KnnGraphTester (#932)

Added a `prefilter` and `filterSelectivity` argument to KnnGraphTester to be
able to compare pre and post-filtering benchmarks.

`filterSelectivity` expresses the selectivity of a filter as proportion of
passing docs that are randomly selected. We store these in a FixedBitSet and
use this to calculate true KNN as well as in HNSW search.

In case of post-filter, we over-select results as `topK / filterSelectivity` to
get final hits close to actual requested `topK`. For pre-filter, we wrap the
FixedBitSet in a query and pass it as prefilter argument to KnnVectorQuery.
This commit is contained in:
Kaival Parikh 2022-07-29 23:51:34 +05:30 committed by GitHub
parent eb7b7791ba
commit 1ad28a3136
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 126 additions and 16 deletions

View File

@ -112,7 +112,7 @@ Bug Fixes
Other
---------------------
(No changes)
* LUCENE-10559: Add Prefilter Option to KnnGraphTester (Kaival Parikh)
======================== Lucene 9.3.0 =======================

View File

@ -33,6 +33,7 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileTime;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Locale;
import java.util.Objects;
@ -56,13 +57,22 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.PrintStreamInfoStream;
import org.apache.lucene.util.SuppressForbidden;
@ -91,8 +101,10 @@ public class KnnGraphTester {
private int beamWidth;
private int maxConn;
private VectorSimilarityFunction similarityFunction;
private FixedBitSet matchDocs;
private float selectivity;
private boolean prefilter;
@SuppressForbidden(reason = "uses Random()")
private KnnGraphTester() {
// set defaults
numDocs = 1000;
@ -101,6 +113,8 @@ public class KnnGraphTester {
topK = 100;
fanout = topK;
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
selectivity = 1f;
prefilter = false;
}
public static void main(String... args) throws Exception {
@ -192,6 +206,18 @@ public class KnnGraphTester {
case "-forceMerge":
forceMerge = true;
break;
case "-prefilter":
prefilter = true;
break;
case "-filterSelectivity":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-filterSelectivity requires a following float");
}
selectivity = Float.parseFloat(args[++iarg]);
if (selectivity <= 0 || selectivity >= 1) {
throw new IllegalArgumentException("-filterSelectivity must be between 0 and 1");
}
break;
case "-quiet":
quiet = true;
break;
@ -203,6 +229,9 @@ public class KnnGraphTester {
if (operation == null && reindex == false) {
usage();
}
if (prefilter == true && selectivity == 1f) {
throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1");
}
indexPath = Paths.get(formatIndexPath(docVectorsPath));
if (reindex) {
if (docVectorsPath == null) {
@ -219,6 +248,7 @@ public class KnnGraphTester {
if (docVectorsPath == null) {
throw new IllegalArgumentException("missing -docs arg");
}
matchDocs = generateRandomBitSet(numDocs, selectivity);
if (outputPath != null) {
testSearch(indexPath, queryPath, outputPath, null);
} else {
@ -362,17 +392,33 @@ public class KnnGraphTester {
DirectoryReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader);
numDocs = reader.maxDoc();
Query bitSetQuery = new BitSetQuery(matchDocs);
for (int i = 0; i < numIters; i++) {
// warm up
targets.get(target);
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout);
if (prefilter) {
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
} else {
doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
}
}
targets.position(0);
start = System.nanoTime();
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
for (int i = 0; i < numIters; i++) {
targets.get(target);
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout);
if (prefilter) {
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
} else {
results[i] =
doKnnVectorQuery(
searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
results[i].scoreDocs =
Arrays.stream(results[i].scoreDocs)
.filter(scoreDoc -> matchDocs == null || matchDocs.get(scoreDoc.doc))
.toArray(ScoreDoc[]::new);
}
}
totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms
@ -417,7 +463,7 @@ public class KnnGraphTester {
totalVisited /= numIters;
System.out.printf(
Locale.ROOT,
"%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\n",
"%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\t%.2f\t%s\n",
recall,
totalCpuTime / (float) numIters,
numDocs,
@ -425,21 +471,22 @@ public class KnnGraphTester {
maxConn,
beamWidth,
totalVisited,
reindexTimeMsec);
reindexTimeMsec,
selectivity,
prefilter ? "pre-filter" : "post-filter");
}
}
private static TopDocs doKnnVectorQuery(
IndexSearcher searcher, String field, float[] vector, int k, int fanout) throws IOException {
return searcher.search(new KnnVectorQuery(field, vector, k + fanout), k);
IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter)
throws IOException {
return searcher.search(new KnnVectorQuery(field, vector, k + fanout, filter), k);
}
private float checkResults(TopDocs[] results, int[][] nn) {
int totalMatches = 0;
int totalResults = 0;
int totalResults = results.length * topK;
for (int i = 0; i < results.length; i++) {
int n = results[i].scoreDocs.length;
totalResults += n;
// System.out.println(Arrays.toString(nn[i]));
// System.out.println(Arrays.toString(results[i].scoreDocs));
totalMatches += compareNN(nn[i], results[i]);
@ -463,7 +510,7 @@ public class KnnGraphTester {
System.out.print('\n');
*/
Set<Integer> expectedSet = new HashSet<>();
for (int i = 0; i < results.scoreDocs.length; i++) {
for (int i = 0; i < topK; i++) {
expectedSet.add(expected[i]);
}
for (ScoreDoc scoreDoc : results.scoreDocs) {
@ -479,11 +526,13 @@ public class KnnGraphTester {
String hash = Integer.toString(Objects.hash(docPath, queryPath, numDocs, numIters, topK), 36);
String nnFileName = "nn-" + hash + ".bin";
Path nnPath = Paths.get(nnFileName);
if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath)) {
if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath) && selectivity == 1f) {
return readNN(nnPath);
} else {
int[][] nn = computeNN(docPath, queryPath);
writeNN(nn, nnPath);
if (selectivity == 1f) {
writeNN(nn, nnPath);
}
return nn;
}
}
@ -527,6 +576,19 @@ public class KnnGraphTester {
}
}
@SuppressForbidden(reason = "Uses random()")
private static FixedBitSet generateRandomBitSet(int size, float selectivity) {
FixedBitSet bitSet = new FixedBitSet(size);
for (int i = 0; i < size; i++) {
if (Math.random() < selectivity) {
bitSet.set(i);
} else {
bitSet.clear(i);
}
}
return bitSet;
}
private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
int[][] result = new int[numIters][];
if (quiet == false) {
@ -558,7 +620,9 @@ public class KnnGraphTester {
for (; j < numDocs && vectors.hasRemaining(); j++) {
vectors.get(vector);
float d = similarityFunction.compare(query, vector);
queue.insertWithOverflow(j, d);
if (matchDocs == null || matchDocs.get(j)) {
queue.insertWithOverflow(j, d);
}
}
result[i] = new int[topK];
for (int k = topK - 1; k >= 0; k--) {
@ -633,7 +697,7 @@ public class KnnGraphTester {
private static void usage() {
String error =
"Usage: TestKnnGraph [-reindex] [-search {queryfile}|-stats|-check] [-docs {datafile}] [-niter N] [-fanout N] [-maxConn N] [-beamWidth N]";
"Usage: TestKnnGraph [-reindex] [-search {queryfile}|-stats|-check] [-docs {datafile}] [-niter N] [-fanout N] [-maxConn N] [-beamWidth N] [-filterSelectivity N] [-prefilter]";
System.err.println(error);
System.exit(1);
}
@ -729,4 +793,50 @@ public class KnnGraphTester {
return Float.compare(score[pivot], score[j]);
}
}
private static class BitSetQuery extends Query {
private final FixedBitSet docs;
private final int cardinality;
BitSetQuery(FixedBitSet docs) {
this.docs = docs;
this.cardinality = docs.cardinality();
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
return new ConstantScoreWeight(this, boost) {
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
return new ConstantScoreScorer(
this, score(), scoreMode, new BitSetIterator(docs, cardinality));
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
};
}
@Override
public void visit(QueryVisitor visitor) {}
@Override
public String toString(String field) {
return "BitSetQuery";
}
@Override
public boolean equals(Object other) {
return sameClassAs(other) && docs.equals(((BitSetQuery) other).docs);
}
@Override
public int hashCode() {
return 31 * classHash() + docs.hashCode();
}
}
}