This commit is contained in:
Benjamin Trent 2024-06-17 11:19:55 -04:00
parent f46920af9d
commit 32ce602213
2 changed files with 45 additions and 33 deletions

View File

@ -32,6 +32,7 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
final int vectorByteSize; final int vectorByteSize;
final MemorySegmentAccessInput input; final MemorySegmentAccessInput input;
final MemorySegment query; final MemorySegment query;
final float constMultiplier;
byte[] scratch; byte[] scratch;
/** /**
@ -56,16 +57,16 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
} }
checkInvariants(values.size(), values.getVectorByteLength(), input); checkInvariants(values.size(), values.getVectorByteLength(), input);
return switch (type) { return switch (type) {
case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector)); case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector, constMultiplier, offsetCorrection));
case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector)); case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector, constMultiplier, offsetCorrection));
case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector)); case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector, constMultiplier));
case MAXIMUM_INNER_PRODUCT -> Optional.of( case MAXIMUM_INNER_PRODUCT -> Optional.of(
new MaxInnerProductScorer(msInput, values, queryVector)); new MaxInnerProductScorer(msInput, values, queryVector, offsetCorrection));
}; };
} }
Lucene99MemorySegmentScalarQuantizedVectorScorer( Lucene99MemorySegmentScalarQuantizedVectorScorer(
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector) { MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector, float constMultiplier) {
super(values); super(values);
this.input = input; this.input = input;
this.vectorByteLength = values.getVectorByteLength(); this.vectorByteLength = values.getVectorByteLength();
@ -101,50 +102,57 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
static final class DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { static final class DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
DotProductScorer( DotProductScorer(
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) {
super(input, values, query); super(input, values, query, constMultiplier);
} }
@Override @Override
public float score(int node) throws IOException { public float score(int node) throws IOException {
checkOrdinal(node); checkOrdinal(node);
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
return 0.5f + raw / (float) (query.byteSize() * (1 << 15)); float vectorOffset = values.getScoreCorrectionConstant(node);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return Math.max((1 + adjustedDistance) / 2, 0);
} }
} }
static final class Int4DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { static final class Int4DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
Int4DotProductScorer( Int4DotProductScorer(
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) {
super(input, values, query); super(input, values, query, constMultiplier);
} }
@Override @Override
public float score(int node) throws IOException { public float score(int node) throws IOException {
checkOrdinal(node); checkOrdinal(node);
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len float raw = PanamaVectorUtilSupport.int4DotProduct(query, false, getSegment(node), false);
float raw = PanamaVectorUtilSupport.int4DotProduct(query, getSegment(node)); float vectorOffset = values.getScoreCorrectionConstant(node);
return 0.5f + raw / (float) (query.byteSize() * (1 << 15)); // For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return Math.max((1 + adjustedDistance) / 2, 0);
} }
} }
static final class EuclideanScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { static final class EuclideanScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { EuclideanScorer(MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier) {
super(input, values, query); super(input, values, query, constMultiplier);
} }
@Override @Override
public float score(int node) throws IOException { public float score(int node) throws IOException {
checkOrdinal(node); checkOrdinal(node);
float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node)); float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node));
return 1 / (1f + raw); float adjustedDistance = raw * constMultiplier;
return 1 / (1f + adjustedDistance);
} }
} }
static final class MaxInnerProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { static final class MaxInnerProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
MaxInnerProductScorer( MaxInnerProductScorer(
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) {
super(input, values, query); super(input, values, query);
} }
@ -152,10 +160,14 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
public float score(int node) throws IOException { public float score(int node) throws IOException {
checkOrdinal(node); checkOrdinal(node);
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
if (raw < 0) { float vectorOffset = values.getScoreCorrectionConstant(node);
return 1 / (1 + -1 * raw); // For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
if (adjustedDistance < 0) {
return 1 / (1 + -1 * adjustedDistance);
} }
return raw + 1; return adjustedDistance + 1;
} }
} }
} }

View File

@ -457,9 +457,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
int innerLimit = Math.min(limit - i, 4096); int innerLimit = Math.min(limit - i, 4096);
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) {
// packed // packed
var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j); var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, packed, i + j);
// unpacked // unpacked
var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length); var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j + packed.length);
// upper // upper
ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
@ -467,7 +467,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
acc0 = acc0.add(prod16); acc0 = acc0.add(prod16);
// lower // lower
ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j); ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j);
ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8);
Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0);
acc1 = acc1.add(prod16a); acc1 = acc1.add(prod16a);
@ -490,9 +490,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
int innerLimit = Math.min(limit - i, 2048); int innerLimit = Math.min(limit - i, 2048);
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
// packed // packed
var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j); var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, packed, i + j);
// unpacked // unpacked
var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length); var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j + packed.length);
// upper // upper
ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
@ -500,7 +500,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
acc0 = acc0.add(prod16); acc0 = acc0.add(prod16);
// lower // lower
ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j); ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j);
ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8);
Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0);
acc1 = acc1.add(prod16a); acc1 = acc1.add(prod16a);
@ -524,7 +524,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
int innerLimit = Math.min(limit - i, 1024); int innerLimit = Math.min(limit - i, 1024);
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) { for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) {
// packed // packed
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j); ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, packed, i + j);
// unpacked // unpacked
ByteVector va8 = ByteVector va8 =
ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length); ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length);
@ -536,7 +536,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
acc0 = acc0.add(prod16.and((short) 0xFF)); acc0 = acc0.add(prod16.and((short) 0xFF));
// lower // lower
va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j); va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, unpacked, i + j);
prod8 = vb8.lanewise(LSHR, 4).mul(va8); prod8 = vb8.lanewise(LSHR, 4).mul(va8);
prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
acc1 = acc1.add(prod16.and((short) 0xFF)); acc1 = acc1.add(prod16.and((short) 0xFF));
@ -558,15 +558,15 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128);
int innerLimit = Math.min(limit - i, 1024); int innerLimit = Math.min(limit - i, 1024);
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j); ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j); ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j);
ByteVector prod8 = va8.mul(vb8); ByteVector prod8 = va8.mul(vb8);
ShortVector prod16 = ShortVector prod16 =
prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
acc0 = acc0.add(prod16.and((short) 0xFF)); acc0 = acc0.add(prod16.and((short) 0xFF));
va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8); va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j + 8);
vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8); vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j + 8);
prod8 = va8.mul(vb8); prod8 = va8.mul(vb8);
prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
acc1 = acc1.add(prod16.and((short) 0xFF)); acc1 = acc1.add(prod16.and((short) 0xFF));