LUCENE-9791 Allow calling BytesRefHash#find concurrently (#8)

Removes `scratch1` field in `BytesRefHash` by accessing underlying bytes pool directly
in `equals` method. As a result it is now possible to call `BytesRefHash#find`
concurrently as long as there are no concurrent modifications to BytesRefHash instance
and it is correctly published.

This addresses the concurrency issue with Monitor (aka Luwak) since it
is using `BytesRefHash#find` concurrently without additional synchronization.
This commit is contained in:
pawel-bugalski-dynatrace 2021-03-11 14:06:03 +01:00 committed by GitHub
parent dcb52acd7d
commit 6367cd1b74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 6 deletions

View File

@ -40,9 +40,6 @@ import org.apache.lucene.util.ByteBlockPool.DirectAllocator;
public final class BytesRefHash implements Accountable {
private static final long BASE_RAM_BYTES =
RamUsageEstimator.shallowSizeOfInstance(BytesRefHash.class)
+
// size of scratch1
RamUsageEstimator.shallowSizeOfInstance(BytesRef.class)
+
// size of Counter
RamUsageEstimator.primitiveSizes.get(long.class);
@ -54,7 +51,6 @@ public final class BytesRefHash implements Accountable {
final ByteBlockPool pool;
int[] bytesStart;
private final BytesRef scratch1 = new BytesRef();
private int hashSize;
private int hashHalfSize;
private int hashMask;
@ -174,8 +170,21 @@ public final class BytesRefHash implements Accountable {
}
private boolean equals(int id, BytesRef b) {
pool.setBytesRef(scratch1, bytesStart[id]);
return scratch1.bytesEquals(b);
final int textStart = bytesStart[id];
final byte[] bytes = pool.buffers[textStart >> BYTE_BLOCK_SHIFT];
int pos = textStart & BYTE_BLOCK_MASK;
final int length;
final int offset;
if ((bytes[pos] & 0x80) == 0) {
// length is 1 byte
length = bytes[pos];
offset = pos + 1;
} else {
// length is 2 bytes
length = (bytes[pos] & 0x7f) + ((bytes[pos + 1] & 0xff) << 7);
offset = pos + 2;
}
return Arrays.equals(bytes, offset, offset + length, b.bytes, b.offset, b.offset + b.length);
}
private boolean shrink(int targetSize) {

View File

@ -16,14 +16,18 @@
*/
package org.apache.lucene.util;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.lucene.util.BytesRefHash.MaxBytesLengthExceededException;
import org.junit.Before;
import org.junit.Test;
@ -267,6 +271,69 @@ public class TestBytesRefHash extends LuceneTestCase {
}
}
@Test
public void testConcurrentAccessToBytesRefHash() throws Exception {
int num = atLeast(2);
for (int j = 0; j < num; j++) {
int numStrings = 797;
List<String> strings = new ArrayList<>(numStrings);
for (int i = 0; i < numStrings; i++) {
final String str = TestUtil.randomRealisticUnicodeString(random(), 1, 1000);
hash.add(new BytesRef(str));
assertTrue(strings.add(str));
}
int hashSize = hash.size();
AtomicInteger notFound = new AtomicInteger();
AtomicInteger notEquals = new AtomicInteger();
AtomicInteger wrongSize = new AtomicInteger();
int numThreads = atLeast(3);
CountDownLatch latch = new CountDownLatch(numThreads);
Thread[] threads = new Thread[numThreads];
for (int i = 0; i < threads.length; i++) {
int loops = atLeast(100);
threads[i] =
new Thread(
() -> {
BytesRef scratch = new BytesRef();
latch.countDown();
try {
latch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
for (int k = 0; k < loops; k++) {
BytesRef find = new BytesRef(strings.get(k % strings.size()));
int id = hash.find(find);
if (id < 0) {
notFound.incrementAndGet();
} else {
BytesRef get = hash.get(id, scratch);
if (!get.bytesEquals(find)) {
notEquals.incrementAndGet();
}
}
if (hash.size() != hashSize) {
wrongSize.incrementAndGet();
}
}
},
"t" + i);
}
for (Thread t : threads) t.start();
for (Thread t : threads) t.join();
assertEquals(0, notFound.get());
assertEquals(0, notEquals.get());
assertEquals(0, wrongSize.get());
hash.clear();
assertEquals(0, hash.size());
hash.reinit();
}
}
@Test(expected = MaxBytesLengthExceededException.class)
public void testLargeValue() {
int[] sizes =