mirror of https://github.com/apache/lucene.git
iter
This commit is contained in:
parent
f46920af9d
commit
32ce602213
|
@ -32,6 +32,7 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
|
||||||
final int vectorByteSize;
|
final int vectorByteSize;
|
||||||
final MemorySegmentAccessInput input;
|
final MemorySegmentAccessInput input;
|
||||||
final MemorySegment query;
|
final MemorySegment query;
|
||||||
|
final float constMultiplier;
|
||||||
byte[] scratch;
|
byte[] scratch;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -56,16 +57,16 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
|
||||||
}
|
}
|
||||||
checkInvariants(values.size(), values.getVectorByteLength(), input);
|
checkInvariants(values.size(), values.getVectorByteLength(), input);
|
||||||
return switch (type) {
|
return switch (type) {
|
||||||
case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector));
|
case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector, constMultiplier, offsetCorrection));
|
||||||
case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector));
|
case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector, constMultiplier, offsetCorrection));
|
||||||
case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector));
|
case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector, constMultiplier));
|
||||||
case MAXIMUM_INNER_PRODUCT -> Optional.of(
|
case MAXIMUM_INNER_PRODUCT -> Optional.of(
|
||||||
new MaxInnerProductScorer(msInput, values, queryVector));
|
new MaxInnerProductScorer(msInput, values, queryVector, offsetCorrection));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
Lucene99MemorySegmentScalarQuantizedVectorScorer(
|
Lucene99MemorySegmentScalarQuantizedVectorScorer(
|
||||||
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector) {
|
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector, float constMultiplier) {
|
||||||
super(values);
|
super(values);
|
||||||
this.input = input;
|
this.input = input;
|
||||||
this.vectorByteLength = values.getVectorByteLength();
|
this.vectorByteLength = values.getVectorByteLength();
|
||||||
|
@ -101,50 +102,57 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
|
||||||
|
|
||||||
static final class DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
|
static final class DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
|
||||||
DotProductScorer(
|
DotProductScorer(
|
||||||
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) {
|
||||||
super(input, values, query);
|
super(input, values, query, constMultiplier);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float score(int node) throws IOException {
|
public float score(int node) throws IOException {
|
||||||
checkOrdinal(node);
|
checkOrdinal(node);
|
||||||
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
|
|
||||||
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
|
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
|
||||||
return 0.5f + raw / (float) (query.byteSize() * (1 << 15));
|
float vectorOffset = values.getScoreCorrectionConstant(node);
|
||||||
|
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
|
||||||
|
assert dotProduct >= 0;
|
||||||
|
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
|
||||||
|
return Math.max((1 + adjustedDistance) / 2, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static final class Int4DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
|
static final class Int4DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
|
||||||
Int4DotProductScorer(
|
Int4DotProductScorer(
|
||||||
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) {
|
||||||
super(input, values, query);
|
super(input, values, query, constMultiplier);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float score(int node) throws IOException {
|
public float score(int node) throws IOException {
|
||||||
checkOrdinal(node);
|
checkOrdinal(node);
|
||||||
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
|
float raw = PanamaVectorUtilSupport.int4DotProduct(query, false, getSegment(node), false);
|
||||||
float raw = PanamaVectorUtilSupport.int4DotProduct(query, getSegment(node));
|
float vectorOffset = values.getScoreCorrectionConstant(node);
|
||||||
return 0.5f + raw / (float) (query.byteSize() * (1 << 15));
|
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
|
||||||
|
assert dotProduct >= 0;
|
||||||
|
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
|
||||||
|
return Math.max((1 + adjustedDistance) / 2, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static final class EuclideanScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
|
static final class EuclideanScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
|
||||||
EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
EuclideanScorer(MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier) {
|
||||||
super(input, values, query);
|
super(input, values, query, constMultiplier);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float score(int node) throws IOException {
|
public float score(int node) throws IOException {
|
||||||
checkOrdinal(node);
|
checkOrdinal(node);
|
||||||
float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node));
|
float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node));
|
||||||
return 1 / (1f + raw);
|
float adjustedDistance = raw * constMultiplier;
|
||||||
|
return 1 / (1f + adjustedDistance);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static final class MaxInnerProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
|
static final class MaxInnerProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer {
|
||||||
MaxInnerProductScorer(
|
MaxInnerProductScorer(
|
||||||
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) {
|
||||||
super(input, values, query);
|
super(input, values, query);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,10 +160,14 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer
|
||||||
public float score(int node) throws IOException {
|
public float score(int node) throws IOException {
|
||||||
checkOrdinal(node);
|
checkOrdinal(node);
|
||||||
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
|
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
|
||||||
if (raw < 0) {
|
float vectorOffset = values.getScoreCorrectionConstant(node);
|
||||||
return 1 / (1 + -1 * raw);
|
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
|
||||||
|
assert dotProduct >= 0;
|
||||||
|
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
|
||||||
|
if (adjustedDistance < 0) {
|
||||||
|
return 1 / (1 + -1 * adjustedDistance);
|
||||||
}
|
}
|
||||||
return raw + 1;
|
return adjustedDistance + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -457,9 +457,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
||||||
int innerLimit = Math.min(limit - i, 4096);
|
int innerLimit = Math.min(limit - i, 4096);
|
||||||
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) {
|
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) {
|
||||||
// packed
|
// packed
|
||||||
var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j);
|
var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, packed, i + j);
|
||||||
// unpacked
|
// unpacked
|
||||||
var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length);
|
var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j + packed.length);
|
||||||
|
|
||||||
// upper
|
// upper
|
||||||
ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
|
ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
|
||||||
|
@ -467,7 +467,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
||||||
acc0 = acc0.add(prod16);
|
acc0 = acc0.add(prod16);
|
||||||
|
|
||||||
// lower
|
// lower
|
||||||
ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j);
|
ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j);
|
||||||
ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8);
|
ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8);
|
||||||
Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0);
|
Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0);
|
||||||
acc1 = acc1.add(prod16a);
|
acc1 = acc1.add(prod16a);
|
||||||
|
@ -490,9 +490,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
||||||
int innerLimit = Math.min(limit - i, 2048);
|
int innerLimit = Math.min(limit - i, 2048);
|
||||||
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
|
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
|
||||||
// packed
|
// packed
|
||||||
var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j);
|
var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, packed, i + j);
|
||||||
// unpacked
|
// unpacked
|
||||||
var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length);
|
var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j + packed.length);
|
||||||
|
|
||||||
// upper
|
// upper
|
||||||
ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
|
ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
|
||||||
|
@ -500,7 +500,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
||||||
acc0 = acc0.add(prod16);
|
acc0 = acc0.add(prod16);
|
||||||
|
|
||||||
// lower
|
// lower
|
||||||
ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j);
|
ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j);
|
||||||
ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8);
|
ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8);
|
||||||
Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0);
|
Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0);
|
||||||
acc1 = acc1.add(prod16a);
|
acc1 = acc1.add(prod16a);
|
||||||
|
@ -524,7 +524,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
||||||
int innerLimit = Math.min(limit - i, 1024);
|
int innerLimit = Math.min(limit - i, 1024);
|
||||||
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) {
|
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) {
|
||||||
// packed
|
// packed
|
||||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j);
|
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, packed, i + j);
|
||||||
// unpacked
|
// unpacked
|
||||||
ByteVector va8 =
|
ByteVector va8 =
|
||||||
ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length);
|
ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length);
|
||||||
|
@ -536,7 +536,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
||||||
acc0 = acc0.add(prod16.and((short) 0xFF));
|
acc0 = acc0.add(prod16.and((short) 0xFF));
|
||||||
|
|
||||||
// lower
|
// lower
|
||||||
va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j);
|
va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, unpacked, i + j);
|
||||||
prod8 = vb8.lanewise(LSHR, 4).mul(va8);
|
prod8 = vb8.lanewise(LSHR, 4).mul(va8);
|
||||||
prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
|
prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
|
||||||
acc1 = acc1.add(prod16.and((short) 0xFF));
|
acc1 = acc1.add(prod16.and((short) 0xFF));
|
||||||
|
@ -558,15 +558,15 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
||||||
ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128);
|
ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128);
|
||||||
int innerLimit = Math.min(limit - i, 1024);
|
int innerLimit = Math.min(limit - i, 1024);
|
||||||
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
|
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
|
||||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j);
|
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j);
|
||||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j);
|
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j);
|
||||||
ByteVector prod8 = va8.mul(vb8);
|
ByteVector prod8 = va8.mul(vb8);
|
||||||
ShortVector prod16 =
|
ShortVector prod16 =
|
||||||
prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
|
prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
|
||||||
acc0 = acc0.add(prod16.and((short) 0xFF));
|
acc0 = acc0.add(prod16.and((short) 0xFF));
|
||||||
|
|
||||||
va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8);
|
va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j + 8);
|
||||||
vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8);
|
vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j + 8);
|
||||||
prod8 = va8.mul(vb8);
|
prod8 = va8.mul(vb8);
|
||||||
prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
|
prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
|
||||||
acc1 = acc1.add(prod16.and((short) 0xFF));
|
acc1 = acc1.add(prod16.and((short) 0xFF));
|
||||||
|
|
Loading…
Reference in New Issue