mirror of https://github.com/apache/lucene.git
LUCENE-10559: Add Prefilter Option to KnnGraphTester (#932)
Added a `prefilter` and `filterSelectivity` argument to KnnGraphTester to be able to compare pre and post-filtering benchmarks. `filterSelectivity` expresses the selectivity of a filter as proportion of passing docs that are randomly selected. We store these in a FixedBitSet and use this to calculate true KNN as well as in HNSW search. In case of post-filter, we over-select results as `topK / filterSelectivity` to get final hits close to actual requested `topK`. For pre-filter, we wrap the FixedBitSet in a query and pass it as prefilter argument to KnnVectorQuery.
This commit is contained in:
parent
eb7b7791ba
commit
1ad28a3136
|
@ -112,7 +112,7 @@ Bug Fixes
|
|||
|
||||
Other
|
||||
---------------------
|
||||
(No changes)
|
||||
* LUCENE-10559: Add Prefilter Option to KnnGraphTester (Kaival Parikh)
|
||||
|
||||
======================== Lucene 9.3.0 =======================
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ import java.nio.file.Files;
|
|||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.attribute.FileTime;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
|
@ -56,13 +57,22 @@ import org.apache.lucene.index.LeafReaderContext;
|
|||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.ConstantScoreScorer;
|
||||
import org.apache.lucene.search.ConstantScoreWeight;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.KnnVectorQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.QueryVisitor;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.apache.lucene.search.Scorer;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.FSDirectory;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.IntroSorter;
|
||||
import org.apache.lucene.util.PrintStreamInfoStream;
|
||||
import org.apache.lucene.util.SuppressForbidden;
|
||||
|
@ -91,8 +101,10 @@ public class KnnGraphTester {
|
|||
private int beamWidth;
|
||||
private int maxConn;
|
||||
private VectorSimilarityFunction similarityFunction;
|
||||
private FixedBitSet matchDocs;
|
||||
private float selectivity;
|
||||
private boolean prefilter;
|
||||
|
||||
@SuppressForbidden(reason = "uses Random()")
|
||||
private KnnGraphTester() {
|
||||
// set defaults
|
||||
numDocs = 1000;
|
||||
|
@ -101,6 +113,8 @@ public class KnnGraphTester {
|
|||
topK = 100;
|
||||
fanout = topK;
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
selectivity = 1f;
|
||||
prefilter = false;
|
||||
}
|
||||
|
||||
public static void main(String... args) throws Exception {
|
||||
|
@ -192,6 +206,18 @@ public class KnnGraphTester {
|
|||
case "-forceMerge":
|
||||
forceMerge = true;
|
||||
break;
|
||||
case "-prefilter":
|
||||
prefilter = true;
|
||||
break;
|
||||
case "-filterSelectivity":
|
||||
if (iarg == args.length - 1) {
|
||||
throw new IllegalArgumentException("-filterSelectivity requires a following float");
|
||||
}
|
||||
selectivity = Float.parseFloat(args[++iarg]);
|
||||
if (selectivity <= 0 || selectivity >= 1) {
|
||||
throw new IllegalArgumentException("-filterSelectivity must be between 0 and 1");
|
||||
}
|
||||
break;
|
||||
case "-quiet":
|
||||
quiet = true;
|
||||
break;
|
||||
|
@ -203,6 +229,9 @@ public class KnnGraphTester {
|
|||
if (operation == null && reindex == false) {
|
||||
usage();
|
||||
}
|
||||
if (prefilter == true && selectivity == 1f) {
|
||||
throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1");
|
||||
}
|
||||
indexPath = Paths.get(formatIndexPath(docVectorsPath));
|
||||
if (reindex) {
|
||||
if (docVectorsPath == null) {
|
||||
|
@ -219,6 +248,7 @@ public class KnnGraphTester {
|
|||
if (docVectorsPath == null) {
|
||||
throw new IllegalArgumentException("missing -docs arg");
|
||||
}
|
||||
matchDocs = generateRandomBitSet(numDocs, selectivity);
|
||||
if (outputPath != null) {
|
||||
testSearch(indexPath, queryPath, outputPath, null);
|
||||
} else {
|
||||
|
@ -362,17 +392,33 @@ public class KnnGraphTester {
|
|||
DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
numDocs = reader.maxDoc();
|
||||
Query bitSetQuery = new BitSetQuery(matchDocs);
|
||||
for (int i = 0; i < numIters; i++) {
|
||||
// warm up
|
||||
targets.get(target);
|
||||
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout);
|
||||
if (prefilter) {
|
||||
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
|
||||
} else {
|
||||
doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
|
||||
}
|
||||
}
|
||||
targets.position(0);
|
||||
start = System.nanoTime();
|
||||
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
|
||||
for (int i = 0; i < numIters; i++) {
|
||||
targets.get(target);
|
||||
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout);
|
||||
if (prefilter) {
|
||||
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
|
||||
} else {
|
||||
results[i] =
|
||||
doKnnVectorQuery(
|
||||
searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
|
||||
|
||||
results[i].scoreDocs =
|
||||
Arrays.stream(results[i].scoreDocs)
|
||||
.filter(scoreDoc -> matchDocs == null || matchDocs.get(scoreDoc.doc))
|
||||
.toArray(ScoreDoc[]::new);
|
||||
}
|
||||
}
|
||||
totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
|
||||
elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms
|
||||
|
@ -417,7 +463,7 @@ public class KnnGraphTester {
|
|||
totalVisited /= numIters;
|
||||
System.out.printf(
|
||||
Locale.ROOT,
|
||||
"%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\n",
|
||||
"%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\t%.2f\t%s\n",
|
||||
recall,
|
||||
totalCpuTime / (float) numIters,
|
||||
numDocs,
|
||||
|
@ -425,21 +471,22 @@ public class KnnGraphTester {
|
|||
maxConn,
|
||||
beamWidth,
|
||||
totalVisited,
|
||||
reindexTimeMsec);
|
||||
reindexTimeMsec,
|
||||
selectivity,
|
||||
prefilter ? "pre-filter" : "post-filter");
|
||||
}
|
||||
}
|
||||
|
||||
private static TopDocs doKnnVectorQuery(
|
||||
IndexSearcher searcher, String field, float[] vector, int k, int fanout) throws IOException {
|
||||
return searcher.search(new KnnVectorQuery(field, vector, k + fanout), k);
|
||||
IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter)
|
||||
throws IOException {
|
||||
return searcher.search(new KnnVectorQuery(field, vector, k + fanout, filter), k);
|
||||
}
|
||||
|
||||
private float checkResults(TopDocs[] results, int[][] nn) {
|
||||
int totalMatches = 0;
|
||||
int totalResults = 0;
|
||||
int totalResults = results.length * topK;
|
||||
for (int i = 0; i < results.length; i++) {
|
||||
int n = results[i].scoreDocs.length;
|
||||
totalResults += n;
|
||||
// System.out.println(Arrays.toString(nn[i]));
|
||||
// System.out.println(Arrays.toString(results[i].scoreDocs));
|
||||
totalMatches += compareNN(nn[i], results[i]);
|
||||
|
@ -463,7 +510,7 @@ public class KnnGraphTester {
|
|||
System.out.print('\n');
|
||||
*/
|
||||
Set<Integer> expectedSet = new HashSet<>();
|
||||
for (int i = 0; i < results.scoreDocs.length; i++) {
|
||||
for (int i = 0; i < topK; i++) {
|
||||
expectedSet.add(expected[i]);
|
||||
}
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
|
@ -479,11 +526,13 @@ public class KnnGraphTester {
|
|||
String hash = Integer.toString(Objects.hash(docPath, queryPath, numDocs, numIters, topK), 36);
|
||||
String nnFileName = "nn-" + hash + ".bin";
|
||||
Path nnPath = Paths.get(nnFileName);
|
||||
if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath)) {
|
||||
if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath) && selectivity == 1f) {
|
||||
return readNN(nnPath);
|
||||
} else {
|
||||
int[][] nn = computeNN(docPath, queryPath);
|
||||
writeNN(nn, nnPath);
|
||||
if (selectivity == 1f) {
|
||||
writeNN(nn, nnPath);
|
||||
}
|
||||
return nn;
|
||||
}
|
||||
}
|
||||
|
@ -527,6 +576,19 @@ public class KnnGraphTester {
|
|||
}
|
||||
}
|
||||
|
||||
@SuppressForbidden(reason = "Uses random()")
|
||||
private static FixedBitSet generateRandomBitSet(int size, float selectivity) {
|
||||
FixedBitSet bitSet = new FixedBitSet(size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (Math.random() < selectivity) {
|
||||
bitSet.set(i);
|
||||
} else {
|
||||
bitSet.clear(i);
|
||||
}
|
||||
}
|
||||
return bitSet;
|
||||
}
|
||||
|
||||
private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
|
||||
int[][] result = new int[numIters][];
|
||||
if (quiet == false) {
|
||||
|
@ -558,7 +620,9 @@ public class KnnGraphTester {
|
|||
for (; j < numDocs && vectors.hasRemaining(); j++) {
|
||||
vectors.get(vector);
|
||||
float d = similarityFunction.compare(query, vector);
|
||||
queue.insertWithOverflow(j, d);
|
||||
if (matchDocs == null || matchDocs.get(j)) {
|
||||
queue.insertWithOverflow(j, d);
|
||||
}
|
||||
}
|
||||
result[i] = new int[topK];
|
||||
for (int k = topK - 1; k >= 0; k--) {
|
||||
|
@ -633,7 +697,7 @@ public class KnnGraphTester {
|
|||
|
||||
private static void usage() {
|
||||
String error =
|
||||
"Usage: TestKnnGraph [-reindex] [-search {queryfile}|-stats|-check] [-docs {datafile}] [-niter N] [-fanout N] [-maxConn N] [-beamWidth N]";
|
||||
"Usage: TestKnnGraph [-reindex] [-search {queryfile}|-stats|-check] [-docs {datafile}] [-niter N] [-fanout N] [-maxConn N] [-beamWidth N] [-filterSelectivity N] [-prefilter]";
|
||||
System.err.println(error);
|
||||
System.exit(1);
|
||||
}
|
||||
|
@ -729,4 +793,50 @@ public class KnnGraphTester {
|
|||
return Float.compare(score[pivot], score[j]);
|
||||
}
|
||||
}
|
||||
|
||||
private static class BitSetQuery extends Query {
|
||||
|
||||
private final FixedBitSet docs;
|
||||
private final int cardinality;
|
||||
|
||||
BitSetQuery(FixedBitSet docs) {
|
||||
this.docs = docs;
|
||||
this.cardinality = docs.cardinality();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
|
||||
throws IOException {
|
||||
return new ConstantScoreWeight(this, boost) {
|
||||
@Override
|
||||
public Scorer scorer(LeafReaderContext context) throws IOException {
|
||||
return new ConstantScoreScorer(
|
||||
this, score(), scoreMode, new BitSetIterator(docs, cardinality));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCacheable(LeafReaderContext ctx) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "BitSetQuery";
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object other) {
|
||||
return sameClassAs(other) && docs.equals(((BitSetQuery) other).docs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 31 * classHash() + docs.hashCode();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue