Fix quantization issue with 8 bits for L2 space type

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
This commit is contained in:
Naveen Tatikonda 2024-06-27 15:36:33 -05:00
parent f8ee339f64
commit f3903bc389
1 changed files with 12 additions and 3 deletions

View File

@ -115,7 +115,7 @@ public class ScalarQuantizer {
assert src.length == dest.length; assert src.length == dest.length;
float correction = 0; float correction = 0;
for (int i = 0; i < src.length; i++) { for (int i = 0; i < src.length; i++) {
correction += quantizeFloat(src[i], dest, i); correction += quantizeFloat(src[i], dest, i, similarityFunction);
} }
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) { if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
return 0; return 0;
@ -123,7 +123,8 @@ public class ScalarQuantizer {
return correction; return correction;
} }
private float quantizeFloat(float v, byte[] dest, int destIndex) { private float quantizeFloat(
float v, byte[] dest, int destIndex, VectorSimilarityFunction similarityFunction) {
assert dest == null || destIndex < dest.length; assert dest == null || destIndex < dest.length;
// Make sure the value is within the quantile range, cutting off the tails // Make sure the value is within the quantile range, cutting off the tails
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile - // see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
@ -136,6 +137,14 @@ public class ScalarQuantizer {
// We multiply by `alpha` here to get the quantized value back into the original range // We multiply by `alpha` here to get the quantized value back into the original range
// to aid in calculating the corrective offset // to aid in calculating the corrective offset
float dxq = Math.round(dxs) * alpha; float dxq = Math.round(dxs) * alpha;
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN) && bits == 8) {
// Uniformly shift each float value by 128 to bring it into byte range
dxs = Math.round(dxs - 128);
// Clip out of range values into [-128 to 127] range
dxs = Math.max(-128, Math.min(127, dxs));
}
if (dest != null) { if (dest != null) {
dest[destIndex] = (byte) Math.round(dxs); dest[destIndex] = (byte) Math.round(dxs);
} }
@ -166,7 +175,7 @@ public class ScalarQuantizer {
for (int i = 0; i < quantizedVector.length; i++) { for (int i = 0; i < quantizedVector.length; i++) {
// dequantize the old value in order to recalculate the corrective offset // dequantize the old value in order to recalculate the corrective offset
float v = (oldQuantizer.alpha * quantizedVector[i]) + oldQuantizer.minQuantile; float v = (oldQuantizer.alpha * quantizedVector[i]) + oldQuantizer.minQuantile;
correctiveOffset += quantizeFloat(v, null, 0); correctiveOffset += quantizeFloat(v, null, 0, similarityFunction);
} }
return correctiveOffset; return correctiveOffset;
} }