This commit is contained in:
Benjamin Trent 2024-06-17 11:19:55 -04:00
parent f46920af9d
commit 32ce602213
2 changed files with 45 additions and 33 deletions

View File

@ -32,6 +32,7 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
final int vectorByteSize;
final MemorySegmentAccessInput input;
final MemorySegment query;
final float constMultiplier;
byte[] scratch;
/**
@ -56,16 +57,16 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
}
checkInvariants(values.size(), values.getVectorByteLength(), input);
return switch (type) {
case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector));
case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector));
case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector));
case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector, constMultiplier, offsetCorrection));
case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector, constMultiplier, offsetCorrection));
case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector, constMultiplier));
case MAXIMUM_INNER_PRODUCT -> Optional.of(
new MaxInnerProductScorer(msInput, values, queryVector));
new MaxInnerProductScorer(msInput, values, queryVector, offsetCorrection));
};
}
Lucene99MemorySegmentScalarQuantizedVectorScorer(
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector) {
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector, float constMultiplier) {
super(values);
this.input = input;
this.vectorByteLength = values.getVectorByteLength();
@ -101,50 +102,57 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
static final class DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
DotProductScorer(
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
super(input, values, query);
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) {
super(input, values, query, constMultiplier);
}
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
return 0.5f + raw / (float) (query.byteSize() * (1 << 15));
float vectorOffset = values.getScoreCorrectionConstant(node);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return Math.max((1 + adjustedDistance) / 2, 0);
}
}
static final class Int4DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
Int4DotProductScorer(
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
super(input, values, query);
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) {
super(input, values, query, constMultiplier);
}
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
float raw = PanamaVectorUtilSupport.int4DotProduct(query, getSegment(node));
return 0.5f + raw / (float) (query.byteSize() * (1 << 15));
float raw = PanamaVectorUtilSupport.int4DotProduct(query, false, getSegment(node), false);
float vectorOffset = values.getScoreCorrectionConstant(node);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return Math.max((1 + adjustedDistance) / 2, 0);
}
}
static final class EuclideanScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
super(input, values, query);
EuclideanScorer(MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier) {
super(input, values, query, constMultiplier);
}
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node));
return 1 / (1f + raw);
float adjustedDistance = raw * constMultiplier;
return 1 / (1f + adjustedDistance);
}
}
static final class MaxInnerProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
MaxInnerProductScorer(
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) {
super(input, values, query);
}
@ -152,10 +160,14 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
public float score(int node) throws IOException {
checkOrdinal(node);
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
if (raw < 0) {
return 1 / (1 + -1 * raw);
float vectorOffset = values.getScoreCorrectionConstant(node);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
if (adjustedDistance < 0) {
return 1 / (1 + -1 * adjustedDistance);
}
return raw + 1;
return adjustedDistance + 1;
}
}
}

View File

@ -457,9 +457,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
int innerLimit = Math.min(limit - i, 4096);
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) {
// packed
var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j);
var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, packed, i + j);
// unpacked
var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length);
var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j + packed.length);
// upper
ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
@ -467,7 +467,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
acc0 = acc0.add(prod16);
// lower
ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j);
ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j);
ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8);
Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0);
acc1 = acc1.add(prod16a);
@ -490,9 +490,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
int innerLimit = Math.min(limit - i, 2048);
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
// packed
var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j);
var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, packed, i + j);
// unpacked
var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length);
var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j + packed.length);
// upper
ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
@ -500,7 +500,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
acc0 = acc0.add(prod16);
// lower
ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j);
ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j);
ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8);
Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0);
acc1 = acc1.add(prod16a);
@ -524,7 +524,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
int innerLimit = Math.min(limit - i, 1024);
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) {
// packed
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j);
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, packed, i + j);
// unpacked
ByteVector va8 =
ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length);
@ -536,7 +536,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
acc0 = acc0.add(prod16.and((short) 0xFF));
// lower
va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j);
va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, unpacked, i + j);
prod8 = vb8.lanewise(LSHR, 4).mul(va8);
prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
acc1 = acc1.add(prod16.and((short) 0xFF));
@ -558,15 +558,15 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128);
int innerLimit = Math.min(limit - i, 1024);
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j);
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j);
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j);
ByteVector prod8 = va8.mul(vb8);
ShortVector prod16 =
prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
acc0 = acc0.add(prod16.and((short) 0xFF));
va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8);
vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8);
va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j + 8);
vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j + 8);
prod8 = va8.mul(vb8);
prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
acc1 = acc1.add(prod16.and((short) 0xFF));