mirror of https://github.com/apache/lucene.git
SimpleText codec to support writing byte vectors (#12111)
A recent test failure signaled that when the simple text codec was randomly selected, byte vectors could not be written. This commit addressed that by adding support for writing byte vectors to SimpleTextKnnVectorsWriter. Note that while support is added to the BufferingKnnVectorsWriter base class, 90, 91 and 92 writers don't need to support byte vectors and will throw unsupported operation exception when attempting to do that.
This commit is contained in:
parent
95e2cfcc1e
commit
5a51ce1d5d
|
@ -25,7 +25,7 @@ import java.nio.ByteOrder;
|
|||
import java.util.Arrays;
|
||||
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
|
@ -107,10 +107,9 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
|
||||
public void writeField(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, int maxDoc)
|
||||
throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
FloatVectorValues vectors = knnVectorsReader.getFloatVectorValues(fieldInfo.name);
|
||||
|
||||
IndexOutput tempVectorData =
|
||||
segmentWriteState.directory.createTempOutput(
|
||||
|
@ -120,7 +119,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
try {
|
||||
// write the vector data to a temporary file
|
||||
// TODO - use a better data structure; a bitset? DocsWithFieldSet is p.p. in o.a.l.index
|
||||
int[] docIds = writeVectorData(tempVectorData, vectors);
|
||||
int[] docIds = writeVectorData(tempVectorData, floatVectorValues);
|
||||
CodecUtil.writeFooter(tempVectorData);
|
||||
IOUtils.close(tempVectorData);
|
||||
|
||||
|
@ -134,7 +133,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
// build the graph using the temporary vector data
|
||||
Lucene90HnswVectorsReader.OffHeapFloatVectorValues offHeapVectors =
|
||||
new Lucene90HnswVectorsReader.OffHeapFloatVectorValues(
|
||||
vectors.dimension(), docIds, vectorDataInput);
|
||||
floatVectorValues.dimension(), docIds, vectorDataInput);
|
||||
|
||||
long[] offsets = new long[docIds.length];
|
||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||
|
@ -170,6 +169,11 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void writeField(FieldInfo fieldInfo, ByteVectorValues byteVectorValues, int maxDoc) {
|
||||
throw new UnsupportedOperationException("byte vectors not supported in this version");
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes the vector values to the output and returns a mapping from dense ordinals to document
|
||||
* IDs. The length of the returned array matches the total number of documents with a vector
|
||||
|
|
|
@ -25,7 +25,7 @@ import java.nio.ByteOrder;
|
|||
import java.util.Arrays;
|
||||
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
|
@ -109,10 +109,9 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
|
||||
public void writeField(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, int maxDoc)
|
||||
throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
FloatVectorValues vectors = knnVectorsReader.getFloatVectorValues(fieldInfo.name);
|
||||
|
||||
IndexOutput tempVectorData =
|
||||
segmentWriteState.directory.createTempOutput(
|
||||
|
@ -121,7 +120,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
boolean success = false;
|
||||
try {
|
||||
// write the vector data to a temporary file
|
||||
DocsWithFieldSet docsWithField = writeVectorData(tempVectorData, vectors);
|
||||
DocsWithFieldSet docsWithField = writeVectorData(tempVectorData, floatVectorValues);
|
||||
CodecUtil.writeFooter(tempVectorData);
|
||||
IOUtils.close(tempVectorData);
|
||||
|
||||
|
@ -139,7 +138,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
// TODO: separate random access vector values from DocIdSetIterator?
|
||||
Lucene91HnswVectorsReader.OffHeapFloatVectorValues offHeapVectors =
|
||||
new Lucene91HnswVectorsReader.OffHeapFloatVectorValues(
|
||||
vectors.dimension(), docsWithField.cardinality(), null, vectorDataInput);
|
||||
floatVectorValues.dimension(), docsWithField.cardinality(), null, vectorDataInput);
|
||||
Lucene91OnHeapHnswGraph graph =
|
||||
offHeapVectors.size() == 0
|
||||
? null
|
||||
|
@ -167,6 +166,11 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void writeField(FieldInfo fieldInfo, ByteVectorValues byteVectorValues, int maxDoc) {
|
||||
throw new UnsupportedOperationException("byte vectors not supported in this version");
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||
*/
|
||||
|
|
|
@ -26,8 +26,8 @@ import java.nio.ByteOrder;
|
|||
import java.util.Arrays;
|
||||
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
|
@ -115,10 +115,9 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
|
||||
public void writeField(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, int maxDoc)
|
||||
throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
FloatVectorValues vectors = knnVectorsReader.getFloatVectorValues(fieldInfo.name);
|
||||
|
||||
IndexOutput tempVectorData =
|
||||
segmentWriteState.directory.createTempOutput(
|
||||
|
@ -127,7 +126,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
boolean success = false;
|
||||
try {
|
||||
// write the vector data to a temporary file
|
||||
DocsWithFieldSet docsWithField = writeVectorData(tempVectorData, vectors);
|
||||
DocsWithFieldSet docsWithField = writeVectorData(tempVectorData, floatVectorValues);
|
||||
CodecUtil.writeFooter(tempVectorData);
|
||||
IOUtils.close(tempVectorData);
|
||||
|
||||
|
@ -146,12 +145,11 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
// TODO: separate random access vector values from DocIdSetIterator?
|
||||
OffHeapFloatVectorValues offHeapVectors =
|
||||
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
|
||||
vectors.dimension(), docsWithField.cardinality(), vectorDataInput);
|
||||
floatVectorValues.dimension(), docsWithField.cardinality(), vectorDataInput);
|
||||
OnHeapHnswGraph graph =
|
||||
offHeapVectors.size() == 0
|
||||
? null
|
||||
: writeGraph(
|
||||
offHeapVectors, VectorEncoding.FLOAT32, fieldInfo.getVectorSimilarityFunction());
|
||||
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
writeMeta(
|
||||
fieldInfo,
|
||||
|
@ -175,6 +173,11 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void writeField(FieldInfo fieldInfo, ByteVectorValues byteVectorValues, int maxDoc) {
|
||||
throw new UnsupportedOperationException("byte vectors not supported in this version");
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||
*/
|
||||
|
@ -271,16 +274,14 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
|
||||
private OnHeapHnswGraph writeGraph(
|
||||
RandomAccessVectorValues<float[]> vectorValues,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction)
|
||||
RandomAccessVectorValues<float[]> vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
|
||||
// build graph
|
||||
HnswGraphBuilder<float[]> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
vectorValues,
|
||||
vectorEncoding,
|
||||
VectorEncoding.FLOAT32,
|
||||
similarityFunction,
|
||||
M,
|
||||
beamWidth,
|
||||
|
|
|
@ -390,7 +390,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
case BYTE -> writeByteVectorData(
|
||||
tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
|
||||
case FLOAT32 -> writeVectorData(
|
||||
tempVectorData, MergedVectorValues.mergeVectorValues(fieldInfo, mergeState));
|
||||
tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
|
||||
};
|
||||
CodecUtil.writeFooter(tempVectorData);
|
||||
IOUtils.close(tempVectorData);
|
||||
|
@ -656,7 +656,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
};
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
||||
throws IOException {
|
||||
this.fieldInfo = fieldInfo;
|
||||
|
|
|
@ -24,7 +24,7 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
|
@ -73,21 +73,21 @@ public class SimpleTextKnnVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
|
||||
public void writeField(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, int maxDoc)
|
||||
throws IOException {
|
||||
FloatVectorValues vectors = knnVectorsReader.getFloatVectorValues(fieldInfo.name);
|
||||
long vectorDataOffset = vectorData.getFilePointer();
|
||||
List<Integer> docIds = new ArrayList<>();
|
||||
int docV;
|
||||
for (docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
|
||||
writeVectorValue(vectors);
|
||||
for (int docV = floatVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = floatVectorValues.nextDoc()) {
|
||||
writeFloatVectorValue(floatVectorValues);
|
||||
docIds.add(docV);
|
||||
}
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds);
|
||||
}
|
||||
|
||||
private void writeVectorValue(FloatVectorValues vectors) throws IOException {
|
||||
private void writeFloatVectorValue(FloatVectorValues vectors) throws IOException {
|
||||
// write vector value
|
||||
float[] value = vectors.vectorValue();
|
||||
assert value.length == vectors.dimension();
|
||||
|
@ -95,6 +95,29 @@ public class SimpleTextKnnVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
newline(vectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, ByteVectorValues byteVectorValues, int maxDoc)
|
||||
throws IOException {
|
||||
long vectorDataOffset = vectorData.getFilePointer();
|
||||
List<Integer> docIds = new ArrayList<>();
|
||||
for (int docV = byteVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = byteVectorValues.nextDoc()) {
|
||||
writeByteVectorValue(byteVectorValues);
|
||||
docIds.add(docV);
|
||||
}
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds);
|
||||
}
|
||||
|
||||
private void writeByteVectorValue(ByteVectorValues vectors) throws IOException {
|
||||
// write vector value
|
||||
byte[] value = vectors.vectorValue();
|
||||
assert value.length == vectors.dimension();
|
||||
write(vectorData, Arrays.toString(value));
|
||||
newline(vectorData);
|
||||
}
|
||||
|
||||
private void writeMeta(
|
||||
FieldInfo field, long vectorDataOffset, long vectorDataLength, List<Integer> docIds)
|
||||
throws IOException {
|
||||
|
|
|
@ -27,9 +27,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
|||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
||||
/**
|
||||
|
@ -39,79 +37,81 @@ import org.apache.lucene.util.RamUsageEstimator;
|
|||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
||||
private final List<FieldWriter> fields = new ArrayList<>();
|
||||
private final List<FieldWriter<?>> fields = new ArrayList<>();
|
||||
|
||||
/** Sole constructor */
|
||||
protected BufferingKnnVectorsWriter() {}
|
||||
|
||||
@Override
|
||||
public KnnFieldVectorsWriter<float[]> addField(FieldInfo fieldInfo) throws IOException {
|
||||
FieldWriter newField = new FieldWriter(fieldInfo);
|
||||
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
FieldWriter<?> newField;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case FLOAT32:
|
||||
newField =
|
||||
new FieldWriter<float[]>(fieldInfo) {
|
||||
@Override
|
||||
public float[] copyValue(float[] vectorValue) {
|
||||
return ArrayUtil.copyOfSubArray(vectorValue, 0, fieldInfo.getVectorDimension());
|
||||
}
|
||||
};
|
||||
break;
|
||||
case BYTE:
|
||||
newField =
|
||||
new FieldWriter<byte[]>(fieldInfo) {
|
||||
@Override
|
||||
public byte[] copyValue(byte[] vectorValue) {
|
||||
return ArrayUtil.copyOfSubArray(vectorValue, 0, fieldInfo.getVectorDimension());
|
||||
}
|
||||
};
|
||||
break;
|
||||
default:
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
fields.add(newField);
|
||||
return newField;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||
for (FieldWriter fieldData : fields) {
|
||||
KnnVectorsReader knnVectorsReader =
|
||||
new KnnVectorsReader() {
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
BufferedVectorValues vectorValues =
|
||||
new BufferedVectorValues(
|
||||
fieldData.docsWithField,
|
||||
fieldData.vectors,
|
||||
fieldData.fieldInfo.getVectorDimension());
|
||||
return sortMap != null
|
||||
? new SortingVectorValues(vectorValues, sortMap)
|
||||
: vectorValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
|
||||
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
|
||||
for (FieldWriter<?> fieldData : fields) {
|
||||
switch (fieldData.fieldInfo.getVectorEncoding()) {
|
||||
case FLOAT32:
|
||||
BufferedFloatVectorValues bufferedFloatVectorValues =
|
||||
new BufferedFloatVectorValues(
|
||||
fieldData.docsWithField,
|
||||
(List<float[]>) fieldData.vectors,
|
||||
fieldData.fieldInfo.getVectorDimension());
|
||||
FloatVectorValues floatVectorValues =
|
||||
sortMap != null
|
||||
? new SortingFloatVectorValues(bufferedFloatVectorValues, sortMap)
|
||||
: bufferedFloatVectorValues;
|
||||
writeField(fieldData.fieldInfo, floatVectorValues, maxDoc);
|
||||
break;
|
||||
case BYTE:
|
||||
BufferedByteVectorValues bufferedByteVectorValues =
|
||||
new BufferedByteVectorValues(
|
||||
fieldData.docsWithField,
|
||||
(List<byte[]>) fieldData.vectors,
|
||||
fieldData.fieldInfo.getVectorDimension());
|
||||
ByteVectorValues byteVectorValues =
|
||||
sortMap != null
|
||||
? new SortingByteVectorValues(bufferedByteVectorValues, sortMap)
|
||||
: bufferedByteVectorValues;
|
||||
writeField(fieldData.fieldInfo, byteVectorValues, maxDoc);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */
|
||||
private static class SortingVectorValues extends FloatVectorValues {
|
||||
private final BufferedVectorValues randomAccess;
|
||||
private static class SortingFloatVectorValues extends FloatVectorValues {
|
||||
private final BufferedFloatVectorValues randomAccess;
|
||||
private final int[] docIdOffsets;
|
||||
private int docId = -1;
|
||||
|
||||
SortingVectorValues(BufferedVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||
SortingFloatVectorValues(BufferedFloatVectorValues delegate, Sorter.DocMap sortMap)
|
||||
throws IOException {
|
||||
this.randomAccess = delegate.copy();
|
||||
this.docIdOffsets = new int[sortMap.size()];
|
||||
|
||||
|
@ -161,10 +161,67 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
/** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */
|
||||
private static class SortingByteVectorValues extends ByteVectorValues {
|
||||
private final BufferedByteVectorValues randomAccess;
|
||||
private final int[] docIdOffsets;
|
||||
private int docId = -1;
|
||||
|
||||
SortingByteVectorValues(BufferedByteVectorValues delegate, Sorter.DocMap sortMap)
|
||||
throws IOException {
|
||||
this.randomAccess = delegate.copy();
|
||||
this.docIdOffsets = new int[sortMap.size()];
|
||||
|
||||
int offset = 1; // 0 means no vector for this (field, document)
|
||||
int docID;
|
||||
while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) {
|
||||
int newDocID = sortMap.oldToNew(docID);
|
||||
docIdOffsets[newDocID] = offset++;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
while (docId < docIdOffsets.length - 1) {
|
||||
++docId;
|
||||
if (docIdOffsets[docId] != 0) {
|
||||
return docId;
|
||||
}
|
||||
}
|
||||
docId = NO_MORE_DOCS;
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return randomAccess.vectorValue(docIdOffsets[docId] - 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return randomAccess.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return randomAccess.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long total = 0;
|
||||
for (FieldWriter field : fields) {
|
||||
for (FieldWriter<?> field : fields) {
|
||||
total += field.ramBytesUsed();
|
||||
}
|
||||
return total;
|
||||
|
@ -172,58 +229,37 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
|
||||
@Override
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
KnnVectorsReader knnVectorsReader =
|
||||
new KnnVectorsReader() {
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() {}
|
||||
};
|
||||
writeField(fieldInfo, knnVectorsReader, mergeState.segmentInfo.maxDoc());
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case FLOAT32:
|
||||
FloatVectorValues floatVectorValues =
|
||||
MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
writeField(fieldInfo, floatVectorValues, mergeState.segmentInfo.maxDoc());
|
||||
break;
|
||||
case BYTE:
|
||||
ByteVectorValues byteVectorValues =
|
||||
MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
|
||||
writeField(fieldInfo, byteVectorValues, mergeState.segmentInfo.maxDoc());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/** Write the provided field */
|
||||
/** Write the provided float vector field */
|
||||
protected abstract void writeField(
|
||||
FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc) throws IOException;
|
||||
FieldInfo fieldInfo, FloatVectorValues floatVectorValues, int maxDoc) throws IOException;
|
||||
|
||||
private static class FieldWriter extends KnnFieldVectorsWriter<float[]> {
|
||||
/** Write the provided byte vector field */
|
||||
protected abstract void writeField(
|
||||
FieldInfo fieldInfo, ByteVectorValues byteVectorValues, int maxDoc) throws IOException;
|
||||
|
||||
private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
|
||||
private final FieldInfo fieldInfo;
|
||||
private final int dim;
|
||||
private final DocsWithFieldSet docsWithField;
|
||||
private final List<float[]> vectors;
|
||||
private final List<T> vectors;
|
||||
|
||||
private int lastDocID = -1;
|
||||
|
||||
public FieldWriter(FieldInfo fieldInfo) {
|
||||
FieldWriter(FieldInfo fieldInfo) {
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.dim = fieldInfo.getVectorDimension();
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
|
@ -231,7 +267,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void addValue(int docID, float[] value) {
|
||||
public final void addValue(int docID, T value) {
|
||||
if (docID == lastDocID) {
|
||||
throw new IllegalArgumentException(
|
||||
"VectorValuesField \""
|
||||
|
@ -245,12 +281,10 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public float[] copyValue(float[] vectorValue) {
|
||||
return ArrayUtil.copyOfSubArray(vectorValue, 0, dim);
|
||||
}
|
||||
public abstract T copyValue(T vectorValue);
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
public final long ramBytesUsed() {
|
||||
if (vectors.size() == 0) return 0;
|
||||
return docsWithField.ramBytesUsed()
|
||||
+ vectors.size()
|
||||
|
@ -261,8 +295,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
private static class BufferedVectorValues extends FloatVectorValues {
|
||||
|
||||
private static class BufferedFloatVectorValues extends FloatVectorValues {
|
||||
final DocsWithFieldSet docsWithField;
|
||||
|
||||
// These are always the vectors of a VectorValuesWriter, which are copied when added to it
|
||||
|
@ -272,15 +305,16 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
DocIdSetIterator docsWithFieldIter;
|
||||
int ord = -1;
|
||||
|
||||
BufferedVectorValues(DocsWithFieldSet docsWithField, List<float[]> vectors, int dimension) {
|
||||
BufferedFloatVectorValues(
|
||||
DocsWithFieldSet docsWithField, List<float[]> vectors, int dimension) {
|
||||
this.docsWithField = docsWithField;
|
||||
this.vectors = vectors;
|
||||
this.dimension = dimension;
|
||||
docsWithFieldIter = docsWithField.iterator();
|
||||
}
|
||||
|
||||
public BufferedVectorValues copy() {
|
||||
return new BufferedVectorValues(docsWithField, vectors, dimension);
|
||||
public BufferedFloatVectorValues copy() {
|
||||
return new BufferedFloatVectorValues(docsWithField, vectors, dimension);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -298,7 +332,67 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
return vectors.get(ord);
|
||||
}
|
||||
|
||||
public float[] vectorValue(int targetOrd) {
|
||||
float[] vectorValue(int targetOrd) {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docsWithFieldIter.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
int docID = docsWithFieldIter.nextDoc();
|
||||
if (docID != NO_MORE_DOCS) {
|
||||
++ord;
|
||||
}
|
||||
return docID;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
private static class BufferedByteVectorValues extends ByteVectorValues {
|
||||
final DocsWithFieldSet docsWithField;
|
||||
|
||||
// These are always the vectors of a VectorValuesWriter, which are copied when added to it
|
||||
final List<byte[]> vectors;
|
||||
final int dimension;
|
||||
|
||||
DocIdSetIterator docsWithFieldIter;
|
||||
int ord = -1;
|
||||
|
||||
BufferedByteVectorValues(DocsWithFieldSet docsWithField, List<byte[]> vectors, int dimension) {
|
||||
this.docsWithField = docsWithField;
|
||||
this.vectors = vectors;
|
||||
this.dimension = dimension;
|
||||
docsWithFieldIter = docsWithField.iterator();
|
||||
}
|
||||
|
||||
public BufferedByteVectorValues copy() {
|
||||
return new BufferedByteVectorValues(docsWithField, vectors, dimension);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return vectors.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() {
|
||||
return vectors.get(ord);
|
||||
}
|
||||
|
||||
byte[] vectorValue(int targetOrd) {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
KnnFieldVectorsWriter<float[]> floatWriter =
|
||||
(KnnFieldVectorsWriter<float[]>) addField(fieldInfo);
|
||||
FloatVectorValues mergedFloats =
|
||||
MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
for (int doc = mergedFloats.nextDoc();
|
||||
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||
doc = mergedFloats.nextDoc()) {
|
||||
|
@ -143,8 +143,8 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
private MergedVectorValues() {}
|
||||
|
||||
/** Returns a merged view over all the segment's {@link FloatVectorValues}. */
|
||||
public static FloatVectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
|
||||
throws IOException {
|
||||
public static FloatVectorValues mergeFloatVectorValues(
|
||||
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
assert fieldInfo != null && fieldInfo.hasVectorValues();
|
||||
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
||||
throw new UnsupportedOperationException(
|
||||
|
|
|
@ -402,7 +402,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
case BYTE -> writeByteVectorData(
|
||||
tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
|
||||
case FLOAT32 -> writeVectorData(
|
||||
tempVectorData, MergedVectorValues.mergeVectorValues(fieldInfo, mergeState));
|
||||
tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
|
||||
};
|
||||
CodecUtil.writeFooter(tempVectorData);
|
||||
IOUtils.close(tempVectorData);
|
||||
|
|
Loading…
Reference in New Issue