mirror of https://github.com/apache/lucene.git
Fix SimpleTextKnnVectorsReader to handle changes introduced in GITHUB#12004 (#12024)
This commit is contained in:
parent
72968d30ba
commit
11f2bc2056
|
@ -29,6 +29,7 @@ import org.apache.lucene.index.CorruptIndexException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
|
@ -140,7 +141,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
IndexInput bytesSlice =
|
IndexInput bytesSlice =
|
||||||
dataIn.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
dataIn.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||||
return new SimpleTextVectorValues(fieldEntry, bytesSlice);
|
return new SimpleTextVectorValues(fieldEntry, bytesSlice, info.getVectorEncoding());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -187,7 +188,42 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
VectorValues values = getVectorValues(field);
|
||||||
|
if (target.length != values.dimension()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"vector query dimension: "
|
||||||
|
+ target.length
|
||||||
|
+ " differs from field dimension: "
|
||||||
|
+ values.dimension());
|
||||||
|
}
|
||||||
|
FieldInfo info = readState.fieldInfos.fieldInfo(field);
|
||||||
|
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
|
||||||
|
HitQueue topK = new HitQueue(k, false);
|
||||||
|
|
||||||
|
int numVisited = 0;
|
||||||
|
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
|
||||||
|
|
||||||
|
int doc;
|
||||||
|
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||||
|
if (acceptDocs != null && acceptDocs.get(doc) == false) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (numVisited >= visitedLimit) {
|
||||||
|
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
BytesRef vector = values.binaryValue();
|
||||||
|
float score = vectorSimilarity.compare(vector, target);
|
||||||
|
topK.insertWithOverflow(new ScoreDoc(doc, score));
|
||||||
|
numVisited++;
|
||||||
|
}
|
||||||
|
ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()];
|
||||||
|
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||||
|
topScoreDocs[i] = topK.pop();
|
||||||
|
}
|
||||||
|
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -273,16 +309,19 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
private final IndexInput in;
|
private final IndexInput in;
|
||||||
private final BytesRef binaryValue;
|
private final BytesRef binaryValue;
|
||||||
private final float[][] values;
|
private final float[][] values;
|
||||||
|
private final VectorEncoding vectorEncoding;
|
||||||
|
|
||||||
int curOrd;
|
int curOrd;
|
||||||
|
|
||||||
SimpleTextVectorValues(FieldEntry entry, IndexInput in) throws IOException {
|
SimpleTextVectorValues(FieldEntry entry, IndexInput in, VectorEncoding vectorEncoding)
|
||||||
|
throws IOException {
|
||||||
this.entry = entry;
|
this.entry = entry;
|
||||||
this.in = in;
|
this.in = in;
|
||||||
values = new float[entry.size()][entry.dimension];
|
values = new float[entry.size()][entry.dimension];
|
||||||
binaryValue = new BytesRef(entry.dimension * Float.BYTES);
|
binaryValue = new BytesRef(entry.dimension * vectorEncoding.byteSize);
|
||||||
binaryValue.length = binaryValue.bytes.length;
|
binaryValue.length = binaryValue.bytes.length;
|
||||||
curOrd = -1;
|
curOrd = -1;
|
||||||
|
this.vectorEncoding = vectorEncoding;
|
||||||
readAllVectors();
|
readAllVectors();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -303,7 +342,17 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BytesRef binaryValue() {
|
public BytesRef binaryValue() {
|
||||||
ByteBuffer.wrap(binaryValue.bytes).asFloatBuffer().get(values[curOrd]);
|
switch (vectorEncoding) {
|
||||||
|
// we know that the floats are really just byte values
|
||||||
|
case BYTE:
|
||||||
|
for (int i = 0; i < values[curOrd].length; i++) {
|
||||||
|
binaryValue.bytes[i + binaryValue.offset] = (byte) values[curOrd][i];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case FLOAT32:
|
||||||
|
ByteBuffer.wrap(binaryValue.bytes).asFloatBuffer().get(values[curOrd]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
return binaryValue;
|
return binaryValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.apache.lucene.search;
|
||||||
|
|
||||||
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
||||||
|
|
||||||
|
import com.carrotsearch.randomizedtesting.generators.RandomPicks;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
|
@ -36,13 +37,25 @@ import org.apache.lucene.util.BytesRef;
|
||||||
public class TestVectorScorer extends LuceneTestCase {
|
public class TestVectorScorer extends LuceneTestCase {
|
||||||
|
|
||||||
public void testFindAll() throws IOException {
|
public void testFindAll() throws IOException {
|
||||||
|
VectorEncoding encoding = RandomPicks.randomFrom(random(), VectorEncoding.values());
|
||||||
try (Directory indexStore =
|
try (Directory indexStore =
|
||||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
getIndexStore(
|
||||||
|
"field", encoding, new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||||
assert reader.leaves().size() == 1;
|
assert reader.leaves().size() == 1;
|
||||||
LeafReaderContext context = reader.leaves().get(0);
|
LeafReaderContext context = reader.leaves().get(0);
|
||||||
FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field");
|
FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field");
|
||||||
VectorScorer vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
|
final VectorScorer vectorScorer;
|
||||||
|
switch (encoding) {
|
||||||
|
case BYTE:
|
||||||
|
vectorScorer = VectorScorer.create(context, fieldInfo, new BytesRef(new byte[] {1, 2}));
|
||||||
|
break;
|
||||||
|
case FLOAT32:
|
||||||
|
vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException("unexpected vector encoding: " + encoding);
|
||||||
|
}
|
||||||
|
|
||||||
int numDocs = 0;
|
int numDocs = 0;
|
||||||
for (int i = 0; i < reader.maxDoc(); i++) {
|
for (int i = 0; i < reader.maxDoc(); i++) {
|
||||||
|
@ -55,11 +68,10 @@ public class TestVectorScorer extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
|
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
|
||||||
private Directory getIndexStore(String field, float[]... contents) throws IOException {
|
private Directory getIndexStore(String field, VectorEncoding encoding, float[]... contents)
|
||||||
|
throws IOException {
|
||||||
Directory indexStore = newDirectory();
|
Directory indexStore = newDirectory();
|
||||||
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
|
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
|
||||||
VectorEncoding encoding =
|
|
||||||
VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
|
||||||
for (int i = 0; i < contents.length; ++i) {
|
for (int i = 0; i < contents.length; ++i) {
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
if (encoding == VectorEncoding.BYTE) {
|
if (encoding == VectorEncoding.BYTE) {
|
||||||
|
|
Loading…
Reference in New Issue