From a337d14b21c3882c38da1b494730aa8f7f12827d Mon Sep 17 00:00:00 2001 From: Adrien Grand Date: Thu, 19 Dec 2024 15:05:14 +0100 Subject: [PATCH] Use the new `loadIntoBitSet` API to speed up dense conjunctions. (#14080) Now that loading doc IDs into a bit set is much more efficient thanks to auto-vectorization, it has become tempting to evaluate dense conjunctions by and-ing bit sets. --- lucene/CHANGES.txt | 3 + .../lucene/search/BooleanScorerSupplier.java | 38 +++- .../search/DenseConjunctionBulkScorer.java | 193 +++++++++++++++++ .../search/DisjunctionDISIApproximation.java | 19 ++ .../org/apache/lucene/util/FixedBitSet.java | 11 +- .../TestDenseConjunctionBulkScorer.java | 204 ++++++++++++++++++ .../apache/lucene/util/TestFixedBitSet.java | 26 +++ 7 files changed, 484 insertions(+), 10 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/DenseConjunctionBulkScorer.java create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestDenseConjunctionBulkScorer.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 3e34678e241..99347d8984c 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -55,6 +55,9 @@ Optimizations * GITHUB#14052: Speed up DisjunctionDISIApproximation#advance. (Adrien Grand) +* GITHUB#14080: Use the `DocIdSetIterator#loadIntoBitSet` API to speed up dense + conjunctions. (Adrien Grand) + Bug Fixes --------------------- (No changes) diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java index b7c613d06a7..395e52b8849 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java @@ -304,9 +304,9 @@ final class BooleanScorerSupplier extends ScorerSupplier { BulkScorer filteredOptionalBulkScorer() throws IOException { if (subs.get(Occur.MUST).isEmpty() == false || subs.get(Occur.FILTER).isEmpty() - || scoreMode != ScoreMode.TOP_SCORES + || (scoreMode.needsScores() && scoreMode != ScoreMode.TOP_SCORES) || subs.get(Occur.SHOULD).size() <= 1 - || minShouldMatch > 1) { + || minShouldMatch != 1) { return null; } long cost = cost(); @@ -318,13 +318,28 @@ final class BooleanScorerSupplier extends ScorerSupplier { for (ScorerSupplier ss : subs.get(Occur.FILTER)) { filters.add(ss.get(cost)); } - Scorer filterScorer; - if (filters.size() == 1) { - filterScorer = filters.iterator().next(); + if (scoreMode == ScoreMode.TOP_SCORES) { + Scorer filterScorer; + if (filters.size() == 1) { + filterScorer = filters.iterator().next(); + } else { + filterScorer = new ConjunctionScorer(filters, Collections.emptySet()); + } + return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer); } else { - filterScorer = new ConjunctionScorer(filters, Collections.emptySet()); + // In the beginning of this method, we exited early if the score mode is not either TOP_SCORES + // or a score mode that doesn't need scores. + assert scoreMode.needsScores() == false; + filters.add(new DisjunctionSumScorer(optionalScorers, scoreMode, cost)); + + if (filters.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull) + && maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE + && cost >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE) { + return new DenseConjunctionBulkScorer(filters.stream().map(Scorer::iterator).toList()); + } + + return new DefaultBulkScorer(new ConjunctionScorer(filters, Collections.emptyList())); } - return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer); } // Return a BulkScorer for the required clauses only @@ -378,7 +393,14 @@ final class BooleanScorerSupplier extends ScorerSupplier { && requiredScoring.size() + requiredNoScoring.size() >= 2 && requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull) && requiredNoScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) { - return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring); + if (requiredScoring.isEmpty() + && maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE + && leadCost >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE) { + return new DenseConjunctionBulkScorer( + requiredNoScoring.stream().map(Scorer::iterator).toList()); + } else { + return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring); + } } if (scoreMode == ScoreMode.TOP_SCORES && requiredScoring.size() > 1) { requiredScoring = Collections.singletonList(new BlockMaxConjunctionScorer(requiredScoring)); diff --git a/lucene/core/src/java/org/apache/lucene/search/DenseConjunctionBulkScorer.java b/lucene/core/src/java/org/apache/lucene/search/DenseConjunctionBulkScorer.java new file mode 100644 index 00000000000..78fc7c5c1c2 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/DenseConjunctionBulkScorer.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; + +/** + * BulkScorer implementation of {@link ConjunctionScorer} that is specialized for dense clauses. + * Whenever sensible, it intersects clauses by loading their matches into a bit set and computing + * the intersection of clauses by and-ing these bit sets. + */ +final class DenseConjunctionBulkScorer extends BulkScorer { + + // Use a small-ish window size to make sure that we can take advantage of gaps in the postings of + // clauses that are not leading iteration. + static final int WINDOW_SIZE = 4096; + // Only use bit sets to compute the intersection if more than 1/32th of the docs are expected to + // match. Experiments suggested that values that are a bit higher than this would work better, but + // we're erring on the conservative side. + static final int DENSITY_THRESHOLD_INVERSE = Long.SIZE / 2; + + private final DocIdSetIterator lead; + private final List others; + + private final FixedBitSet windowMatches = new FixedBitSet(WINDOW_SIZE); + private final FixedBitSet clauseWindowMatches = new FixedBitSet(WINDOW_SIZE); + private final DocIdStreamView docIdStreamView = new DocIdStreamView(); + + DenseConjunctionBulkScorer(List iterators) { + if (iterators.size() <= 1) { + throw new IllegalArgumentException("Expected 2 or more clauses, got " + iterators.size()); + } + iterators = new ArrayList<>(iterators); + iterators.sort(Comparator.comparingLong(DocIdSetIterator::cost)); + lead = iterators.get(0); + others = List.copyOf(iterators.subList(1, iterators.size())); + } + + @Override + public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { + for (DocIdSetIterator it : others) { + min = Math.max(min, it.docID()); + } + + if (lead.docID() < min) { + lead.advance(min); + } + + if (lead.docID() >= max) { + return lead.docID(); + } + + List otherIterators = this.others; + DocIdSetIterator collectorIterator = collector.competitiveIterator(); + if (collectorIterator != null) { + otherIterators = new ArrayList<>(otherIterators); + otherIterators.add(collectorIterator); + } + + final DocIdSetIterator[] others = otherIterators.toArray(DocIdSetIterator[]::new); + + int windowMax; + do { + windowMax = (int) Math.min(max, (long) lead.docID() + WINDOW_SIZE); + scoreWindowUsingBitSet(collector, acceptDocs, others, windowMax); + } while (windowMax < max); + + return lead.docID(); + } + + private static int advance(FixedBitSet set, int i) { + if (i >= WINDOW_SIZE) { + return DocIdSetIterator.NO_MORE_DOCS; + } else { + return set.nextSetBit(i); + } + } + + private void scoreWindowUsingBitSet( + LeafCollector collector, Bits acceptDocs, DocIdSetIterator[] others, int max) + throws IOException { + assert windowMatches.scanIsEmpty(); + assert clauseWindowMatches.scanIsEmpty(); + + int offset = lead.docID(); + lead.intoBitSet(acceptDocs, max, windowMatches, offset); + + int upTo = 0; + for (; + upTo < others.length + && windowMatches.cardinality() >= WINDOW_SIZE / DENSITY_THRESHOLD_INVERSE; + upTo++) { + DocIdSetIterator other = others[upTo]; + if (other.docID() < offset) { + other.advance(offset); + } + // No need to apply acceptDocs on other clauses since we already applied live docs on the + // leading clause. + other.intoBitSet(null, max, clauseWindowMatches, offset); + windowMatches.and(clauseWindowMatches); + clauseWindowMatches.clear(); + } + + if (upTo < others.length) { + // If the leading clause is sparse on this doc ID range or if the intersection became sparse + // after applying a few clauses, we finish evaluating the intersection using the traditional + // leap-frog approach. This proved important with a query such as "+secretary +of +state" on + // wikibigall, where the intersection becomes sparse after intersecting "secretary" and + // "state". + advanceHead: + for (int windowMatch = windowMatches.nextSetBit(0); + windowMatch != DocIdSetIterator.NO_MORE_DOCS; ) { + int doc = offset + windowMatch; + for (int i = upTo; i < others.length; ++i) { + DocIdSetIterator other = others[i]; + int otherDoc = other.docID(); + if (otherDoc < doc) { + otherDoc = other.advance(doc); + } + if (doc != otherDoc) { + int clearUpTo = Math.min(WINDOW_SIZE, otherDoc - offset); + windowMatches.clear(windowMatch, clearUpTo); + windowMatch = advance(windowMatches, clearUpTo); + continue advanceHead; + } + } + windowMatch = advance(windowMatches, windowMatch + 1); + } + } + + docIdStreamView.offset = offset; + collector.collect(docIdStreamView); + windowMatches.clear(); + + // If another clause is more advanced than lead1 then advance lead1, it's important to take + // advantage of large gaps in the postings lists of other clauses. + int maxOtherDocID = -1; + for (DocIdSetIterator other : others) { + maxOtherDocID = Math.max(maxOtherDocID, other.docID()); + } + if (lead.docID() < maxOtherDocID) { + lead.advance(maxOtherDocID); + } + } + + @Override + public long cost() { + return lead.cost(); + } + + final class DocIdStreamView extends DocIdStream { + + int offset; + + @Override + public void forEach(CheckedIntConsumer consumer) throws IOException { + int offset = this.offset; + long[] bitArray = windowMatches.getBits(); + for (int idx = 0; idx < bitArray.length; idx++) { + long bits = bitArray[idx]; + while (bits != 0L) { + int ntz = Long.numberOfTrailingZeros(bits); + consumer.accept(offset + ((idx << 6) | ntz)); + bits ^= 1L << ntz; + } + } + } + + @Override + public int count() throws IOException { + return windowMatches.cardinality(); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionDISIApproximation.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionDISIApproximation.java index 3b7e2b1014c..8dc551e3c95 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionDISIApproximation.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionDISIApproximation.java @@ -21,6 +21,8 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Comparator; import java.util.List; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; /** * A {@link DocIdSetIterator} which is a disjunction of the approximations of the provided @@ -141,6 +143,23 @@ public final class DisjunctionDISIApproximation extends DocIdSetIterator { return Math.min(leadTop.doc, minOtherDoc); } + @Override + public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset) + throws IOException { + while (leadTop.doc < upTo) { + leadTop.approximation.intoBitSet(acceptDocs, upTo, bitSet, offset); + leadTop.doc = leadTop.approximation.docID(); + leadTop = leadIterators.updateTop(); + } + + minOtherDoc = Integer.MAX_VALUE; + for (DisiWrapper w : otherIterators) { + w.approximation.intoBitSet(acceptDocs, upTo, bitSet, offset); + w.doc = w.approximation.docID(); + minOtherDoc = Math.min(minOtherDoc, w.doc); + } + } + /** Return the linked list of iterators positioned on the current doc. */ public DisiWrapper topList() { if (leadTop.doc < minOtherDoc) { diff --git a/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java b/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java index 4101b6ff1e2..1a79025c882 100644 --- a/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java +++ b/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java @@ -33,6 +33,10 @@ public final class FixedBitSet extends BitSet { private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(FixedBitSet.class); + // An array that is small enough to use reasonable amounts of RAM and large enough to allow + // Arrays#mismatch to use SIMD instructions and multiple registers under the hood. + private static long[] ZEROES = new long[32]; + private final long[] bits; // Array of longs holding the bits private final int numBits; // The number of bits in use private final int numWords; // The exact number of longs needed to hold numBits (<= bits.length) @@ -470,8 +474,11 @@ public final class FixedBitSet extends BitSet { // Depends on the ghost bits being clear! final int count = numWords; - for (int i = 0; i < count; i++) { - if (bits[i] != 0) return false; + for (int i = 0; i < count; i += ZEROES.length) { + int cmpLen = Math.min(ZEROES.length, bits.length - i); + if (Arrays.equals(bits, i, i + cmpLen, ZEROES, 0, cmpLen) == false) { + return false; + } } return true; diff --git a/lucene/core/src/test/org/apache/lucene/search/TestDenseConjunctionBulkScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestDenseConjunctionBulkScorer.java new file mode 100644 index 00000000000..db317775c82 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestDenseConjunctionBulkScorer.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Arrays; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.FixedBitSet; + +public class TestDenseConjunctionBulkScorer extends LuceneTestCase { + + public void testSameMatches() throws IOException { + int maxDoc = 100_000; + FixedBitSet clause1 = new FixedBitSet(maxDoc); + FixedBitSet clause2 = new FixedBitSet(maxDoc); + FixedBitSet clause3 = new FixedBitSet(maxDoc); + for (int i = 0; i < maxDoc; i += 2) { + clause1.set(i); + clause2.set(i); + clause3.set(i); + } + DenseConjunctionBulkScorer scorer = + new DenseConjunctionBulkScorer( + Arrays.asList( + new BitSetIterator(clause1, clause1.approximateCardinality()), + new BitSetIterator(clause2, clause2.approximateCardinality()), + new BitSetIterator(clause3, clause3.approximateCardinality()))); + FixedBitSet result = new FixedBitSet(maxDoc); + scorer.score( + new LeafCollector() { + @Override + public void setScorer(Scorable scorer) throws IOException {} + + @Override + public void collect(int doc) throws IOException { + result.set(doc); + } + }, + null, + 0, + DocIdSetIterator.NO_MORE_DOCS); + + assertEquals(clause1, result); + } + + public void testApplyAcceptDocs() throws IOException { + int maxDoc = 100_000; + FixedBitSet clause1 = new FixedBitSet(maxDoc); + FixedBitSet clause2 = new FixedBitSet(maxDoc); + clause1.set(0, maxDoc); + clause2.set(0, maxDoc); + FixedBitSet acceptDocs = new FixedBitSet(maxDoc); + for (int i = 0; i < maxDoc; i += 2) { + acceptDocs.set(i); + } + DenseConjunctionBulkScorer scorer = + new DenseConjunctionBulkScorer( + Arrays.asList( + new BitSetIterator(clause1, clause1.approximateCardinality()), + new BitSetIterator(clause2, clause2.approximateCardinality()))); + FixedBitSet result = new FixedBitSet(maxDoc); + scorer.score( + new LeafCollector() { + @Override + public void setScorer(Scorable scorer) throws IOException {} + + @Override + public void collect(int doc) throws IOException { + result.set(doc); + } + }, + acceptDocs, + 0, + DocIdSetIterator.NO_MORE_DOCS); + + assertEquals(acceptDocs, result); + } + + public void testEmptyIntersection() throws IOException { + int maxDoc = 100_000; + FixedBitSet clause1 = new FixedBitSet(maxDoc); + FixedBitSet clause2 = new FixedBitSet(maxDoc); + for (int i = 0; i < maxDoc - 1; i += 2) { + clause1.set(i); + clause2.set(i + 1); + } + DenseConjunctionBulkScorer scorer = + new DenseConjunctionBulkScorer( + Arrays.asList( + new BitSetIterator(clause1, clause1.approximateCardinality()), + new BitSetIterator(clause2, clause2.approximateCardinality()))); + FixedBitSet result = new FixedBitSet(maxDoc); + scorer.score( + new LeafCollector() { + @Override + public void setScorer(Scorable scorer) throws IOException {} + + @Override + public void collect(int doc) throws IOException { + result.set(doc); + } + }, + null, + 0, + DocIdSetIterator.NO_MORE_DOCS); + + assertTrue(result.scanIsEmpty()); + } + + public void testClustered() throws IOException { + int maxDoc = 100_000; + FixedBitSet clause1 = new FixedBitSet(maxDoc); + FixedBitSet clause2 = new FixedBitSet(maxDoc); + FixedBitSet clause3 = new FixedBitSet(maxDoc); + clause1.set(10_000, 90_000); + clause2.set(0, 80_000); + clause3.set(20_000, 100_000); + DenseConjunctionBulkScorer scorer = + new DenseConjunctionBulkScorer( + Arrays.asList( + new BitSetIterator(clause1, clause1.approximateCardinality()), + new BitSetIterator(clause2, clause2.approximateCardinality()), + new BitSetIterator(clause3, clause3.approximateCardinality()))); + FixedBitSet result = new FixedBitSet(maxDoc); + scorer.score( + new LeafCollector() { + @Override + public void setScorer(Scorable scorer) throws IOException {} + + @Override + public void collect(int doc) throws IOException { + result.set(doc); + } + }, + null, + 0, + DocIdSetIterator.NO_MORE_DOCS); + + FixedBitSet expected = new FixedBitSet(maxDoc); + expected.set(20_000, 80_000); + assertArrayEquals(expected.getBits(), result.getBits()); + assertEquals(expected, result); + } + + public void testSparseAfter2ndClause() throws IOException { + int maxDoc = 100_000; + FixedBitSet clause1 = new FixedBitSet(maxDoc); + FixedBitSet clause2 = new FixedBitSet(maxDoc); + FixedBitSet clause3 = new FixedBitSet(maxDoc); + // 13 and 17 are primes, so their only intersection is on multiples of both 13 and 17 + // Likewise, 19 is prime, so the only intersection of the conjunction is on multiples of 13, 17 + // and 19 + for (int i = 0; i < maxDoc; i += 13) { + clause1.set(i); + } + for (int i = 0; i < maxDoc; i += 17) { + clause2.set(i); + } + for (int i = 0; i < maxDoc; i += 19) { + clause3.set(i); + } + DenseConjunctionBulkScorer scorer = + new DenseConjunctionBulkScorer( + Arrays.asList( + new BitSetIterator(clause1, clause1.approximateCardinality()), + new BitSetIterator(clause2, clause2.approximateCardinality()), + new BitSetIterator(clause3, clause3.approximateCardinality()))); + FixedBitSet result = new FixedBitSet(maxDoc); + scorer.score( + new LeafCollector() { + @Override + public void setScorer(Scorable scorer) throws IOException {} + + @Override + public void collect(int doc) throws IOException { + result.set(doc); + } + }, + null, + 0, + DocIdSetIterator.NO_MORE_DOCS); + + FixedBitSet expected = new FixedBitSet(maxDoc); + for (int i = 0; i < maxDoc; i += 13 * 17 * 19) { + expected.set(i); + } + assertEquals(expected, result); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/util/TestFixedBitSet.java b/lucene/core/src/test/org/apache/lucene/util/TestFixedBitSet.java index 67f0918f46b..ff2bb4fd0d6 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestFixedBitSet.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestFixedBitSet.java @@ -616,4 +616,30 @@ public class TestFixedBitSet extends BaseBitSetTestCase { set.set(5); assertTrue(bits.get(5)); } + + public void testScanIsEmpty() { + FixedBitSet set = new FixedBitSet(0); + assertTrue(set.scanIsEmpty()); + + set = new FixedBitSet(13); + assertTrue(set.scanIsEmpty()); + set.set(10); + assertFalse(set.scanIsEmpty()); + + set = new FixedBitSet(1024); + assertTrue(set.scanIsEmpty()); + set.set(3); + assertFalse(set.scanIsEmpty()); + set.clear(3); + set.set(1020); + assertFalse(set.scanIsEmpty()); + + set = new FixedBitSet(1030); + assertTrue(set.scanIsEmpty()); + set.set(3); + assertFalse(set.scanIsEmpty()); + set.clear(3); + set.set(1028); + assertFalse(set.scanIsEmpty()); + } }