Improve VectorUtil::xorBitCount perf on ARM (#13545)

This commit improves the performance of VectorUtil::xorBitCount on ARM by ~4x.

This change is effectively a workaround for the lack of vectorization of Long::bitCount on ARM.

On x64 there is no issue, the long variant of xorBitCount outperforms the int variant by ~15%.
This commit is contained in:
Chris Hegarty 2024-07-08 17:30:45 +01:00 committed by GitHub
parent 9e04cb9c41
commit 3304b60c9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 185 additions and 1 deletions

View File

@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.benchmark.jmh;
import java.io.IOException;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.util.VectorUtil;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
@Fork(1)
@Warmup(iterations = 3, time = 3)
@Measurement(iterations = 5, time = 3)
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.SECONDS)
@State(Scope.Benchmark)
public class HammingDistanceBenchmark {
@Param({"1000000"})
int nb = 1_000_000;
@Param({"1024"})
int dims = 1024;
byte[][] xb;
byte[] xq;
@Setup
public void setup() throws IOException {
Random rand = new Random();
this.xb = new byte[nb][dims / 8];
for (int i = 0; i < nb; i++) {
for (int j = 0; j < dims / 8; j++) {
xb[i][j] = (byte) rand.nextInt(0, 255);
}
}
this.xq = new byte[dims / 8];
for (int i = 0; i < xq.length; i++) {
xq[i] = (byte) rand.nextInt(0, 255);
}
}
@Benchmark
public int xorBitCount() {
int tot = 0;
for (int i = 0; i < nb; i++) {
tot += VectorUtil.xorBitCount(xb[i], xq);
}
return tot;
}
}

View File

@ -212,6 +212,14 @@ public final class VectorUtil {
return IMPL.int4DotProduct(unpacked, false, packed, true);
}
/**
* For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time.
* On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when
* compared to Integer::bitCount. While Long::bitCount is optimal on x64. TODO: include the
* OpenJDK JIRA url
*/
static final boolean XOR_BIT_COUNT_STRIDE_AS_INT = Constants.OS_ARCH.equals("aarch64");
/**
* XOR bit count computed over signed bytes.
*
@ -223,8 +231,32 @@ public final class VectorUtil {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
if (XOR_BIT_COUNT_STRIDE_AS_INT) {
return xorBitCountInt(a, b);
} else {
return xorBitCountLong(a, b);
}
}
/** XOR bit count striding over 4 bytes at a time. */
static int xorBitCountInt(byte[] a, byte[] b) {
int distance = 0, i = 0;
for (final int upperBound = a.length & ~(Long.BYTES - 1); i < upperBound; i += Long.BYTES) {
for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
distance +=
Integer.bitCount(
(int) BitUtil.VH_NATIVE_INT.get(a, i) ^ (int) BitUtil.VH_NATIVE_INT.get(b, i));
}
// tail:
for (; i < a.length; i++) {
distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF);
}
return distance;
}
/** XOR bit count striding over 8 bytes at a time. */
static int xorBitCountLong(byte[] a, byte[] b) {
int distance = 0, i = 0;
for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
distance +=
Long.bitCount(
(long) BitUtil.VH_NATIVE_LONG.get(a, i) ^ (long) BitUtil.VH_NATIVE_LONG.get(b, i));

View File

@ -276,4 +276,81 @@ public class TestVectorUtil extends LuceneTestCase {
u[1] = -v[0];
assertEquals(0, VectorUtil.cosine(u, v), DELTA);
}
interface ToIntBiFunction {
int apply(byte[] a, byte[] b);
}
public void testBasicXorBitCount() {
testBasicXorBitCountImpl(VectorUtil::xorBitCount);
testBasicXorBitCountImpl(VectorUtil::xorBitCountInt);
testBasicXorBitCountImpl(VectorUtil::xorBitCountLong);
// test sanity
testBasicXorBitCountImpl(TestVectorUtil::xorBitCount);
}
void testBasicXorBitCountImpl(ToIntBiFunction xorBitCount) {
assertEquals(0, xorBitCount.apply(new byte[] {1}, new byte[] {1}));
assertEquals(0, xorBitCount.apply(new byte[] {1, 2, 3}, new byte[] {1, 2, 3}));
assertEquals(1, xorBitCount.apply(new byte[] {1, 2, 3}, new byte[] {0, 2, 3}));
assertEquals(2, xorBitCount.apply(new byte[] {1, 2, 3}, new byte[] {0, 6, 3}));
assertEquals(3, xorBitCount.apply(new byte[] {1, 2, 3}, new byte[] {0, 6, 7}));
assertEquals(4, xorBitCount.apply(new byte[] {1, 2, 3}, new byte[] {2, 6, 7}));
// 32-bit / int boundary
assertEquals(0, xorBitCount.apply(new byte[] {1, 2, 3, 4}, new byte[] {1, 2, 3, 4}));
assertEquals(1, xorBitCount.apply(new byte[] {1, 2, 3, 4}, new byte[] {0, 2, 3, 4}));
assertEquals(0, xorBitCount.apply(new byte[] {1, 2, 3, 4, 5}, new byte[] {1, 2, 3, 4, 5}));
assertEquals(1, xorBitCount.apply(new byte[] {1, 2, 3, 4, 5}, new byte[] {0, 2, 3, 4, 5}));
// 64-bit / long boundary
assertEquals(
0,
xorBitCount.apply(
new byte[] {1, 2, 3, 4, 5, 6, 7, 8}, new byte[] {1, 2, 3, 4, 5, 6, 7, 8}));
assertEquals(
1,
xorBitCount.apply(
new byte[] {1, 2, 3, 4, 5, 6, 7, 8}, new byte[] {0, 2, 3, 4, 5, 6, 7, 8}));
assertEquals(
0,
xorBitCount.apply(
new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9}, new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9}));
assertEquals(
1,
xorBitCount.apply(
new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9}, new byte[] {0, 2, 3, 4, 5, 6, 7, 8, 9}));
}
public void testXorBitCount() {
int iterations = atLeast(100);
for (int i = 0; i < iterations; i++) {
int size = random().nextInt(1024);
byte[] a = new byte[size];
byte[] b = new byte[size];
random().nextBytes(a);
random().nextBytes(b);
int expected = xorBitCount(a, b);
assertEquals(expected, VectorUtil.xorBitCount(a, b));
assertEquals(expected, VectorUtil.xorBitCountInt(a, b));
assertEquals(expected, VectorUtil.xorBitCountLong(a, b));
}
}
private static int xorBitCount(byte[] a, byte[] b) {
int res = 0;
for (int i = 0; i < a.length; i++) {
byte x = a[i];
byte y = b[i];
for (int j = 0; j < Byte.SIZE; j++) {
if (x == y) break;
if ((x & 0x01) != (y & 0x01)) res++;
x = (byte) ((x & 0xFF) >> 1);
y = (byte) ((y & 0xFF) >> 1);
}
}
return res;
}
}