diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 3993e2e3bd5..4994cf692d4 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -29,6 +29,7 @@ import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorValues; import org.apache.lucene.search.DocIdSetIterator; @@ -140,7 +141,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { } IndexInput bytesSlice = dataIn.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength); - return new SimpleTextVectorValues(fieldEntry, bytesSlice); + return new SimpleTextVectorValues(fieldEntry, bytesSlice, info.getVectorEncoding()); } @Override @@ -187,7 +188,42 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { @Override public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException { - return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + VectorValues values = getVectorValues(field); + if (target.length != values.dimension()) { + throw new IllegalArgumentException( + "vector query dimension: " + + target.length + + " differs from field dimension: " + + values.dimension()); + } + FieldInfo info = readState.fieldInfos.fieldInfo(field); + VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction(); + HitQueue topK = new HitQueue(k, false); + + int numVisited = 0; + TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; + + int doc; + while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + if (acceptDocs != null && acceptDocs.get(doc) == false) { + continue; + } + + if (numVisited >= visitedLimit) { + relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + break; + } + + BytesRef vector = values.binaryValue(); + float score = vectorSimilarity.compare(vector, target); + topK.insertWithOverflow(new ScoreDoc(doc, score)); + numVisited++; + } + ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()]; + for (int i = topScoreDocs.length - 1; i >= 0; i--) { + topScoreDocs[i] = topK.pop(); + } + return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs); } @Override @@ -273,16 +309,19 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { private final IndexInput in; private final BytesRef binaryValue; private final float[][] values; + private final VectorEncoding vectorEncoding; int curOrd; - SimpleTextVectorValues(FieldEntry entry, IndexInput in) throws IOException { + SimpleTextVectorValues(FieldEntry entry, IndexInput in, VectorEncoding vectorEncoding) + throws IOException { this.entry = entry; this.in = in; values = new float[entry.size()][entry.dimension]; - binaryValue = new BytesRef(entry.dimension * Float.BYTES); + binaryValue = new BytesRef(entry.dimension * vectorEncoding.byteSize); binaryValue.length = binaryValue.bytes.length; curOrd = -1; + this.vectorEncoding = vectorEncoding; readAllVectors(); } @@ -303,7 +342,17 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { @Override public BytesRef binaryValue() { - ByteBuffer.wrap(binaryValue.bytes).asFloatBuffer().get(values[curOrd]); + switch (vectorEncoding) { + // we know that the floats are really just byte values + case BYTE: + for (int i = 0; i < values[curOrd].length; i++) { + binaryValue.bytes[i + binaryValue.offset] = (byte) values[curOrd][i]; + } + break; + case FLOAT32: + ByteBuffer.wrap(binaryValue.bytes).asFloatBuffer().get(values[curOrd]); + break; + } return binaryValue; } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java index 979283c7701..2c72baa73d3 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java @@ -18,6 +18,7 @@ package org.apache.lucene.search; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import com.carrotsearch.randomizedtesting.generators.RandomPicks; import java.io.IOException; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -36,13 +37,25 @@ import org.apache.lucene.util.BytesRef; public class TestVectorScorer extends LuceneTestCase { public void testFindAll() throws IOException { + VectorEncoding encoding = RandomPicks.randomFrom(random(), VectorEncoding.values()); try (Directory indexStore = - getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + getIndexStore( + "field", encoding, new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); IndexReader reader = DirectoryReader.open(indexStore)) { assert reader.leaves().size() == 1; LeafReaderContext context = reader.leaves().get(0); FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field"); - VectorScorer vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2}); + final VectorScorer vectorScorer; + switch (encoding) { + case BYTE: + vectorScorer = VectorScorer.create(context, fieldInfo, new BytesRef(new byte[] {1, 2})); + break; + case FLOAT32: + vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2}); + break; + default: + throw new IllegalArgumentException("unexpected vector encoding: " + encoding); + } int numDocs = 0; for (int i = 0; i < reader.maxDoc(); i++) { @@ -55,11 +68,10 @@ public class TestVectorScorer extends LuceneTestCase { } /** Creates a new directory and adds documents with the given vectors as kNN vector fields */ - private Directory getIndexStore(String field, float[]... contents) throws IOException { + private Directory getIndexStore(String field, VectorEncoding encoding, float[]... contents) + throws IOException { Directory indexStore = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore); - VectorEncoding encoding = - VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)]; for (int i = 0; i < contents.length; ++i) { Document doc = new Document(); if (encoding == VectorEncoding.BYTE) {