Speed up advancing within a block, take 2. ()

PR  tried to speed up advancing by using branchless binary search, but while this yielded a speedup on my machine, this yielded a slowdown on nightly benchmarks.

This PR tries a different approach using vectorization. Experimentation suggests that it speeds up queries that advance to the next few doc IDs, such as `AndHighHigh`.
This commit is contained in:
Adrien Grand 2024-10-30 12:51:36 +01:00 committed by GitHub
parent 9359cfd32f
commit 3041af7a94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 292 additions and 17 deletions
lucene
CHANGES.txt
benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh
core/src
java/org/apache/lucene
java21/org/apache/lucene/internal/vectorization
test/org/apache/lucene/util

View File

@ -80,6 +80,8 @@ Optimizations
* GITHUB#13963: Speed up nextDoc() implementations in Lucene912PostingsReader.
(Adrien Grand)
* GITHUB#13958: Speed up advancing within a block. (Adrien Grand)
Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended

View File

@ -0,0 +1,180 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.benchmark.jmh;
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.VectorUtil;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.CompilerControl;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 5, time = 1)
@Measurement(iterations = 5, time = 1)
@Fork(
value = 3,
jvmArgsAppend = {
"-Xmx1g",
"-Xms1g",
"-XX:+AlwaysPreTouch",
"--add-modules",
"jdk.incubator.vector"
})
public class AdvanceBenchmark {
private final long[] values = new long[129];
private final int[] startIndexes = new int[1_000];
private final long[] targets = new long[startIndexes.length];
@Setup(Level.Trial)
public void setup() throws Exception {
for (int i = 0; i < 128; ++i) {
values[i] = i;
}
values[128] = DocIdSetIterator.NO_MORE_DOCS;
Random r = new Random(0);
for (int i = 0; i < startIndexes.length; ++i) {
startIndexes[i] = r.nextInt(64);
targets[i] = startIndexes[i] + 1 + r.nextInt(1 << r.nextInt(7));
}
}
@Benchmark
public void binarySearch() {
for (int i = 0; i < startIndexes.length; ++i) {
binarySearch(values, targets[i], startIndexes[i]);
}
}
@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int binarySearch(long[] values, long target, int startIndex) {
// Standard binary search
int i = Arrays.binarySearch(values, startIndex, values.length, target);
if (i < 0) {
i = -1 - i;
}
return i;
}
@Benchmark
public void inlinedBranchlessBinarySearch() {
for (int i = 0; i < targets.length; ++i) {
inlinedBranchlessBinarySearch(values, targets[i]);
}
}
@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int inlinedBranchlessBinarySearch(long[] values, long target) {
// This compiles to cmov instructions.
int start = 0;
if (values[63] < target) {
start += 64;
}
if (values[start + 31] < target) {
start += 32;
}
if (values[start + 15] < target) {
start += 16;
}
if (values[start + 7] < target) {
start += 8;
}
if (values[start + 3] < target) {
start += 4;
}
if (values[start + 1] < target) {
start += 2;
}
if (values[start] < target) {
start += 1;
}
return start;
}
@Benchmark
public void linearSearch() {
for (int i = 0; i < startIndexes.length; ++i) {
linearSearch(values, targets[i], startIndexes[i]);
}
}
@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int linearSearch(long[] values, long target, int startIndex) {
// Naive linear search.
for (int i = startIndex; i < values.length; ++i) {
if (values[i] >= target) {
return i;
}
}
return values.length;
}
@Benchmark
public void vectorUtilSearch() {
for (int i = 0; i < startIndexes.length; ++i) {
VectorUtil.findNextGEQ(values, 128, targets[i], startIndexes[i]);
}
}
@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int vectorUtilSearch(long[] values, long target, int startIndex) {
return VectorUtil.findNextGEQ(values, 128, target, startIndex);
}
private static void assertEquals(int expected, int actual) {
if (expected != actual) {
throw new AssertionError("Expected: " + expected + ", got " + actual);
}
}
public static void main(String[] args) {
// For testing purposes
long[] values = new long[129];
for (int i = 0; i < 128; ++i) {
values[i] = i;
}
values[128] = DocIdSetIterator.NO_MORE_DOCS;
for (int start = 0; start < 128; ++start) {
for (int targetIndex = start; targetIndex < 128; ++targetIndex) {
int actualIndex = binarySearch(values, values[targetIndex], start);
assertEquals(targetIndex, actualIndex);
actualIndex = inlinedBranchlessBinarySearch(values, values[targetIndex]);
assertEquals(targetIndex, actualIndex);
actualIndex = linearSearch(values, values[targetIndex], start);
assertEquals(targetIndex, actualIndex);
actualIndex = vectorUtilSearch(values, values[targetIndex], start);
assertEquals(targetIndex, actualIndex);
}
}
}
}

View File

@ -46,6 +46,7 @@ import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SlowImpactsEnum;
import org.apache.lucene.internal.vectorization.PostingDecodingUtil;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;
import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.ChecksumIndexInput;
@ -65,6 +66,8 @@ import org.apache.lucene.util.IOUtils;
public final class Lucene912PostingsReader extends PostingsReaderBase {
static final VectorizationProvider VECTORIZATION_PROVIDER = VectorizationProvider.getInstance();
private static final VectorUtilSupport VECTOR_SUPPORT =
VECTORIZATION_PROVIDER.getVectorUtilSupport();
// Dummy impacts, composed of the maximum possible term frequency and the lowest possible
// (unsigned) norm value. This is typically used on tail blocks, which don't actually record
// impacts as the storage overhead would not be worth any query evaluation speedup, since there's
@ -215,15 +218,6 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
}
}
static int findFirstGreater(long[] buffer, int target, int from) {
for (int i = from; i < BLOCK_SIZE; ++i) {
if (buffer[i] >= target) {
return i;
}
}
return BLOCK_SIZE;
}
@Override
public BlockTermState newTermState() {
return new IntBlockTermState();
@ -357,6 +351,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
protected int docCountUpto; // number of docs in or before the current block
protected long prevDocID; // last doc ID of the previous block
protected int docBufferSize;
protected int docBufferUpto;
protected IndexInput docIn;
@ -402,6 +397,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
level1DocEndFP = termState.docStartFP;
}
level1DocCountUpto = 0;
docBufferSize = BLOCK_SIZE;
docBufferUpto = BLOCK_SIZE;
return this;
}
@ -487,7 +483,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
docCountUpto += BLOCK_SIZE;
prevDocID = docBuffer[BLOCK_SIZE - 1];
docBufferUpto = 0;
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
}
private void refillRemainder() throws IOException {
@ -508,6 +504,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
docCountUpto += left;
}
docBufferUpto = 0;
docBufferSize = left;
freqFP = -1;
}
@ -604,7 +601,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
}
}
int next = findFirstGreater(docBuffer, target, docBufferUpto);
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
this.doc = (int) docBuffer[next];
docBufferUpto = next + 1;
return doc;
@ -782,16 +779,18 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
freqBuffer[0] = totalTermFreq;
docBuffer[1] = NO_MORE_DOCS;
docCountUpto++;
docBufferSize = 1;
} else {
// Read vInts:
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, true);
prefixSum(docBuffer, left, prevDocID);
docBuffer[left] = NO_MORE_DOCS;
docCountUpto += left;
docBufferSize = left;
}
prevDocID = docBuffer[BLOCK_SIZE - 1];
docBufferUpto = 0;
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
}
private void skipLevel1To(int target) throws IOException {
@ -951,7 +950,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
refillDocs();
}
int next = findFirstGreater(docBuffer, target, docBufferUpto);
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
posPendingCount += sumOverRange(freqBuffer, docBufferUpto, next + 1);
this.freq = (int) freqBuffer[next];
this.docBufferUpto = next + 1;
@ -1155,6 +1154,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
protected int docCountUpto; // number of docs in or before the current block
protected int doc = -1; // doc we last read
protected long prevDocID = -1; // last doc ID of the previous block
protected int docBufferSize = BLOCK_SIZE;
protected int docBufferUpto = BLOCK_SIZE;
// true if we shallow-advanced to a new block that we have not decoded yet
@ -1306,10 +1306,11 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
docBuffer[left] = NO_MORE_DOCS;
freqFP = -1;
docCountUpto += left;
docBufferSize = left;
}
prevDocID = docBuffer[BLOCK_SIZE - 1];
docBufferUpto = 0;
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
}
private void skipLevel1To(int target) throws IOException {
@ -1437,7 +1438,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
needsRefilling = false;
}
int next = findFirstGreater(docBuffer, target, docBufferUpto);
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
this.doc = (int) docBuffer[next];
docBufferUpto = next + 1;
return doc;
@ -1535,10 +1536,11 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
prefixSum(docBuffer, left, prevDocID);
docBuffer[left] = NO_MORE_DOCS;
docCountUpto += left;
docBufferSize = left;
}
prevDocID = docBuffer[BLOCK_SIZE - 1];
docBufferUpto = 0;
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
}
private void skipLevel1To(int target) throws IOException {
@ -1669,7 +1671,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
needsRefilling = false;
}
int next = findFirstGreater(docBuffer, target, docBufferUpto);
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
posPendingCount += sumOverRange(freqBuffer, docBufferUpto, next + 1);
freq = (int) freqBuffer[next];
docBufferUpto = next + 1;

View File

@ -197,4 +197,14 @@ final class DefaultVectorUtilSupport implements VectorUtilSupport {
}
return squareSum;
}
@Override
public int findNextGEQ(long[] buffer, int length, long target, int from) {
for (int i = from; i < length; ++i) {
if (buffer[i] >= target) {
return i;
}
}
return length;
}
}

View File

@ -44,4 +44,12 @@ public interface VectorUtilSupport {
/** Returns the sum of squared differences of the two byte vectors. */
int squareDistance(byte[] a, byte[] b);
/**
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code
* length} exclusive, find the first array index whose value is greater than or equal to {@code
* target}. This index is guaranteed to be at least {@code from}. If there is no such array index,
* {@code length} is returned.
*/
int findNextGEQ(long[] buffer, int length, long target, int from);
}

View File

@ -106,6 +106,10 @@ public final class ImpactsDISI extends DocIdSetIterator {
@Override
public int nextDoc() throws IOException {
DocIdSetIterator in = this.in;
if (in.docID() < upTo) {
return in.nextDoc();
}
return advance(in.docID() + 1);
}

View File

@ -307,4 +307,14 @@ public final class VectorUtil {
}
return v;
}
/**
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code
* length} exclusive, find the first array index whose value is greater than or equal to {@code
* target}. This index is guaranteed to be at least {@code from}. If there is no such array index,
* {@code length} is returned.
*/
public static int findNextGEQ(long[] buffer, int length, long target, int from) {
return IMPL.findNextGEQ(buffer, length, target, from);
}
}

View File

@ -29,8 +29,11 @@ import java.lang.foreign.MemorySegment;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.LongVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorShape;
import jdk.incubator.vector.VectorSpecies;
import org.apache.lucene.util.Constants;
@ -56,6 +59,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
PanamaVectorConstants.PRERERRED_INT_SPECIES;
private static final VectorSpecies<Byte> BYTE_SPECIES;
private static final VectorSpecies<Short> SHORT_SPECIES;
private static final VectorSpecies<Long> LONG_SPECIES;
static final int VECTOR_BITSIZE;
@ -71,6 +75,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
BYTE_SPECIES = null;
SHORT_SPECIES = null;
}
LONG_SPECIES = PanamaVectorConstants.PRERERRED_LONG_SPECIES;
}
// the way FMA should work! if available use it, otherwise fall back to mul/add
@ -761,4 +766,27 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// reduce
return acc1.add(acc2).reduceLanes(ADD);
}
// Experiments suggest that we need at least 4 lanes so that the overhead of going with the vector
// approach and counting trues on vector masks pays off.
private static final boolean ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO = LONG_SPECIES.length() >= 4;
@Override
public int findNextGEQ(long[] buffer, int length, long target, int from) {
if (ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO) {
for (; from + LONG_SPECIES.length() < length; from += LONG_SPECIES.length() + 1) {
if (buffer[from + LONG_SPECIES.length()] >= target) {
LongVector vector = LongVector.fromArray(LONG_SPECIES, buffer, from);
VectorMask<Long> mask = vector.compare(VectorOperators.LT, target);
return from + mask.trueCount();
}
}
}
for (int i = from; i < length; ++i) {
if (buffer[i] >= target) {
return i;
}
}
return length;
}
}

View File

@ -353,4 +353,35 @@ public class TestVectorUtil extends LuceneTestCase {
}
return res;
}
public void testFindNextGEQ() {
int padding = TestUtil.nextInt(random(), 0, 5);
long[] values = new long[128 + padding];
long v = 0;
for (int i = 0; i < 128; ++i) {
v += TestUtil.nextInt(random(), 1, 1000);
values[i] = v;
}
// Now duel with slowFindFirstGreater
for (int iter = 0; iter < 1_000; ++iter) {
int from = TestUtil.nextInt(random(), 0, 127);
long target =
TestUtil.nextLong(random(), values[from], Math.max(values[from], values[127]))
+ random().nextInt(10)
- 5;
assertEquals(
slowFindNextGEQ(values, 128, target, from),
VectorUtil.findNextGEQ(values, 128, target, from));
}
}
private static int slowFindNextGEQ(long[] buffer, int length, long target, int from) {
for (int i = from; i < length; ++i) {
if (buffer[i] >= target) {
return i;
}
}
return length;
}
}