Implement Weight#count for vector values in the FieldExistsQuery (#13322)

* implement Weight#count for vector values

* add change log

* apply review comment

* apply review comment

* changelog

* remove null check
This commit is contained in:
panguixin 2024-06-06 03:02:51 +08:00 committed by GitHub
parent 05b4639c0c
commit fe50e86e36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 61 additions and 9 deletions

View File

@ -255,6 +255,8 @@ Optimizations
* GITHUB##13425: Rewrite SortedNumericDocValuesRangeQuery to MatchNoDocsQuery when the upper bound is smaller than the
lower bound. (Ioana Tagirta)
* GITHUB#13322: Implement Weight#count for vector values in the FieldExistsQuery. (Pan Guixin)
Bug Fixes
---------------------
(No changes)

View File

@ -19,10 +19,12 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.Objects;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
@ -35,7 +37,7 @@ import org.apache.lucene.index.Terms;
* org.apache.lucene.document.KnnByteVectorField} or a field that indexes norms or doc values.
*/
public class FieldExistsQuery extends Query {
private String field;
private final String field;
/** Create a query that will match that have a value for the given {@code field}. */
public FieldExistsQuery(String field) {
@ -128,13 +130,7 @@ public class FieldExistsQuery extends Query {
break;
}
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
DocIdSetIterator vectorValues =
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> leaf.getFloatVectorValues(field);
case BYTE -> leaf.getByteVectorValues(field);
};
assert vectorValues != null : "unexpected null vector values";
if (vectorValues != null && vectorValues.cost() != leaf.maxDoc()) {
if (getVectorValuesSize(fieldInfo, leaf) != leaf.maxDoc()) {
allReadersRewritable = false;
break;
}
@ -238,7 +234,10 @@ public class FieldExistsQuery extends Query {
}
return super.count(context);
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
} else if (fieldInfo.hasVectorValues()) { // the field indexes vectors
if (reader.hasDeletions() == false) {
return getVectorValuesSize(fieldInfo, reader);
}
return super.count(context);
} else if (fieldInfo.getDocValuesType()
!= DocValuesType.NONE) { // the field indexes doc values
@ -277,4 +276,20 @@ public class FieldExistsQuery extends Query {
+ fieldInfo.name
+ "' exists and indexes neither of these data structures";
}
private int getVectorValuesSize(FieldInfo fi, LeafReader reader) throws IOException {
assert fi.name.equals(field);
return switch (fi.getVectorEncoding()) {
case FLOAT32 -> {
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
assert floatVectorValues != null : "unexpected null float vector values";
yield floatVectorValues.size();
}
case BYTE -> {
ByteVectorValues byteVectorValues = reader.getByteVectorValues(field);
assert byteVectorValues != null : "unexpected null byte vector values";
yield byteVectorValues.size();
}
};
}
}

View File

@ -40,7 +40,9 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.VectorUtil;
@ -646,6 +648,39 @@ public class TestFieldExistsQuery extends LuceneTestCase {
}
}
public void testDeleteKnnVector() throws IOException {
try (Directory dir = newDirectory();
RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
final int numDocs = atLeast(100);
boolean allDocsHaveVector = random().nextBoolean();
BitSet docWithVector = new FixedBitSet(numDocs);
for (int i = 0; i < numDocs; ++i) {
Document doc = new Document();
if (allDocsHaveVector || random().nextBoolean()) {
doc.add(new KnnFloatVectorField("vector", randomVector(5)));
docWithVector.set(i);
}
doc.add(new StringField("id", Integer.toString(i), Store.NO));
iw.addDocument(doc);
}
if (random().nextBoolean()) {
final int numDeleted = random().nextInt(numDocs) + 1;
for (int i = 0; i < numDeleted; ++i) {
iw.deleteDocuments(new Term("id", Integer.toString(i)));
docWithVector.clear(i);
}
}
try (IndexReader reader = iw.getReader()) {
final IndexSearcher searcher = newSearcher(reader);
final int count = searcher.count(new FieldExistsQuery("vector"));
assertEquals(docWithVector.cardinality(), count);
}
}
}
public void testKnnVectorConjunction() throws IOException {
try (Directory dir = newDirectory();
RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {