From 3ec9c26d672262762f4213c827699bf735409eeb Mon Sep 17 00:00:00 2001 From: Robert Muir Date: Sat, 14 Oct 2023 16:28:26 -0400 Subject: [PATCH] speedup all binary functions on avx256, speedup binary square on avx512 --- .../PanamaVectorUtilSupport.java | 83 ++++++++++++++----- 1 file changed, 64 insertions(+), 19 deletions(-) diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index 1b7157c4752..a7ca073c882 100644 --- a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -17,6 +17,7 @@ package org.apache.lucene.internal.vectorization; import static jdk.incubator.vector.VectorOperators.ADD; +import static jdk.incubator.vector.VectorOperators.B2I; import static jdk.incubator.vector.VectorOperators.B2S; import static jdk.incubator.vector.VectorOperators.S2I; @@ -305,7 +306,10 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { // vectors (256-bit on intel to dodge performance landmines) if (a.length >= 16 && useIntegerVectors) { // compute vectorized dot product consistent with VPDPBUSD instruction - if (INT_SPECIES_PREF_BIT_SIZE >= 256) { + if (INT_SPECIES_PREF_BIT_SIZE >= 512) { + i += BYTE_SPECIES.loopBound(a.length); + res += dotProductBody512(a, b, i); + } else if (INT_SPECIES_PREF_BIT_SIZE == 256) { i += BYTE_SPECIES.loopBound(a.length); res += dotProductBody256(a, b, i); } else { @@ -322,14 +326,14 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { return res; } - /** vectorized dot product body (256+ bit vectors) */ - private int dotProductBody256(byte[] a, byte[] b, int limit) { + /** vectorized dot product body (512 bit vectors) */ + private int dotProductBody512(byte[] a, byte[] b, int limit) { IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED); for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i); ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i); - // 16-bit multiply + // 16-bit multiply: avoid AVX-512 heavy multiply on zmm Vector va16 = va8.convertShape(B2S, SHORT_SPECIES, 0); Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0); Vector prod16 = va16.mul(vb16); @@ -342,6 +346,22 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { return acc.reduceLanes(ADD); } + /** vectorized dot product body (256 bit vectors) */ + private int dotProductBody256(byte[] a, byte[] b, int limit) { + IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED); + for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i); + ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i); + + // 32-bit multiply and add into accumulator + Vector va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0); + Vector vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0); + acc = acc.add(va32.mul(vb32)); + } + // reduce + return acc.reduceLanes(ADD); + } + /** vectorized dot product body (128 bit vectors) */ private int dotProductBody128(byte[] a, byte[] b, int limit) { IntVector acc = IntVector.zero(IntVector.SPECIES_128); @@ -374,7 +394,10 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { // vectors (256-bit on intel to dodge performance landmines) if (a.length >= 16 && useIntegerVectors) { final float[] ret; - if (INT_SPECIES_PREF_BIT_SIZE >= 256) { + if (INT_SPECIES_PREF_BIT_SIZE >= 512) { + i += BYTE_SPECIES.loopBound(a.length); + ret = cosineBody512(a, b, i); + } else if (INT_SPECIES_PREF_BIT_SIZE == 256) { i += BYTE_SPECIES.loopBound(a.length); ret = cosineBody256(a, b, i); } else { @@ -398,9 +421,8 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); } - /** vectorized cosine body (256+ bit vectors) */ - private float[] cosineBody256(byte[] a, byte[] b, int limit) { - // optimized 256/512 bit implementation, processes 8/16 bytes at a time + /** vectorized cosine body (512 bit vectors) */ + private float[] cosineBody512(byte[] a, byte[] b, int limit) { IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED); IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED); IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED); @@ -408,20 +430,45 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i); ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i); - // 16-bit multiply + // 16-bit multiply: avoid AVX-512 heavy multiply on zmm Vector va16 = va8.convertShape(B2S, SHORT_SPECIES, 0); Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0); - Vector prod16 = va16.mul(vb16); Vector norm1_16 = va16.mul(va16); Vector norm2_16 = vb16.mul(vb16); + Vector prod16 = va16.mul(vb16); // sum into accumulators: 32-bit add - Vector prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0); Vector norm1_32 = norm1_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0); Vector norm2_32 = norm2_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0); - accSum = accSum.add(prod32); + Vector prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0); accNorm1 = accNorm1.add(norm1_32); accNorm2 = accNorm2.add(norm2_32); + accSum = accSum.add(prod32); + } + // reduce + return new float[] { + accSum.reduceLanes(ADD), accNorm1.reduceLanes(ADD), accNorm2.reduceLanes(ADD) + }; + } + + /** vectorized cosine body (256 bit vectors) */ + private float[] cosineBody256(byte[] a, byte[] b, int limit) { + IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED); + IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED); + IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED); + for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i); + ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i); + + // 16-bit multiply, and add into accumulators + Vector va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0); + Vector vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0); + Vector norm1_32 = va32.mul(va32); + Vector norm2_32 = vb32.mul(vb32); + Vector prod32 = va32.mul(vb32); + accNorm1 = accNorm1.add(norm1_32); + accNorm2 = accNorm2.add(norm2_32); + accSum = accSum.add(prod32); } // reduce return new float[] { @@ -488,13 +535,11 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i); ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i); - // 16-bit sub - Vector va16 = va8.convertShape(B2S, SHORT_SPECIES, 0); - Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0); - Vector diff16 = va16.sub(vb16); - - // 32-bit multiply and add into accumulators - Vector diff32 = diff16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0); + // 32-bit sub, multiply, and add into accumulators + // TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512? + Vector va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0); + Vector vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0); + Vector diff32 = va32.sub(vb32); acc = acc.add(diff32.mul(diff32)); } // reduce