mirror of https://github.com/apache/lucene.git
Improve VectorUtil::xorBitCount perf on ARM (#13545)
This commit improves the performance of VectorUtil::xorBitCount on ARM by ~4x. This change is effectively a workaround for the lack of vectorization of Long::bitCount on ARM. On x64 there is no issue, the long variant of xorBitCount outperforms the int variant by ~15%.
This commit is contained in:
parent
9e04cb9c41
commit
3304b60c9c
|
@ -0,0 +1,75 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.benchmark.jmh;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.openjdk.jmh.annotations.Benchmark;
|
||||
import org.openjdk.jmh.annotations.BenchmarkMode;
|
||||
import org.openjdk.jmh.annotations.Fork;
|
||||
import org.openjdk.jmh.annotations.Measurement;
|
||||
import org.openjdk.jmh.annotations.Mode;
|
||||
import org.openjdk.jmh.annotations.OutputTimeUnit;
|
||||
import org.openjdk.jmh.annotations.Param;
|
||||
import org.openjdk.jmh.annotations.Scope;
|
||||
import org.openjdk.jmh.annotations.Setup;
|
||||
import org.openjdk.jmh.annotations.State;
|
||||
import org.openjdk.jmh.annotations.Warmup;
|
||||
|
||||
@Fork(1)
|
||||
@Warmup(iterations = 3, time = 3)
|
||||
@Measurement(iterations = 5, time = 3)
|
||||
@BenchmarkMode(Mode.Throughput)
|
||||
@OutputTimeUnit(TimeUnit.SECONDS)
|
||||
@State(Scope.Benchmark)
|
||||
public class HammingDistanceBenchmark {
|
||||
@Param({"1000000"})
|
||||
int nb = 1_000_000;
|
||||
|
||||
@Param({"1024"})
|
||||
int dims = 1024;
|
||||
|
||||
byte[][] xb;
|
||||
byte[] xq;
|
||||
|
||||
@Setup
|
||||
public void setup() throws IOException {
|
||||
Random rand = new Random();
|
||||
this.xb = new byte[nb][dims / 8];
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int j = 0; j < dims / 8; j++) {
|
||||
xb[i][j] = (byte) rand.nextInt(0, 255);
|
||||
}
|
||||
}
|
||||
this.xq = new byte[dims / 8];
|
||||
for (int i = 0; i < xq.length; i++) {
|
||||
xq[i] = (byte) rand.nextInt(0, 255);
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int xorBitCount() {
|
||||
int tot = 0;
|
||||
for (int i = 0; i < nb; i++) {
|
||||
tot += VectorUtil.xorBitCount(xb[i], xq);
|
||||
}
|
||||
return tot;
|
||||
}
|
||||
}
|
|
@ -212,6 +212,14 @@ public final class VectorUtil {
|
|||
return IMPL.int4DotProduct(unpacked, false, packed, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time.
|
||||
* On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when
|
||||
* compared to Integer::bitCount. While Long::bitCount is optimal on x64. TODO: include the
|
||||
* OpenJDK JIRA url
|
||||
*/
|
||||
static final boolean XOR_BIT_COUNT_STRIDE_AS_INT = Constants.OS_ARCH.equals("aarch64");
|
||||
|
||||
/**
|
||||
* XOR bit count computed over signed bytes.
|
||||
*
|
||||
|
@ -223,8 +231,32 @@ public final class VectorUtil {
|
|||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
|
||||
}
|
||||
if (XOR_BIT_COUNT_STRIDE_AS_INT) {
|
||||
return xorBitCountInt(a, b);
|
||||
} else {
|
||||
return xorBitCountLong(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
/** XOR bit count striding over 4 bytes at a time. */
|
||||
static int xorBitCountInt(byte[] a, byte[] b) {
|
||||
int distance = 0, i = 0;
|
||||
for (final int upperBound = a.length & ~(Long.BYTES - 1); i < upperBound; i += Long.BYTES) {
|
||||
for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
|
||||
distance +=
|
||||
Integer.bitCount(
|
||||
(int) BitUtil.VH_NATIVE_INT.get(a, i) ^ (int) BitUtil.VH_NATIVE_INT.get(b, i));
|
||||
}
|
||||
// tail:
|
||||
for (; i < a.length; i++) {
|
||||
distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF);
|
||||
}
|
||||
return distance;
|
||||
}
|
||||
|
||||
/** XOR bit count striding over 8 bytes at a time. */
|
||||
static int xorBitCountLong(byte[] a, byte[] b) {
|
||||
int distance = 0, i = 0;
|
||||
for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
|
||||
distance +=
|
||||
Long.bitCount(
|
||||
(long) BitUtil.VH_NATIVE_LONG.get(a, i) ^ (long) BitUtil.VH_NATIVE_LONG.get(b, i));
|
||||
|
|
|
@ -276,4 +276,81 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
u[1] = -v[0];
|
||||
assertEquals(0, VectorUtil.cosine(u, v), DELTA);
|
||||
}
|
||||
|
||||
interface ToIntBiFunction {
|
||||
int apply(byte[] a, byte[] b);
|
||||
}
|
||||
|
||||
public void testBasicXorBitCount() {
|
||||
testBasicXorBitCountImpl(VectorUtil::xorBitCount);
|
||||
testBasicXorBitCountImpl(VectorUtil::xorBitCountInt);
|
||||
testBasicXorBitCountImpl(VectorUtil::xorBitCountLong);
|
||||
// test sanity
|
||||
testBasicXorBitCountImpl(TestVectorUtil::xorBitCount);
|
||||
}
|
||||
|
||||
void testBasicXorBitCountImpl(ToIntBiFunction xorBitCount) {
|
||||
assertEquals(0, xorBitCount.apply(new byte[] {1}, new byte[] {1}));
|
||||
assertEquals(0, xorBitCount.apply(new byte[] {1, 2, 3}, new byte[] {1, 2, 3}));
|
||||
assertEquals(1, xorBitCount.apply(new byte[] {1, 2, 3}, new byte[] {0, 2, 3}));
|
||||
assertEquals(2, xorBitCount.apply(new byte[] {1, 2, 3}, new byte[] {0, 6, 3}));
|
||||
assertEquals(3, xorBitCount.apply(new byte[] {1, 2, 3}, new byte[] {0, 6, 7}));
|
||||
assertEquals(4, xorBitCount.apply(new byte[] {1, 2, 3}, new byte[] {2, 6, 7}));
|
||||
|
||||
// 32-bit / int boundary
|
||||
assertEquals(0, xorBitCount.apply(new byte[] {1, 2, 3, 4}, new byte[] {1, 2, 3, 4}));
|
||||
assertEquals(1, xorBitCount.apply(new byte[] {1, 2, 3, 4}, new byte[] {0, 2, 3, 4}));
|
||||
assertEquals(0, xorBitCount.apply(new byte[] {1, 2, 3, 4, 5}, new byte[] {1, 2, 3, 4, 5}));
|
||||
assertEquals(1, xorBitCount.apply(new byte[] {1, 2, 3, 4, 5}, new byte[] {0, 2, 3, 4, 5}));
|
||||
|
||||
// 64-bit / long boundary
|
||||
assertEquals(
|
||||
0,
|
||||
xorBitCount.apply(
|
||||
new byte[] {1, 2, 3, 4, 5, 6, 7, 8}, new byte[] {1, 2, 3, 4, 5, 6, 7, 8}));
|
||||
assertEquals(
|
||||
1,
|
||||
xorBitCount.apply(
|
||||
new byte[] {1, 2, 3, 4, 5, 6, 7, 8}, new byte[] {0, 2, 3, 4, 5, 6, 7, 8}));
|
||||
|
||||
assertEquals(
|
||||
0,
|
||||
xorBitCount.apply(
|
||||
new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9}, new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9}));
|
||||
assertEquals(
|
||||
1,
|
||||
xorBitCount.apply(
|
||||
new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9}, new byte[] {0, 2, 3, 4, 5, 6, 7, 8, 9}));
|
||||
}
|
||||
|
||||
public void testXorBitCount() {
|
||||
int iterations = atLeast(100);
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
int size = random().nextInt(1024);
|
||||
byte[] a = new byte[size];
|
||||
byte[] b = new byte[size];
|
||||
random().nextBytes(a);
|
||||
random().nextBytes(b);
|
||||
|
||||
int expected = xorBitCount(a, b);
|
||||
assertEquals(expected, VectorUtil.xorBitCount(a, b));
|
||||
assertEquals(expected, VectorUtil.xorBitCountInt(a, b));
|
||||
assertEquals(expected, VectorUtil.xorBitCountLong(a, b));
|
||||
}
|
||||
}
|
||||
|
||||
private static int xorBitCount(byte[] a, byte[] b) {
|
||||
int res = 0;
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
byte x = a[i];
|
||||
byte y = b[i];
|
||||
for (int j = 0; j < Byte.SIZE; j++) {
|
||||
if (x == y) break;
|
||||
if ((x & 0x01) != (y & 0x01)) res++;
|
||||
x = (byte) ((x & 0xFF) >> 1);
|
||||
y = (byte) ((y & 0xFF) >> 1);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue