mirror of https://github.com/apache/lucene.git
iter
This commit is contained in:
parent
f46920af9d
commit
32ce602213
|
@ -32,6 +32,7 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
|
|||
final int vectorByteSize;
|
||||
final MemorySegmentAccessInput input;
|
||||
final MemorySegment query;
|
||||
final float constMultiplier;
|
||||
byte[] scratch;
|
||||
|
||||
/**
|
||||
|
@ -56,16 +57,16 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
|
|||
}
|
||||
checkInvariants(values.size(), values.getVectorByteLength(), input);
|
||||
return switch (type) {
|
||||
case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector));
|
||||
case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector));
|
||||
case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector));
|
||||
case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector, constMultiplier, offsetCorrection));
|
||||
case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector, constMultiplier, offsetCorrection));
|
||||
case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector, constMultiplier));
|
||||
case MAXIMUM_INNER_PRODUCT -> Optional.of(
|
||||
new MaxInnerProductScorer(msInput, values, queryVector));
|
||||
new MaxInnerProductScorer(msInput, values, queryVector, offsetCorrection));
|
||||
};
|
||||
}
|
||||
|
||||
Lucene99MemorySegmentScalarQuantizedVectorScorer(
|
||||
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector) {
|
||||
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector, float constMultiplier) {
|
||||
super(values);
|
||||
this.input = input;
|
||||
this.vectorByteLength = values.getVectorByteLength();
|
||||
|
@ -101,50 +102,57 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
|
|||
|
||||
static final class DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
|
||||
DotProductScorer(
|
||||
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
||||
super(input, values, query);
|
||||
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) {
|
||||
super(input, values, query, constMultiplier);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
|
||||
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
|
||||
return 0.5f + raw / (float) (query.byteSize() * (1 << 15));
|
||||
float vectorOffset = values.getScoreCorrectionConstant(node);
|
||||
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
|
||||
assert dotProduct >= 0;
|
||||
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
|
||||
return Math.max((1 + adjustedDistance) / 2, 0);
|
||||
}
|
||||
}
|
||||
|
||||
static final class Int4DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
|
||||
Int4DotProductScorer(
|
||||
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
||||
super(input, values, query);
|
||||
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) {
|
||||
super(input, values, query, constMultiplier);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
|
||||
float raw = PanamaVectorUtilSupport.int4DotProduct(query, getSegment(node));
|
||||
return 0.5f + raw / (float) (query.byteSize() * (1 << 15));
|
||||
float raw = PanamaVectorUtilSupport.int4DotProduct(query, false, getSegment(node), false);
|
||||
float vectorOffset = values.getScoreCorrectionConstant(node);
|
||||
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
|
||||
assert dotProduct >= 0;
|
||||
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
|
||||
return Math.max((1 + adjustedDistance) / 2, 0);
|
||||
}
|
||||
}
|
||||
|
||||
static final class EuclideanScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
|
||||
EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
||||
super(input, values, query);
|
||||
EuclideanScorer(MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier) {
|
||||
super(input, values, query, constMultiplier);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node));
|
||||
return 1 / (1f + raw);
|
||||
float adjustedDistance = raw * constMultiplier;
|
||||
return 1 / (1f + adjustedDistance);
|
||||
}
|
||||
}
|
||||
|
||||
static final class MaxInnerProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
|
||||
MaxInnerProductScorer(
|
||||
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
||||
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) {
|
||||
super(input, values, query);
|
||||
}
|
||||
|
||||
|
@ -152,10 +160,14 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
|
|||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
|
||||
if (raw < 0) {
|
||||
return 1 / (1 + -1 * raw);
|
||||
float vectorOffset = values.getScoreCorrectionConstant(node);
|
||||
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
|
||||
assert dotProduct >= 0;
|
||||
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
|
||||
if (adjustedDistance < 0) {
|
||||
return 1 / (1 + -1 * adjustedDistance);
|
||||
}
|
||||
return raw + 1;
|
||||
return adjustedDistance + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -457,9 +457,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
int innerLimit = Math.min(limit - i, 4096);
|
||||
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) {
|
||||
// packed
|
||||
var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j);
|
||||
var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, packed, i + j);
|
||||
// unpacked
|
||||
var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length);
|
||||
var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j + packed.length);
|
||||
|
||||
// upper
|
||||
ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
|
||||
|
@ -467,7 +467,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
acc0 = acc0.add(prod16);
|
||||
|
||||
// lower
|
||||
ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j);
|
||||
ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j);
|
||||
ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8);
|
||||
Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0);
|
||||
acc1 = acc1.add(prod16a);
|
||||
|
@ -490,9 +490,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
int innerLimit = Math.min(limit - i, 2048);
|
||||
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
|
||||
// packed
|
||||
var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j);
|
||||
var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, packed, i + j);
|
||||
// unpacked
|
||||
var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length);
|
||||
var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j + packed.length);
|
||||
|
||||
// upper
|
||||
ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
|
||||
|
@ -500,7 +500,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
acc0 = acc0.add(prod16);
|
||||
|
||||
// lower
|
||||
ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j);
|
||||
ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j);
|
||||
ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8);
|
||||
Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0);
|
||||
acc1 = acc1.add(prod16a);
|
||||
|
@ -524,7 +524,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
int innerLimit = Math.min(limit - i, 1024);
|
||||
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) {
|
||||
// packed
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j);
|
||||
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, packed, i + j);
|
||||
// unpacked
|
||||
ByteVector va8 =
|
||||
ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length);
|
||||
|
@ -536,7 +536,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
acc0 = acc0.add(prod16.and((short) 0xFF));
|
||||
|
||||
// lower
|
||||
va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j);
|
||||
va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, unpacked, i + j);
|
||||
prod8 = vb8.lanewise(LSHR, 4).mul(va8);
|
||||
prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
|
||||
acc1 = acc1.add(prod16.and((short) 0xFF));
|
||||
|
@ -558,15 +558,15 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128);
|
||||
int innerLimit = Math.min(limit - i, 1024);
|
||||
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j);
|
||||
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j);
|
||||
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j);
|
||||
ByteVector prod8 = va8.mul(vb8);
|
||||
ShortVector prod16 =
|
||||
prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
|
||||
acc0 = acc0.add(prod16.and((short) 0xFF));
|
||||
|
||||
va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8);
|
||||
vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8);
|
||||
va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j + 8);
|
||||
vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j + 8);
|
||||
prod8 = va8.mul(vb8);
|
||||
prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
|
||||
acc1 = acc1.add(prod16.and((short) 0xFF));
|
||||
|
|
Loading…
Reference in New Issue