mirror of https://github.com/apache/lucene.git
Fix SimpleTextKnnVectorsReader to handle changes introduced in GITHUB#12004 (#12024)
This commit is contained in:
parent
72968d30ba
commit
11f2bc2056
|
@ -29,6 +29,7 @@ import org.apache.lucene.index.CorruptIndexException;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
|
@ -140,7 +141,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
IndexInput bytesSlice =
|
||||
dataIn.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
return new SimpleTextVectorValues(fieldEntry, bytesSlice);
|
||||
return new SimpleTextVectorValues(fieldEntry, bytesSlice, info.getVectorEncoding());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -187,7 +188,42 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
VectorValues values = getVectorValues(field);
|
||||
if (target.length != values.dimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"vector query dimension: "
|
||||
+ target.length
|
||||
+ " differs from field dimension: "
|
||||
+ values.dimension());
|
||||
}
|
||||
FieldInfo info = readState.fieldInfos.fieldInfo(field);
|
||||
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
|
||||
HitQueue topK = new HitQueue(k, false);
|
||||
|
||||
int numVisited = 0;
|
||||
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
|
||||
|
||||
int doc;
|
||||
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
if (acceptDocs != null && acceptDocs.get(doc) == false) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (numVisited >= visitedLimit) {
|
||||
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||
break;
|
||||
}
|
||||
|
||||
BytesRef vector = values.binaryValue();
|
||||
float score = vectorSimilarity.compare(vector, target);
|
||||
topK.insertWithOverflow(new ScoreDoc(doc, score));
|
||||
numVisited++;
|
||||
}
|
||||
ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()];
|
||||
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||
topScoreDocs[i] = topK.pop();
|
||||
}
|
||||
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -273,16 +309,19 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
private final IndexInput in;
|
||||
private final BytesRef binaryValue;
|
||||
private final float[][] values;
|
||||
private final VectorEncoding vectorEncoding;
|
||||
|
||||
int curOrd;
|
||||
|
||||
SimpleTextVectorValues(FieldEntry entry, IndexInput in) throws IOException {
|
||||
SimpleTextVectorValues(FieldEntry entry, IndexInput in, VectorEncoding vectorEncoding)
|
||||
throws IOException {
|
||||
this.entry = entry;
|
||||
this.in = in;
|
||||
values = new float[entry.size()][entry.dimension];
|
||||
binaryValue = new BytesRef(entry.dimension * Float.BYTES);
|
||||
binaryValue = new BytesRef(entry.dimension * vectorEncoding.byteSize);
|
||||
binaryValue.length = binaryValue.bytes.length;
|
||||
curOrd = -1;
|
||||
this.vectorEncoding = vectorEncoding;
|
||||
readAllVectors();
|
||||
}
|
||||
|
||||
|
@ -303,7 +342,17 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
|
||||
@Override
|
||||
public BytesRef binaryValue() {
|
||||
ByteBuffer.wrap(binaryValue.bytes).asFloatBuffer().get(values[curOrd]);
|
||||
switch (vectorEncoding) {
|
||||
// we know that the floats are really just byte values
|
||||
case BYTE:
|
||||
for (int i = 0; i < values[curOrd].length; i++) {
|
||||
binaryValue.bytes[i + binaryValue.offset] = (byte) values[curOrd][i];
|
||||
}
|
||||
break;
|
||||
case FLOAT32:
|
||||
ByteBuffer.wrap(binaryValue.bytes).asFloatBuffer().get(values[curOrd]);
|
||||
break;
|
||||
}
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.apache.lucene.search;
|
|||
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.generators.RandomPicks;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
|
@ -36,13 +37,25 @@ import org.apache.lucene.util.BytesRef;
|
|||
public class TestVectorScorer extends LuceneTestCase {
|
||||
|
||||
public void testFindAll() throws IOException {
|
||||
VectorEncoding encoding = RandomPicks.randomFrom(random(), VectorEncoding.values());
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
getIndexStore(
|
||||
"field", encoding, new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
assert reader.leaves().size() == 1;
|
||||
LeafReaderContext context = reader.leaves().get(0);
|
||||
FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field");
|
||||
VectorScorer vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
|
||||
final VectorScorer vectorScorer;
|
||||
switch (encoding) {
|
||||
case BYTE:
|
||||
vectorScorer = VectorScorer.create(context, fieldInfo, new BytesRef(new byte[] {1, 2}));
|
||||
break;
|
||||
case FLOAT32:
|
||||
vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
|
||||
break;
|
||||
default:
|
||||
throw new IllegalArgumentException("unexpected vector encoding: " + encoding);
|
||||
}
|
||||
|
||||
int numDocs = 0;
|
||||
for (int i = 0; i < reader.maxDoc(); i++) {
|
||||
|
@ -55,11 +68,10 @@ public class TestVectorScorer extends LuceneTestCase {
|
|||
}
|
||||
|
||||
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
|
||||
private Directory getIndexStore(String field, float[]... contents) throws IOException {
|
||||
private Directory getIndexStore(String field, VectorEncoding encoding, float[]... contents)
|
||||
throws IOException {
|
||||
Directory indexStore = newDirectory();
|
||||
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
|
||||
VectorEncoding encoding =
|
||||
VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
||||
for (int i = 0; i < contents.length; ++i) {
|
||||
Document doc = new Document();
|
||||
if (encoding == VectorEncoding.BYTE) {
|
||||
|
|
Loading…
Reference in New Issue