From 9d732380aefd41fdbe152eab47eca83d8de4f2af Mon Sep 17 00:00:00 2001 From: Adrien Grand Date: Sat, 5 Mar 2022 18:31:56 +0100 Subject: [PATCH] LUCENE-10453: Speed up euclidean distances. (#725) --- lucene/CHANGES.txt | 3 +++ .../org/apache/lucene/util/VectorUtil.java | 25 ++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index c072829fde6..772986a3f8b 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -252,6 +252,9 @@ Optimizations * LUCENE-10450: IndexSortSortedNumericDocValuesRangeQuery could be rewrite to MatchAllDocsQuery. (Lu Xugang) +* LUCENE-10453: Indexing and search speedup with KNN vectors when using + euclidean distance. (Adrien Grand) + Changes in runtime behavior --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 1e5729aa85d..7df7a99fa73 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -133,13 +133,36 @@ public final class VectorUtil { } float squareSum = 0.0f; int dim = v1.length; - for (int i = 0; i < dim; i++) { + int i; + for (i = 0; i + 8 <= dim; i += 8) { + squareSum += squareDistanceUnrolled8(v1, v2, i); + } + for (; i < dim; i++) { float diff = v1[i] - v2[i]; squareSum += diff * diff; } return squareSum; } + private static float squareDistanceUnrolled8(float[] v1, float[] v2, int index) { + float diff0 = v1[index + 0] - v2[index + 0]; + float diff1 = v1[index + 1] - v2[index + 1]; + float diff2 = v1[index + 2] - v2[index + 2]; + float diff3 = v1[index + 3] - v2[index + 3]; + float diff4 = v1[index + 4] - v2[index + 4]; + float diff5 = v1[index + 5] - v2[index + 5]; + float diff6 = v1[index + 6] - v2[index + 6]; + float diff7 = v1[index + 7] - v2[index + 7]; + return diff0 * diff0 + + diff1 * diff1 + + diff2 * diff2 + + diff3 * diff3 + + diff4 * diff4 + + diff5 * diff5 + + diff6 * diff6 + + diff7 * diff7; + } + /** * Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is * thrown for zero vectors.