From 32ce6022137782cb0eb9d67e628626e761d797c2 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 17 Jun 2024 11:19:55 -0400 Subject: [PATCH] iter --- ...orySegmentScalarQuantizedVectorScorer.java | 54 +++++++++++-------- .../PanamaVectorUtilSupport.java | 24 ++++----- 2 files changed, 45 insertions(+), 33 deletions(-) diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java index 0fdd236057b..ef1c96e4d75 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java @@ -32,6 +32,7 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer final int vectorByteSize; final MemorySegmentAccessInput input; final MemorySegment query; + final float constMultiplier; byte[] scratch; /** @@ -56,16 +57,16 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer } checkInvariants(values.size(), values.getVectorByteLength(), input); return switch (type) { - case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector)); - case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector)); - case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector)); + case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector, constMultiplier, offsetCorrection)); + case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector, constMultiplier, offsetCorrection)); + case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector, constMultiplier)); case MAXIMUM_INNER_PRODUCT -> Optional.of( - new MaxInnerProductScorer(msInput, values, queryVector)); + new MaxInnerProductScorer(msInput, values, queryVector, offsetCorrection)); }; } Lucene99MemorySegmentScalarQuantizedVectorScorer( - MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector) { + MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector, float constMultiplier) { super(values); this.input = input; this.vectorByteLength = values.getVectorByteLength(); @@ -101,50 +102,57 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer static final class DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { DotProductScorer( - MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { - super(input, values, query); + MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) { + super(input, values, query, constMultiplier); } @Override public float score(int node) throws IOException { checkOrdinal(node); - // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); - return 0.5f + raw / (float) (query.byteSize() * (1 << 15)); + float vectorOffset = values.getScoreCorrectionConstant(node); + // For the current implementation of scalar quantization, all dotproducts should be >= 0; + assert dotProduct >= 0; + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + return Math.max((1 + adjustedDistance) / 2, 0); } } static final class Int4DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { Int4DotProductScorer( - MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { - super(input, values, query); + MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) { + super(input, values, query, constMultiplier); } @Override public float score(int node) throws IOException { checkOrdinal(node); - // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len - float raw = PanamaVectorUtilSupport.int4DotProduct(query, getSegment(node)); - return 0.5f + raw / (float) (query.byteSize() * (1 << 15)); + float raw = PanamaVectorUtilSupport.int4DotProduct(query, false, getSegment(node), false); + float vectorOffset = values.getScoreCorrectionConstant(node); + // For the current implementation of scalar quantization, all dotproducts should be >= 0; + assert dotProduct >= 0; + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + return Math.max((1 + adjustedDistance) / 2, 0); } } static final class EuclideanScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { - EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { - super(input, values, query); + EuclideanScorer(MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier) { + super(input, values, query, constMultiplier); } @Override public float score(int node) throws IOException { checkOrdinal(node); float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node)); - return 1 / (1f + raw); + float adjustedDistance = raw * constMultiplier; + return 1 / (1f + adjustedDistance); } } static final class MaxInnerProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { MaxInnerProductScorer( - MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) { super(input, values, query); } @@ -152,10 +160,14 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer public float score(int node) throws IOException { checkOrdinal(node); float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); - if (raw < 0) { - return 1 / (1 + -1 * raw); + float vectorOffset = values.getScoreCorrectionConstant(node); + // For the current implementation of scalar quantization, all dotproducts should be >= 0; + assert dotProduct >= 0; + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + if (adjustedDistance < 0) { + return 1 / (1 + -1 * adjustedDistance); } - return raw + 1; + return adjustedDistance + 1; } } } diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index ec8186f7160..b6ac4892ec5 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -457,9 +457,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { int innerLimit = Math.min(limit - i, 4096); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { // packed - var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j); + var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, packed, i + j); // unpacked - var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length); + var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j + packed.length); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -467,7 +467,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { acc0 = acc0.add(prod16); // lower - ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j); + ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); acc1 = acc1.add(prod16a); @@ -490,9 +490,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { int innerLimit = Math.min(limit - i, 2048); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { // packed - var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j); + var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, packed, i + j); // unpacked - var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length); + var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j + packed.length); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -500,7 +500,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { acc0 = acc0.add(prod16); // lower - ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j); + ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); acc1 = acc1.add(prod16a); @@ -524,7 +524,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { int innerLimit = Math.min(limit - i, 1024); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) { // packed - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j); + ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, packed, i + j); // unpacked ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length); @@ -536,7 +536,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { acc0 = acc0.add(prod16.and((short) 0xFF)); // lower - va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j); + va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, unpacked, i + j); prod8 = vb8.lanewise(LSHR, 4).mul(va8); prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc1 = acc1.add(prod16.and((short) 0xFF)); @@ -558,15 +558,15 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); int innerLimit = Math.min(limit - i, 1024); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { - ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j); - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j); + ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j); + ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j); ByteVector prod8 = va8.mul(vb8); ShortVector prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc0 = acc0.add(prod16.and((short) 0xFF)); - va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8); - vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8); + va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j + 8); + vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j + 8); prod8 = va8.mul(vb8); prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc1 = acc1.add(prod16.and((short) 0xFF));