From 57397f0cabaff5eea98e9ec975290d2800e99627 Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Wed, 1 Feb 2023 10:48:05 +0100 Subject: [PATCH] Adjust return type for VectorUtil methods (#12122) Two of the methods (squareDistance and dotProduct) that take byte arrays return a float while the variable used to store the value is an int. They can just return an int. --- .../org/apache/lucene/index/VectorSimilarityFunction.java | 2 +- lucene/core/src/java/org/apache/lucene/util/VectorUtil.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java index d5ebcabb7a3..3646cf65584 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java @@ -37,7 +37,7 @@ public enum VectorSimilarityFunction { @Override public float compare(byte[] v1, byte[] v2) { - return 1 / (1 + squareDistance(v1, v2)); + return 1 / (1f + squareDistance(v1, v2)); } }, diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 8800d434ba6..2a08436ec0b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -181,7 +181,7 @@ public final class VectorUtil { } /** Returns the sum of squared differences of the two vectors. */ - public static float squareDistance(byte[] a, byte[] b) { + public static int squareDistance(byte[] a, byte[] b) { // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. int squareSum = 0; for (int i = 0; i < a.length; i++) { @@ -249,7 +249,7 @@ public final class VectorUtil { * @param b bytes containing another vector, of the same dimension * @return the value of the dot product of the two vectors */ - public static float dotProduct(byte[] a, byte[] b) { + public static int dotProduct(byte[] a, byte[] b) { assert a.length == b.length; int total = 0; for (int i = 0; i < a.length; i++) {