Use radix sort to speed up the sorting of deleted terms (#12573)

This commit is contained in:
gf2121 2023-09-22 00:33:01 -05:00 committed by GitHub
parent fb1f4dd412
commit 8b84f6c096
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 246 additions and 61 deletions

View File

@ -150,6 +150,8 @@ Optimizations
---------------------
* GITHUB#12183: Make TermStates#build concurrent. (Shubham Chaudhary)
* GITHUB#12573: Use radix sort to speed up the sorting of deleted terms. (Guo Feng)
Changes in runtime behavior
---------------------

View File

@ -16,13 +16,22 @@
*/
package org.apache.lucene.index;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.lucene.index.DocValuesUpdate.BinaryDocValuesUpdate;
import org.apache.lucene.index.DocValuesUpdate.NumericDocValuesUpdate;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.ByteBlockPool;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefHash;
import org.apache.lucene.util.Counter;
import org.apache.lucene.util.RamUsageEstimator;
@ -39,21 +48,6 @@ import org.apache.lucene.util.RamUsageEstimator;
class BufferedUpdates implements Accountable {
/* Rough logic: HashMap has an array[Entry] w/ varying
load factor (say 2 * POINTER). Entry is object w/ Term
key, Integer val, int hash, Entry next
(OBJ_HEADER + 3*POINTER + INT). Term is object w/
String field and String text (OBJ_HEADER + 2*POINTER).
Term's field is String (OBJ_HEADER + 4*INT + POINTER +
OBJ_HEADER + string.length*CHAR).
Term's text is String (OBJ_HEADER + 4*INT + POINTER +
OBJ_HEADER + string.length*CHAR). Integer is
OBJ_HEADER + INT. */
static final int BYTES_PER_DEL_TERM =
9 * RamUsageEstimator.NUM_BYTES_OBJECT_REF
+ 7 * RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
+ 10 * Integer.BYTES;
/* Rough logic: HashMap has an array[Entry] w/ varying
load factor (say 2 * POINTER). Entry is object w/
Query key, Integer val, int hash, Entry next
@ -67,8 +61,7 @@ class BufferedUpdates implements Accountable {
final AtomicInteger numTermDeletes = new AtomicInteger();
final AtomicInteger numFieldUpdates = new AtomicInteger();
final Map<Term, Integer> deleteTerms =
new HashMap<>(); // TODO cut this over to FieldUpdatesBuffer
final DeletedTerms deleteTerms = new DeletedTerms();
final Map<Query, Integer> deleteQueries = new HashMap<>();
final Map<String, FieldUpdatesBuffer> fieldUpdates = new HashMap<>();
@ -77,7 +70,6 @@ class BufferedUpdates implements Accountable {
private final Counter bytesUsed = Counter.newCounter(true);
final Counter fieldUpdatesBytesUsed = Counter.newCounter(true);
private final Counter termsBytesUsed = Counter.newCounter(true);
private static final boolean VERBOSE_DELETES = false;
@ -127,8 +119,8 @@ class BufferedUpdates implements Accountable {
}
public void addTerm(Term term, int docIDUpto) {
Integer current = deleteTerms.get(term);
if (current != null && docIDUpto < current) {
int current = deleteTerms.get(term);
if (current != -1 && docIDUpto < current) {
// Only record the new number if it's greater than the
// current one. This is important because if multiple
// threads are replacing the same doc at nearly the
@ -139,15 +131,11 @@ class BufferedUpdates implements Accountable {
return;
}
deleteTerms.put(term, Integer.valueOf(docIDUpto));
// note that if current != null then it means there's already a buffered
deleteTerms.put(term, docIDUpto);
// note that if current != -1 then it means there's already a buffered
// delete on that term, therefore we seem to over-count. this over-counting
// is done to respect IndexWriterConfig.setMaxBufferedDeleteTerms.
numTermDeletes.incrementAndGet();
if (current == null) {
termsBytesUsed.addAndGet(
BYTES_PER_DEL_TERM + term.bytes.length + (Character.BYTES * term.field().length()));
}
}
void addNumericUpdate(NumericDocValuesUpdate update, int docIDUpto) {
@ -176,7 +164,6 @@ class BufferedUpdates implements Accountable {
void clearDeleteTerms() {
numTermDeletes.set(0);
termsBytesUsed.addAndGet(-termsBytesUsed.get());
deleteTerms.clear();
}
@ -188,7 +175,6 @@ class BufferedUpdates implements Accountable {
fieldUpdates.clear();
bytesUsed.addAndGet(-bytesUsed.get());
fieldUpdatesBytesUsed.addAndGet(-fieldUpdatesBytesUsed.get());
termsBytesUsed.addAndGet(-termsBytesUsed.get());
}
boolean any() {
@ -197,6 +183,164 @@ class BufferedUpdates implements Accountable {
@Override
public long ramBytesUsed() {
return bytesUsed.get() + fieldUpdatesBytesUsed.get() + termsBytesUsed.get();
return bytesUsed.get() + fieldUpdatesBytesUsed.get() + deleteTerms.ramBytesUsed();
}
static class DeletedTerms implements Accountable {
private final Counter bytesUsed = Counter.newCounter();
private final ByteBlockPool pool =
new ByteBlockPool(new ByteBlockPool.DirectTrackingAllocator(bytesUsed));
private final Map<String, BytesRefIntMap> deleteTerms = new HashMap<>();
private int termsSize = 0;
DeletedTerms() {}
/**
* Get the newest doc id of the deleted term.
*
* @param term The deleted term.
* @return The newest doc id of this deleted term.
*/
int get(Term term) {
BytesRefIntMap hash = deleteTerms.get(term.field);
if (hash == null) {
return -1;
}
return hash.get(term.bytes);
}
/**
* Put the newest doc id of the deleted term.
*
* @param term The deleted term.
* @param value The newest doc id of the deleted term.
*/
void put(Term term, int value) {
BytesRefIntMap hash =
deleteTerms.computeIfAbsent(
term.field,
k -> {
bytesUsed.addAndGet(RamUsageEstimator.sizeOf(term.field));
return new BytesRefIntMap(pool, bytesUsed);
});
if (hash.put(term.bytes, value)) {
termsSize++;
}
}
void clear() {
bytesUsed.addAndGet(-bytesUsed.get());
deleteTerms.clear();
termsSize = 0;
}
int size() {
return termsSize;
}
boolean isEmpty() {
return termsSize == 0;
}
/** Just for test, not efficient. */
Set<Term> keySet() {
return deleteTerms.entrySet().stream()
.flatMap(
entry -> entry.getValue().keySet().stream().map(b -> new Term(entry.getKey(), b)))
.collect(Collectors.toSet());
}
interface DeletedTermConsumer<E extends Exception> {
void accept(Term term, int docId) throws E;
}
/**
* Consume all terms in a sorted order.
*
* <p>Note: This is a destructive operation as it calls {@link BytesRefHash#sort()}.
*
* @see BytesRefHash#sort
*/
<E extends Exception> void forEachOrdered(DeletedTermConsumer<E> consumer) throws E {
List<Map.Entry<String, BytesRefIntMap>> deleteFields =
new ArrayList<>(deleteTerms.entrySet());
deleteFields.sort(Map.Entry.comparingByKey());
Term scratch = new Term("", new BytesRef());
for (Map.Entry<String, BufferedUpdates.BytesRefIntMap> deleteFieldEntry : deleteFields) {
scratch.field = deleteFieldEntry.getKey();
BufferedUpdates.BytesRefIntMap terms = deleteFieldEntry.getValue();
int[] indices = terms.bytesRefHash.sort();
for (int index : indices) {
if (index != -1) {
terms.bytesRefHash.get(index, scratch.bytes);
consumer.accept(scratch, terms.values[index]);
}
}
}
}
@Override
public long ramBytesUsed() {
return bytesUsed.get();
}
}
private static class BytesRefIntMap {
private static final long INIT_RAM_BYTES =
RamUsageEstimator.shallowSizeOf(BytesRefIntMap.class)
+ RamUsageEstimator.shallowSizeOf(BytesRefHash.class)
+ RamUsageEstimator.sizeOf(new int[BytesRefHash.DEFAULT_CAPACITY]);
private final Counter counter;
private final BytesRefHash bytesRefHash;
private int[] values;
private BytesRefIntMap(ByteBlockPool pool, Counter counter) {
this.counter = counter;
this.bytesRefHash =
new BytesRefHash(
pool,
BytesRefHash.DEFAULT_CAPACITY,
new BytesRefHash.DirectBytesStartArray(BytesRefHash.DEFAULT_CAPACITY, counter));
this.values = new int[BytesRefHash.DEFAULT_CAPACITY];
counter.addAndGet(INIT_RAM_BYTES);
}
private Set<BytesRef> keySet() {
BytesRef scratch = new BytesRef();
Set<BytesRef> set = new HashSet<>();
for (int i = 0; i < bytesRefHash.size(); i++) {
bytesRefHash.get(i, scratch);
set.add(BytesRef.deepCopyOf(scratch));
}
return set;
}
private boolean put(BytesRef key, int value) {
assert value >= 0;
int e = bytesRefHash.add(key);
if (e < 0) {
values[-e - 1] = value;
return false;
} else {
if (e >= values.length) {
int originLength = values.length;
values = ArrayUtil.grow(values, e + 1);
counter.addAndGet((long) (values.length - originLength) * Integer.BYTES);
}
values[e] = value;
return true;
}
}
private int get(BytesRef key) {
int e = bytesRefHash.find(key);
if (e == -1) {
return -1;
}
return values[e];
}
}
}

View File

@ -18,7 +18,6 @@ package org.apache.lucene.index;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.lucene.codecs.FieldsConsumer;
@ -55,29 +54,29 @@ final class FreqProxTermsWriter extends TermsHash {
// Process any pending Term deletes for this newly
// flushed segment:
if (state.segUpdates != null && state.segUpdates.deleteTerms.size() > 0) {
Map<Term, Integer> segDeletes = state.segUpdates.deleteTerms;
List<Term> deleteTerms = new ArrayList<>(segDeletes.keySet());
Collections.sort(deleteTerms);
BufferedUpdates.DeletedTerms segDeletes = state.segUpdates.deleteTerms;
FrozenBufferedUpdates.TermDocsIterator iterator =
new FrozenBufferedUpdates.TermDocsIterator(fields, true);
for (Term deleteTerm : deleteTerms) {
DocIdSetIterator postings = iterator.nextTerm(deleteTerm.field(), deleteTerm.bytes());
if (postings != null) {
int delDocLimit = segDeletes.get(deleteTerm);
assert delDocLimit < PostingsEnum.NO_MORE_DOCS;
int doc;
while ((doc = postings.nextDoc()) < delDocLimit) {
if (state.liveDocs == null) {
state.liveDocs = new FixedBitSet(state.segmentInfo.maxDoc());
state.liveDocs.set(0, state.segmentInfo.maxDoc());
segDeletes.forEachOrdered(
(term, docId) -> {
DocIdSetIterator postings = iterator.nextTerm(term.field(), term.bytes());
if (postings != null) {
assert docId < PostingsEnum.NO_MORE_DOCS;
int doc;
while ((doc = postings.nextDoc()) < docId) {
if (state.liveDocs == null) {
state.liveDocs = new FixedBitSet(state.segmentInfo.maxDoc());
state.liveDocs.set(0, state.segmentInfo.maxDoc());
}
if (state.liveDocs.get(doc)) {
state.delCountOnFlush++;
state.liveDocs.clear(doc);
}
}
}
if (state.liveDocs.get(doc)) {
state.delCountOnFlush++;
state.liveDocs.clear(doc);
}
}
}
}
});
}
}

View File

@ -31,7 +31,6 @@ import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.InfoStream;
@ -86,12 +85,9 @@ final class FrozenBufferedUpdates {
this.privateSegment = privateSegment;
assert privateSegment == null || updates.deleteTerms.isEmpty()
: "segment private packet should only have del queries";
Term[] termsArray = updates.deleteTerms.keySet().toArray(new Term[updates.deleteTerms.size()]);
ArrayUtil.timSort(termsArray);
PrefixCodedTerms.Builder builder = new PrefixCodedTerms.Builder();
for (Term term : termsArray) {
builder.add(term);
}
updates.deleteTerms.forEachOrdered((term, doc) -> builder.add(term));
deleteTerms = builder.finish();
deleteQueries = new Query[updates.deleteQueries.size()];

View File

@ -16,8 +16,13 @@
*/
package org.apache.lucene.index;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
/** Unit test for {@link BufferedUpdates} */
public class TestBufferedUpdates extends LuceneTestCase {
@ -28,14 +33,14 @@ public class TestBufferedUpdates extends LuceneTestCase {
assertFalse(bu.any());
int queries = atLeast(1);
for (int i = 0; i < queries; i++) {
final int docIDUpto = random().nextBoolean() ? Integer.MAX_VALUE : random().nextInt();
final int docIDUpto = random().nextBoolean() ? Integer.MAX_VALUE : random().nextInt(100000);
final Term term = new Term("id", Integer.toString(random().nextInt(100)));
bu.addQuery(new TermQuery(term), docIDUpto);
}
int terms = atLeast(1);
for (int i = 0; i < terms; i++) {
final int docIDUpto = random().nextBoolean() ? Integer.MAX_VALUE : random().nextInt();
final int docIDUpto = random().nextBoolean() ? Integer.MAX_VALUE : random().nextInt(100000);
final Term term = new Term("id", Integer.toString(random().nextInt(100)));
bu.addTerm(term, docIDUpto);
}
@ -52,4 +57,44 @@ public class TestBufferedUpdates extends LuceneTestCase {
assertFalse(bu.any());
assertEquals(bu.ramBytesUsed(), 0L);
}
public void testDeletedTerms() {
int iters = atLeast(10);
String[] fields = new String[] {"a", "b", "c"};
for (int iter = 0; iter < iters; iter++) {
Map<Term, Integer> expected = new HashMap<>();
BufferedUpdates.DeletedTerms actual = new BufferedUpdates.DeletedTerms();
assertTrue(actual.isEmpty());
int termCount = atLeast(5000);
int maxBytesNum = random().nextInt(3) + 1;
for (int i = 0; i < termCount; i++) {
int byteNum = random().nextInt(maxBytesNum) + 1;
byte[] bytes = new byte[byteNum];
random().nextBytes(bytes);
Term term = new Term(fields[random().nextInt(fields.length)], new BytesRef(bytes));
int value = random().nextInt(10000000);
expected.put(term, value);
actual.put(term, value);
}
assertEquals(expected.size(), actual.size());
for (Map.Entry<Term, Integer> entry : expected.entrySet()) {
assertEquals(entry.getValue(), Integer.valueOf(actual.get(entry.getKey())));
}
List<Map.Entry<Term, Integer>> expectedSorted =
expected.entrySet().stream().sorted(Map.Entry.comparingByKey()).toList();
List<Map.Entry<Term, Integer>> actualSorted = new ArrayList<>();
actual.forEachOrdered(
((term, docId) -> {
Term copy = new Term(term.field, BytesRef.deepCopyOf(term.bytes));
actualSorted.add(Map.entry(copy, docId));
}));
assertEquals(expectedSorted, actualSorted);
}
}
}

View File

@ -84,8 +84,7 @@ public class TestDocumentsWriterDeleteQueue extends LuceneTestCase {
private void assertAllBetween(int start, int end, BufferedUpdates deletes, Integer[] ids) {
for (int i = start; i <= end; i++) {
assertEquals(
Integer.valueOf(end), deletes.deleteTerms.get(new Term("id", ids[i].toString())));
assertEquals(end, deletes.deleteTerms.get(new Term("id", ids[i].toString())));
}
}

View File

@ -1158,7 +1158,7 @@ public class TestIndexWriterDelete extends LuceneTestCase {
new IndexWriter(
dir,
newIndexWriterConfig(new MockAnalyzer(random()))
.setRAMBufferSizeMB(0.1f)
.setRAMBufferSizeMB(0.5f)
.setMaxBufferedDocs(1000)
.setMergePolicy(NoMergePolicy.INSTANCE)
.setReaderPooling(false));