SlowCompositeCodecReaderWrapper must copy its sub-vector values to maintain thread-safety (#14092)

This commit is contained in:
Michael Sokolov 2024-12-31 12:05:12 -05:00 committed by GitHub
parent 525b963be0
commit 68051f1b9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 19 additions and 14 deletions

View File

@ -301,7 +301,12 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
} }
} }
private record DocValuesSub<T extends KnnVectorValues>(T sub, int docStart, int ordStart) {} private record DocValuesSub<T extends KnnVectorValues>(T sub, int docStart, int ordStart) {
@SuppressWarnings("unchecked")
DocValuesSub<T> copy() throws IOException {
return new DocValuesSub<T>((T) (sub.copy()), docStart, ordStart);
}
}
private static class MergedDocIterator<T extends KnnVectorValues> private static class MergedDocIterator<T extends KnnVectorValues>
extends KnnVectorValues.DocIndexIterator { extends KnnVectorValues.DocIndexIterator {
@ -850,7 +855,7 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
class MergedFloatVectorValues extends FloatVectorValues { class MergedFloatVectorValues extends FloatVectorValues {
final int dimension; final int dimension;
final int size; final int size;
final DocValuesSub<?>[] subs; final List<DocValuesSub<FloatVectorValues>> subs;
final MergedDocIterator<FloatVectorValues> iter; final MergedDocIterator<FloatVectorValues> iter;
final int[] starts; final int[] starts;
int lastSubIndex; int lastSubIndex;
@ -858,7 +863,7 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
MergedFloatVectorValues(int dimension, int size, List<DocValuesSub<FloatVectorValues>> subs) { MergedFloatVectorValues(int dimension, int size, List<DocValuesSub<FloatVectorValues>> subs) {
this.dimension = dimension; this.dimension = dimension;
this.size = size; this.size = size;
this.subs = subs.toArray(new DocValuesSub<?>[0]); this.subs = subs;
iter = new MergedDocIterator<>(subs); iter = new MergedDocIterator<>(subs);
// [0, start(1), ..., size] - we want the extra element // [0, start(1), ..., size] - we want the extra element
// to avoid checking for out-of-array bounds // to avoid checking for out-of-array bounds
@ -888,8 +893,8 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
@Override @Override
public FloatVectorValues copy() throws IOException { public FloatVectorValues copy() throws IOException {
List<DocValuesSub<FloatVectorValues>> subsCopy = new ArrayList<>(); List<DocValuesSub<FloatVectorValues>> subsCopy = new ArrayList<>();
for (Object sub : subs) { for (DocValuesSub<FloatVectorValues> sub : subs) {
subsCopy.add((DocValuesSub<FloatVectorValues>) sub); subsCopy.add(sub.copy());
} }
return new MergedFloatVectorValues(dimension, size, subsCopy); return new MergedFloatVectorValues(dimension, size, subsCopy);
} }
@ -900,9 +905,9 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
// We need to implement fully random-access API here in order to support callers like // We need to implement fully random-access API here in order to support callers like
// SortingCodecReader that rely on it. // SortingCodecReader that rely on it.
lastSubIndex = findSub(ord, lastSubIndex, starts); lastSubIndex = findSub(ord, lastSubIndex, starts);
assert subs[lastSubIndex].sub != null; DocValuesSub<FloatVectorValues> sub = subs.get(lastSubIndex);
return ((FloatVectorValues) subs[lastSubIndex].sub) assert sub.sub != null;
.vectorValue(ord - subs[lastSubIndex].ordStart); return (sub.sub).vectorValue(ord - sub.ordStart);
} }
} }
@ -929,7 +934,7 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
class MergedByteVectorValues extends ByteVectorValues { class MergedByteVectorValues extends ByteVectorValues {
final int dimension; final int dimension;
final int size; final int size;
final DocValuesSub<?>[] subs; final List<DocValuesSub<ByteVectorValues>> subs;
final MergedDocIterator<ByteVectorValues> iter; final MergedDocIterator<ByteVectorValues> iter;
final int[] starts; final int[] starts;
int lastSubIndex; int lastSubIndex;
@ -937,7 +942,7 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
MergedByteVectorValues(int dimension, int size, List<DocValuesSub<ByteVectorValues>> subs) { MergedByteVectorValues(int dimension, int size, List<DocValuesSub<ByteVectorValues>> subs) {
this.dimension = dimension; this.dimension = dimension;
this.size = size; this.size = size;
this.subs = subs.toArray(new DocValuesSub<?>[0]); this.subs = subs;
iter = new MergedDocIterator<>(subs); iter = new MergedDocIterator<>(subs);
// [0, start(1), ..., size] - we want the extra element // [0, start(1), ..., size] - we want the extra element
// to avoid checking for out-of-array bounds // to avoid checking for out-of-array bounds
@ -970,16 +975,16 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
// SortingCodecReader that rely on it. We maintain lastSubIndex since we expect some // SortingCodecReader that rely on it. We maintain lastSubIndex since we expect some
// repetition. // repetition.
lastSubIndex = findSub(ord, lastSubIndex, starts); lastSubIndex = findSub(ord, lastSubIndex, starts);
return ((ByteVectorValues) subs[lastSubIndex].sub) DocValuesSub<ByteVectorValues> sub = subs.get(lastSubIndex);
.vectorValue(ord - subs[lastSubIndex].ordStart); return sub.sub.vectorValue(ord - sub.ordStart);
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Override @Override
public ByteVectorValues copy() throws IOException { public ByteVectorValues copy() throws IOException {
List<DocValuesSub<ByteVectorValues>> newSubs = new ArrayList<>(); List<DocValuesSub<ByteVectorValues>> newSubs = new ArrayList<>();
for (Object sub : subs) { for (DocValuesSub<ByteVectorValues> sub : subs) {
newSubs.add((DocValuesSub<ByteVectorValues>) sub); newSubs.add(sub.copy());
} }
return new MergedByteVectorValues(dimension, size, newSubs); return new MergedByteVectorValues(dimension, size, newSubs);
} }