Fix SimpleTextKnnVectorsReader to handle changes introduced in GITHUB#12004 (#12024)

This commit is contained in:
Benjamin Trent 2022-12-15 08:49:47 -05:00 committed by GitHub
parent 72968d30ba
commit 11f2bc2056
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 10 deletions

View File

@ -29,6 +29,7 @@ import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
@ -140,7 +141,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
IndexInput bytesSlice =
dataIn.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
return new SimpleTextVectorValues(fieldEntry, bytesSlice);
return new SimpleTextVectorValues(fieldEntry, bytesSlice, info.getVectorEncoding());
}
@Override
@ -187,7 +188,42 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
VectorValues values = getVectorValues(field);
if (target.length != values.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ target.length
+ " differs from field dimension: "
+ values.dimension());
}
FieldInfo info = readState.fieldInfos.fieldInfo(field);
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
HitQueue topK = new HitQueue(k, false);
int numVisited = 0;
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
int doc;
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
if (acceptDocs != null && acceptDocs.get(doc) == false) {
continue;
}
if (numVisited >= visitedLimit) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
break;
}
BytesRef vector = values.binaryValue();
float score = vectorSimilarity.compare(vector, target);
topK.insertWithOverflow(new ScoreDoc(doc, score));
numVisited++;
}
ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = topK.pop();
}
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
}
@Override
@ -273,16 +309,19 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
private final IndexInput in;
private final BytesRef binaryValue;
private final float[][] values;
private final VectorEncoding vectorEncoding;
int curOrd;
SimpleTextVectorValues(FieldEntry entry, IndexInput in) throws IOException {
SimpleTextVectorValues(FieldEntry entry, IndexInput in, VectorEncoding vectorEncoding)
throws IOException {
this.entry = entry;
this.in = in;
values = new float[entry.size()][entry.dimension];
binaryValue = new BytesRef(entry.dimension * Float.BYTES);
binaryValue = new BytesRef(entry.dimension * vectorEncoding.byteSize);
binaryValue.length = binaryValue.bytes.length;
curOrd = -1;
this.vectorEncoding = vectorEncoding;
readAllVectors();
}
@ -303,7 +342,17 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
@Override
public BytesRef binaryValue() {
switch (vectorEncoding) {
// we know that the floats are really just byte values
case BYTE:
for (int i = 0; i < values[curOrd].length; i++) {
binaryValue.bytes[i + binaryValue.offset] = (byte) values[curOrd][i];
}
break;
case FLOAT32:
ByteBuffer.wrap(binaryValue.bytes).asFloatBuffer().get(values[curOrd]);
break;
}
return binaryValue;
}

View File

@ -18,6 +18,7 @@ package org.apache.lucene.search;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import com.carrotsearch.randomizedtesting.generators.RandomPicks;
import java.io.IOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
@ -36,13 +37,25 @@ import org.apache.lucene.util.BytesRef;
public class TestVectorScorer extends LuceneTestCase {
public void testFindAll() throws IOException {
VectorEncoding encoding = RandomPicks.randomFrom(random(), VectorEncoding.values());
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
getIndexStore(
"field", encoding, new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
assert reader.leaves().size() == 1;
LeafReaderContext context = reader.leaves().get(0);
FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field");
VectorScorer vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
final VectorScorer vectorScorer;
switch (encoding) {
case BYTE:
vectorScorer = VectorScorer.create(context, fieldInfo, new BytesRef(new byte[] {1, 2}));
break;
case FLOAT32:
vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
break;
default:
throw new IllegalArgumentException("unexpected vector encoding: " + encoding);
}
int numDocs = 0;
for (int i = 0; i < reader.maxDoc(); i++) {
@ -55,11 +68,10 @@ public class TestVectorScorer extends LuceneTestCase {
}
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
private Directory getIndexStore(String field, float[]... contents) throws IOException {
private Directory getIndexStore(String field, VectorEncoding encoding, float[]... contents)
throws IOException {
Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
VectorEncoding encoding =
VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
for (int i = 0; i < contents.length; ++i) {
Document doc = new Document();
if (encoding == VectorEncoding.BYTE) {