mirror of https://github.com/apache/lucene.git
Speed up advancing within a block, take 2. (#13958)
PR #13692 tried to speed up advancing by using branchless binary search, but while this yielded a speedup on my machine, this yielded a slowdown on nightly benchmarks. This PR tries a different approach using vectorization. Experimentation suggests that it speeds up queries that advance to the next few doc IDs, such as `AndHighHigh`.
This commit is contained in:
parent
9359cfd32f
commit
3041af7a94
lucene
CHANGES.txt
benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh
core/src
java/org/apache/lucene
codecs/lucene912
internal/vectorization
search
util
java21/org/apache/lucene/internal/vectorization
test/org/apache/lucene/util
|
@ -80,6 +80,8 @@ Optimizations
|
|||
* GITHUB#13963: Speed up nextDoc() implementations in Lucene912PostingsReader.
|
||||
(Adrien Grand)
|
||||
|
||||
* GITHUB#13958: Speed up advancing within a block. (Adrien Grand)
|
||||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.benchmark.jmh;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.openjdk.jmh.annotations.Benchmark;
|
||||
import org.openjdk.jmh.annotations.BenchmarkMode;
|
||||
import org.openjdk.jmh.annotations.CompilerControl;
|
||||
import org.openjdk.jmh.annotations.Fork;
|
||||
import org.openjdk.jmh.annotations.Level;
|
||||
import org.openjdk.jmh.annotations.Measurement;
|
||||
import org.openjdk.jmh.annotations.Mode;
|
||||
import org.openjdk.jmh.annotations.OutputTimeUnit;
|
||||
import org.openjdk.jmh.annotations.Scope;
|
||||
import org.openjdk.jmh.annotations.Setup;
|
||||
import org.openjdk.jmh.annotations.State;
|
||||
import org.openjdk.jmh.annotations.Warmup;
|
||||
|
||||
@BenchmarkMode(Mode.Throughput)
|
||||
@OutputTimeUnit(TimeUnit.MILLISECONDS)
|
||||
@State(Scope.Benchmark)
|
||||
@Warmup(iterations = 5, time = 1)
|
||||
@Measurement(iterations = 5, time = 1)
|
||||
@Fork(
|
||||
value = 3,
|
||||
jvmArgsAppend = {
|
||||
"-Xmx1g",
|
||||
"-Xms1g",
|
||||
"-XX:+AlwaysPreTouch",
|
||||
"--add-modules",
|
||||
"jdk.incubator.vector"
|
||||
})
|
||||
public class AdvanceBenchmark {
|
||||
|
||||
private final long[] values = new long[129];
|
||||
private final int[] startIndexes = new int[1_000];
|
||||
private final long[] targets = new long[startIndexes.length];
|
||||
|
||||
@Setup(Level.Trial)
|
||||
public void setup() throws Exception {
|
||||
for (int i = 0; i < 128; ++i) {
|
||||
values[i] = i;
|
||||
}
|
||||
values[128] = DocIdSetIterator.NO_MORE_DOCS;
|
||||
Random r = new Random(0);
|
||||
for (int i = 0; i < startIndexes.length; ++i) {
|
||||
startIndexes[i] = r.nextInt(64);
|
||||
targets[i] = startIndexes[i] + 1 + r.nextInt(1 << r.nextInt(7));
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void binarySearch() {
|
||||
for (int i = 0; i < startIndexes.length; ++i) {
|
||||
binarySearch(values, targets[i], startIndexes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@CompilerControl(CompilerControl.Mode.DONT_INLINE)
|
||||
private static int binarySearch(long[] values, long target, int startIndex) {
|
||||
// Standard binary search
|
||||
int i = Arrays.binarySearch(values, startIndex, values.length, target);
|
||||
if (i < 0) {
|
||||
i = -1 - i;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void inlinedBranchlessBinarySearch() {
|
||||
for (int i = 0; i < targets.length; ++i) {
|
||||
inlinedBranchlessBinarySearch(values, targets[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@CompilerControl(CompilerControl.Mode.DONT_INLINE)
|
||||
private static int inlinedBranchlessBinarySearch(long[] values, long target) {
|
||||
// This compiles to cmov instructions.
|
||||
int start = 0;
|
||||
|
||||
if (values[63] < target) {
|
||||
start += 64;
|
||||
}
|
||||
if (values[start + 31] < target) {
|
||||
start += 32;
|
||||
}
|
||||
if (values[start + 15] < target) {
|
||||
start += 16;
|
||||
}
|
||||
if (values[start + 7] < target) {
|
||||
start += 8;
|
||||
}
|
||||
if (values[start + 3] < target) {
|
||||
start += 4;
|
||||
}
|
||||
if (values[start + 1] < target) {
|
||||
start += 2;
|
||||
}
|
||||
if (values[start] < target) {
|
||||
start += 1;
|
||||
}
|
||||
|
||||
return start;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void linearSearch() {
|
||||
for (int i = 0; i < startIndexes.length; ++i) {
|
||||
linearSearch(values, targets[i], startIndexes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@CompilerControl(CompilerControl.Mode.DONT_INLINE)
|
||||
private static int linearSearch(long[] values, long target, int startIndex) {
|
||||
// Naive linear search.
|
||||
for (int i = startIndex; i < values.length; ++i) {
|
||||
if (values[i] >= target) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return values.length;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void vectorUtilSearch() {
|
||||
for (int i = 0; i < startIndexes.length; ++i) {
|
||||
VectorUtil.findNextGEQ(values, 128, targets[i], startIndexes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@CompilerControl(CompilerControl.Mode.DONT_INLINE)
|
||||
private static int vectorUtilSearch(long[] values, long target, int startIndex) {
|
||||
return VectorUtil.findNextGEQ(values, 128, target, startIndex);
|
||||
}
|
||||
|
||||
private static void assertEquals(int expected, int actual) {
|
||||
if (expected != actual) {
|
||||
throw new AssertionError("Expected: " + expected + ", got " + actual);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
// For testing purposes
|
||||
long[] values = new long[129];
|
||||
for (int i = 0; i < 128; ++i) {
|
||||
values[i] = i;
|
||||
}
|
||||
values[128] = DocIdSetIterator.NO_MORE_DOCS;
|
||||
for (int start = 0; start < 128; ++start) {
|
||||
for (int targetIndex = start; targetIndex < 128; ++targetIndex) {
|
||||
int actualIndex = binarySearch(values, values[targetIndex], start);
|
||||
assertEquals(targetIndex, actualIndex);
|
||||
actualIndex = inlinedBranchlessBinarySearch(values, values[targetIndex]);
|
||||
assertEquals(targetIndex, actualIndex);
|
||||
actualIndex = linearSearch(values, values[targetIndex], start);
|
||||
assertEquals(targetIndex, actualIndex);
|
||||
actualIndex = vectorUtilSearch(values, values[targetIndex], start);
|
||||
assertEquals(targetIndex, actualIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -46,6 +46,7 @@ import org.apache.lucene.index.PostingsEnum;
|
|||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SlowImpactsEnum;
|
||||
import org.apache.lucene.internal.vectorization.PostingDecodingUtil;
|
||||
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
|
||||
import org.apache.lucene.internal.vectorization.VectorizationProvider;
|
||||
import org.apache.lucene.store.ByteArrayDataInput;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
|
@ -65,6 +66,8 @@ import org.apache.lucene.util.IOUtils;
|
|||
public final class Lucene912PostingsReader extends PostingsReaderBase {
|
||||
|
||||
static final VectorizationProvider VECTORIZATION_PROVIDER = VectorizationProvider.getInstance();
|
||||
private static final VectorUtilSupport VECTOR_SUPPORT =
|
||||
VECTORIZATION_PROVIDER.getVectorUtilSupport();
|
||||
// Dummy impacts, composed of the maximum possible term frequency and the lowest possible
|
||||
// (unsigned) norm value. This is typically used on tail blocks, which don't actually record
|
||||
// impacts as the storage overhead would not be worth any query evaluation speedup, since there's
|
||||
|
@ -215,15 +218,6 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
}
|
||||
}
|
||||
|
||||
static int findFirstGreater(long[] buffer, int target, int from) {
|
||||
for (int i = from; i < BLOCK_SIZE; ++i) {
|
||||
if (buffer[i] >= target) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return BLOCK_SIZE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BlockTermState newTermState() {
|
||||
return new IntBlockTermState();
|
||||
|
@ -357,6 +351,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
protected int docCountUpto; // number of docs in or before the current block
|
||||
protected long prevDocID; // last doc ID of the previous block
|
||||
|
||||
protected int docBufferSize;
|
||||
protected int docBufferUpto;
|
||||
|
||||
protected IndexInput docIn;
|
||||
|
@ -402,6 +397,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
level1DocEndFP = termState.docStartFP;
|
||||
}
|
||||
level1DocCountUpto = 0;
|
||||
docBufferSize = BLOCK_SIZE;
|
||||
docBufferUpto = BLOCK_SIZE;
|
||||
return this;
|
||||
}
|
||||
|
@ -487,7 +483,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
docCountUpto += BLOCK_SIZE;
|
||||
prevDocID = docBuffer[BLOCK_SIZE - 1];
|
||||
docBufferUpto = 0;
|
||||
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
|
||||
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
private void refillRemainder() throws IOException {
|
||||
|
@ -508,6 +504,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
docCountUpto += left;
|
||||
}
|
||||
docBufferUpto = 0;
|
||||
docBufferSize = left;
|
||||
freqFP = -1;
|
||||
}
|
||||
|
||||
|
@ -604,7 +601,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
}
|
||||
}
|
||||
|
||||
int next = findFirstGreater(docBuffer, target, docBufferUpto);
|
||||
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
|
||||
this.doc = (int) docBuffer[next];
|
||||
docBufferUpto = next + 1;
|
||||
return doc;
|
||||
|
@ -782,16 +779,18 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
freqBuffer[0] = totalTermFreq;
|
||||
docBuffer[1] = NO_MORE_DOCS;
|
||||
docCountUpto++;
|
||||
docBufferSize = 1;
|
||||
} else {
|
||||
// Read vInts:
|
||||
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, true);
|
||||
prefixSum(docBuffer, left, prevDocID);
|
||||
docBuffer[left] = NO_MORE_DOCS;
|
||||
docCountUpto += left;
|
||||
docBufferSize = left;
|
||||
}
|
||||
prevDocID = docBuffer[BLOCK_SIZE - 1];
|
||||
docBufferUpto = 0;
|
||||
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
|
||||
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
private void skipLevel1To(int target) throws IOException {
|
||||
|
@ -951,7 +950,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
refillDocs();
|
||||
}
|
||||
|
||||
int next = findFirstGreater(docBuffer, target, docBufferUpto);
|
||||
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
|
||||
posPendingCount += sumOverRange(freqBuffer, docBufferUpto, next + 1);
|
||||
this.freq = (int) freqBuffer[next];
|
||||
this.docBufferUpto = next + 1;
|
||||
|
@ -1155,6 +1154,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
protected int docCountUpto; // number of docs in or before the current block
|
||||
protected int doc = -1; // doc we last read
|
||||
protected long prevDocID = -1; // last doc ID of the previous block
|
||||
protected int docBufferSize = BLOCK_SIZE;
|
||||
protected int docBufferUpto = BLOCK_SIZE;
|
||||
|
||||
// true if we shallow-advanced to a new block that we have not decoded yet
|
||||
|
@ -1306,10 +1306,11 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
docBuffer[left] = NO_MORE_DOCS;
|
||||
freqFP = -1;
|
||||
docCountUpto += left;
|
||||
docBufferSize = left;
|
||||
}
|
||||
prevDocID = docBuffer[BLOCK_SIZE - 1];
|
||||
docBufferUpto = 0;
|
||||
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
|
||||
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
private void skipLevel1To(int target) throws IOException {
|
||||
|
@ -1437,7 +1438,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
needsRefilling = false;
|
||||
}
|
||||
|
||||
int next = findFirstGreater(docBuffer, target, docBufferUpto);
|
||||
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
|
||||
this.doc = (int) docBuffer[next];
|
||||
docBufferUpto = next + 1;
|
||||
return doc;
|
||||
|
@ -1535,10 +1536,11 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
prefixSum(docBuffer, left, prevDocID);
|
||||
docBuffer[left] = NO_MORE_DOCS;
|
||||
docCountUpto += left;
|
||||
docBufferSize = left;
|
||||
}
|
||||
prevDocID = docBuffer[BLOCK_SIZE - 1];
|
||||
docBufferUpto = 0;
|
||||
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
|
||||
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
private void skipLevel1To(int target) throws IOException {
|
||||
|
@ -1669,7 +1671,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
|
|||
needsRefilling = false;
|
||||
}
|
||||
|
||||
int next = findFirstGreater(docBuffer, target, docBufferUpto);
|
||||
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
|
||||
posPendingCount += sumOverRange(freqBuffer, docBufferUpto, next + 1);
|
||||
freq = (int) freqBuffer[next];
|
||||
docBufferUpto = next + 1;
|
||||
|
|
|
@ -197,4 +197,14 @@ final class DefaultVectorUtilSupport implements VectorUtilSupport {
|
|||
}
|
||||
return squareSum;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int findNextGEQ(long[] buffer, int length, long target, int from) {
|
||||
for (int i = from; i < length; ++i) {
|
||||
if (buffer[i] >= target) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return length;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,4 +44,12 @@ public interface VectorUtilSupport {
|
|||
|
||||
/** Returns the sum of squared differences of the two byte vectors. */
|
||||
int squareDistance(byte[] a, byte[] b);
|
||||
|
||||
/**
|
||||
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code
|
||||
* length} exclusive, find the first array index whose value is greater than or equal to {@code
|
||||
* target}. This index is guaranteed to be at least {@code from}. If there is no such array index,
|
||||
* {@code length} is returned.
|
||||
*/
|
||||
int findNextGEQ(long[] buffer, int length, long target, int from);
|
||||
}
|
||||
|
|
|
@ -106,6 +106,10 @@ public final class ImpactsDISI extends DocIdSetIterator {
|
|||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
DocIdSetIterator in = this.in;
|
||||
if (in.docID() < upTo) {
|
||||
return in.nextDoc();
|
||||
}
|
||||
return advance(in.docID() + 1);
|
||||
}
|
||||
|
||||
|
|
|
@ -307,4 +307,14 @@ public final class VectorUtil {
|
|||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code
|
||||
* length} exclusive, find the first array index whose value is greater than or equal to {@code
|
||||
* target}. This index is guaranteed to be at least {@code from}. If there is no such array index,
|
||||
* {@code length} is returned.
|
||||
*/
|
||||
public static int findNextGEQ(long[] buffer, int length, long target, int from) {
|
||||
return IMPL.findNextGEQ(buffer, length, target, from);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,8 +29,11 @@ import java.lang.foreign.MemorySegment;
|
|||
import jdk.incubator.vector.ByteVector;
|
||||
import jdk.incubator.vector.FloatVector;
|
||||
import jdk.incubator.vector.IntVector;
|
||||
import jdk.incubator.vector.LongVector;
|
||||
import jdk.incubator.vector.ShortVector;
|
||||
import jdk.incubator.vector.Vector;
|
||||
import jdk.incubator.vector.VectorMask;
|
||||
import jdk.incubator.vector.VectorOperators;
|
||||
import jdk.incubator.vector.VectorShape;
|
||||
import jdk.incubator.vector.VectorSpecies;
|
||||
import org.apache.lucene.util.Constants;
|
||||
|
@ -56,6 +59,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
PanamaVectorConstants.PRERERRED_INT_SPECIES;
|
||||
private static final VectorSpecies<Byte> BYTE_SPECIES;
|
||||
private static final VectorSpecies<Short> SHORT_SPECIES;
|
||||
private static final VectorSpecies<Long> LONG_SPECIES;
|
||||
|
||||
static final int VECTOR_BITSIZE;
|
||||
|
||||
|
@ -71,6 +75,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
BYTE_SPECIES = null;
|
||||
SHORT_SPECIES = null;
|
||||
}
|
||||
LONG_SPECIES = PanamaVectorConstants.PRERERRED_LONG_SPECIES;
|
||||
}
|
||||
|
||||
// the way FMA should work! if available use it, otherwise fall back to mul/add
|
||||
|
@ -761,4 +766,27 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
// reduce
|
||||
return acc1.add(acc2).reduceLanes(ADD);
|
||||
}
|
||||
|
||||
// Experiments suggest that we need at least 4 lanes so that the overhead of going with the vector
|
||||
// approach and counting trues on vector masks pays off.
|
||||
private static final boolean ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO = LONG_SPECIES.length() >= 4;
|
||||
|
||||
@Override
|
||||
public int findNextGEQ(long[] buffer, int length, long target, int from) {
|
||||
if (ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO) {
|
||||
for (; from + LONG_SPECIES.length() < length; from += LONG_SPECIES.length() + 1) {
|
||||
if (buffer[from + LONG_SPECIES.length()] >= target) {
|
||||
LongVector vector = LongVector.fromArray(LONG_SPECIES, buffer, from);
|
||||
VectorMask<Long> mask = vector.compare(VectorOperators.LT, target);
|
||||
return from + mask.trueCount();
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = from; i < length; ++i) {
|
||||
if (buffer[i] >= target) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return length;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -353,4 +353,35 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
public void testFindNextGEQ() {
|
||||
int padding = TestUtil.nextInt(random(), 0, 5);
|
||||
long[] values = new long[128 + padding];
|
||||
long v = 0;
|
||||
for (int i = 0; i < 128; ++i) {
|
||||
v += TestUtil.nextInt(random(), 1, 1000);
|
||||
values[i] = v;
|
||||
}
|
||||
|
||||
// Now duel with slowFindFirstGreater
|
||||
for (int iter = 0; iter < 1_000; ++iter) {
|
||||
int from = TestUtil.nextInt(random(), 0, 127);
|
||||
long target =
|
||||
TestUtil.nextLong(random(), values[from], Math.max(values[from], values[127]))
|
||||
+ random().nextInt(10)
|
||||
- 5;
|
||||
assertEquals(
|
||||
slowFindNextGEQ(values, 128, target, from),
|
||||
VectorUtil.findNextGEQ(values, 128, target, from));
|
||||
}
|
||||
}
|
||||
|
||||
private static int slowFindNextGEQ(long[] buffer, int length, long target, int from) {
|
||||
for (int i = from; i < length; ++i) {
|
||||
if (buffer[i] >= target) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return length;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue