Fix for bad cast when sorting a KnnVectors index over BytesRef (#1074)

This commit is contained in:
Michael Sokolov 2022-08-20 17:23:47 -04:00 committed by GitHub
parent 798c02dd70
commit 0a58318e16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 207 additions and 7 deletions

View File

@ -51,4 +51,14 @@ public class TestLucene90HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
.toString()
.equals(expectedString));
}
@Override
public void testRandomBytes() throws Exception {
// unimplemented
}
@Override
public void testSortedIndexBytes() throws Exception {
// unimplemented
}
}

View File

@ -50,4 +50,14 @@ public class TestLucene91HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
expectedString,
((Lucene91Codec) customCodec).getKnnVectorsFormatForField("bogus_field").toString());
}
@Override
public void testRandomBytes() throws Exception {
// unimplemented
}
@Override
public void testSortedIndexBytes() throws Exception {
// unimplemented
}
}

View File

@ -40,4 +40,14 @@ public class TestLucene92HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
expectedString,
((Lucene92Codec) customCodec).getKnnVectorsFormatForField("bogus_field").toString());
}
@Override
public void testRandomBytes() throws Exception {
// unimplemented
}
@Override
public void testSortedIndexBytes() throws Exception {
// unimplemented
}
}

View File

@ -28,6 +28,7 @@ import org.apache.lucene.index.BufferingKnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.BytesRef;
@ -76,6 +77,10 @@ public class SimpleTextKnnVectorsWriter extends BufferingKnnVectorsWriter {
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
throws IOException {
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"SimpleText codec does not support vector encoding: " + fieldInfo.getVectorEncoding());
}
long vectorDataOffset = vectorData.getFilePointer();
List<Integer> docIds = new ArrayList<>();
int docV;

View File

@ -17,6 +17,11 @@
package org.apache.lucene.codecs.simpletext;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
public class TestSimpleTextKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
@ -24,4 +29,24 @@ public class TestSimpleTextKnnVectorsFormat extends BaseKnnVectorsFormatTestCase
protected Codec getCodec() {
return new SimpleTextCodec();
}
public void testUnsupportedEncoding() throws Exception {
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new KnnVectorField("field", newBytesRef(2), VectorSimilarityFunction.DOT_PRODUCT));
iw.addDocument(doc);
expectThrows(IllegalArgumentException.class, () -> iw.commit());
}
}
@Override
public void testRandomBytes() throws Exception {
// unimplemented
}
@Override
public void testSortedIndexBytes() throws Exception {
// unimplemented
}
}

View File

@ -278,8 +278,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
for (int ordinal : ordMap) {
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
vectorData.writeBytes(vector, 0, vector.length);
BytesRef vector = (BytesRef) fieldData.vectors.get(ordinal);
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
}
return vectorDataOffset;
}

View File

@ -245,9 +245,8 @@ public class TestKnnVectorQuery extends LuceneTestCase {
for (int j = 0; j < 5; j++) {
vectors[j] = new float[] {j, j};
}
try (Directory d = getIndexStore("field", vectors);
try (Directory d = getIndexStore("field", 1, vectors);
IndexReader reader = DirectoryReader.open(d)) {
assertEquals(1, reader.leaves().size());
IndexSearcher searcher = new IndexSearcher(reader);
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
Query rewritten = query.rewrite(reader);
@ -757,8 +756,13 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
}
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
private Directory getIndexStore(String field, float[]... contents) throws IOException {
return getIndexStore(field, -1, contents);
}
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
private Directory getIndexStore(String field, int forceMerge, float[]... contents)
throws IOException {
Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
VectorEncoding encoding = randomVectorEncoding();
@ -782,7 +786,9 @@ public class TestKnnVectorQuery extends LuceneTestCase {
doc.add(new StringField("other", "value", Field.Store.NO));
writer.addDocument(doc);
}
if (forceMerge > 0) {
writer.forceMerge(forceMerge);
}
writer.close();
return indexStore;
}

View File

@ -720,7 +720,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IndexWriter iw = new IndexWriter(dir, iwc)) {
add(iw, fieldName, 1, 1, new float[] {-1, 0});
add(iw, fieldName, 4, 4, new float[] {0, 1});
add(iw, fieldName, 3, 3, null);
add(iw, fieldName, 3, 3, (float[]) null);
add(iw, fieldName, 2, 2, new float[] {1, 0});
iw.forceMerge(1);
try (IndexReader reader = DirectoryReader.open(iw)) {
@ -740,6 +740,34 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
}
public void testSortedIndexBytes() throws Exception {
IndexWriterConfig iwc = newIndexWriterConfig();
iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT)));
String fieldName = "field";
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, iwc)) {
add(iw, fieldName, 1, 1, new BytesRef(new byte[] {-1, 0}));
add(iw, fieldName, 4, 4, new BytesRef(new byte[] {0, 1}));
add(iw, fieldName, 3, 3, (BytesRef) null);
add(iw, fieldName, 2, 2, new BytesRef(new byte[] {1, 0}));
iw.forceMerge(1);
try (IndexReader reader = DirectoryReader.open(iw)) {
LeafReader leaf = getOnlyLeafReader(reader);
VectorValues vectorValues = leaf.getVectorValues(fieldName);
assertEquals(2, vectorValues.dimension());
assertEquals(3, vectorValues.size());
assertEquals("1", leaf.document(vectorValues.nextDoc()).get("id"));
assertEquals(-1f, vectorValues.vectorValue()[0], 0);
assertEquals("2", leaf.document(vectorValues.nextDoc()).get("id"));
assertEquals(1, vectorValues.vectorValue()[0], 0);
assertEquals("4", leaf.document(vectorValues.nextDoc()).get("id"));
assertEquals(0, vectorValues.vectorValue()[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
}
}
}
public void testIndexMultipleKnnVectorFields() throws Exception {
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig())) {
@ -863,6 +891,82 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
}
/**
* Index random vectors as bytes, sometimes skipping documents, sometimes deleting a document,
* sometimes merging, sometimes sorting the index, and verify that the expected values can be read
* back consistently.
*/
public void testRandomBytes() throws Exception {
IndexWriterConfig iwc = newIndexWriterConfig();
if (random().nextBoolean()) {
iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT)));
}
String fieldName = "field";
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, iwc)) {
int numDoc = atLeast(100);
int dimension = atLeast(10);
BytesRef scratch = new BytesRef(dimension);
scratch.length = dimension;
int numValues = 0;
BytesRef[] values = new BytesRef[numDoc];
for (int i = 0; i < numDoc; i++) {
if (random().nextInt(7) != 3) {
// usually index a vector value for a doc
values[i] = randomVector8(dimension);
++numValues;
}
if (random().nextBoolean() && values[i] != null) {
// sometimes use a shared scratch array
System.arraycopy(values[i].bytes, 0, scratch.bytes, 0, dimension);
add(iw, fieldName, i, scratch, similarityFunction);
} else {
add(iw, fieldName, i, values[i], similarityFunction);
}
if (random().nextInt(10) == 2) {
// sometimes delete a random document
int idToDelete = random().nextInt(i + 1);
iw.deleteDocuments(new Term("id", Integer.toString(idToDelete)));
// and remember that it was deleted
if (values[idToDelete] != null) {
values[idToDelete] = null;
--numValues;
}
}
if (random().nextInt(10) == 3) {
iw.commit();
}
}
int numDeletes = 0;
try (IndexReader reader = DirectoryReader.open(iw)) {
int valueCount = 0, totalSize = 0;
for (LeafReaderContext ctx : reader.leaves()) {
VectorValues vectorValues = ctx.reader().getVectorValues(fieldName);
if (vectorValues == null) {
continue;
}
totalSize += vectorValues.size();
int docId;
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
BytesRef v = vectorValues.binaryValue();
assertEquals(dimension, v.length);
String idString = ctx.reader().document(docId).getField("id").stringValue();
int id = Integer.parseInt(idString);
if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) {
assertEquals(idString, 0, values[id].compareTo(v));
++valueCount;
} else {
++numDeletes;
assertNull(values[id]);
}
}
}
assertEquals(numValues, valueCount);
assertEquals(numValues, totalSize - numDeletes);
}
}
}
/**
* Tests whether {@link KnnVectorsReader#search} implementations obey the limit on the number of
* visited vectors. This test is a best-effort attempt to capture the right behavior, and isn't
@ -1012,6 +1116,36 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
add(iw, field, id, random().nextInt(100), vector, similarityFunction);
}
private void add(
IndexWriter iw, String field, int id, BytesRef vector, VectorSimilarityFunction similarity)
throws IOException {
add(iw, field, id, random().nextInt(100), vector, similarity);
}
private void add(IndexWriter iw, String field, int id, int sortKey, BytesRef vector)
throws IOException {
add(iw, field, id, sortKey, vector, VectorSimilarityFunction.EUCLIDEAN);
}
private void add(
IndexWriter iw,
String field,
int id,
int sortKey,
BytesRef vector,
VectorSimilarityFunction similarityFunction)
throws IOException {
Document doc = new Document();
if (vector != null) {
doc.add(new KnnVectorField(field, vector, similarityFunction));
}
doc.add(new NumericDocValuesField("sortkey", sortKey));
String idString = Integer.toString(id);
doc.add(new StringField("id", idString, Field.Store.YES));
Term idTerm = new Term("id", idString);
iw.updateDocument(idTerm, doc);
}
private void add(IndexWriter iw, String field, int id, int sortkey, float[] vector)
throws IOException {
add(iw, field, id, sortkey, vector, VectorSimilarityFunction.EUCLIDEAN);