Avoid deadlocks in cache (#30461)

This commit avoids deadlocks in the cache by removing dangerous places
where we try to take the LRU lock while completing a future. Instead, we
block for the future to complete, and then execute the handling code
under the LRU lock (for example, eviction).
This commit is contained in:
Jason Tedor 2018-05-09 11:52:38 -04:00 committed by GitHub
parent 143df3a51d
commit 4defaa4f2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 44 deletions

View File

@ -206,34 +206,33 @@ public class Cache<K, V> {
*/
Entry<K, V> get(K key, long now, Predicate<Entry<K, V>> isExpired, Consumer<Entry<K, V>> onExpiration) {
CompletableFuture<Entry<K, V>> future;
Entry<K, V> entry = null;
try (ReleasableLock ignored = readLock.acquire()) {
future = map.get(key);
}
if (future != null) {
Entry<K, V> entry;
try {
entry = future.handle((ok, ex) -> {
if (ok != null && !isExpired.test(ok)) {
segmentStats.hit();
ok.accessTime = now;
return ok;
} else {
segmentStats.miss();
if (ok != null) {
assert isExpired.test(ok);
onExpiration.accept(ok);
}
return null;
}
}).get();
} catch (ExecutionException | InterruptedException e) {
entry = future.get();
} catch (ExecutionException e) {
assert future.isCompletedExceptionally();
segmentStats.miss();
return null;
} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
}
else {
if (isExpired.test(entry)) {
segmentStats.miss();
onExpiration.accept(entry);
return null;
} else {
segmentStats.hit();
entry.accessTime = now;
return entry;
}
} else {
segmentStats.miss();
return null;
}
return entry;
}
/**
@ -269,30 +268,18 @@ public class Cache<K, V> {
/**
* remove an entry from the segment
*
* @param key the key of the entry to remove from the cache
* @return the removed entry if there was one, otherwise null
* @param key the key of the entry to remove from the cache
* @param onRemoval a callback for the removed entry
*/
Entry<K, V> remove(K key) {
void remove(K key, Consumer<CompletableFuture<Entry<K, V>>> onRemoval) {
CompletableFuture<Entry<K, V>> future;
Entry<K, V> entry = null;
try (ReleasableLock ignored = writeLock.acquire()) {
future = map.remove(key);
}
if (future != null) {
try {
entry = future.handle((ok, ex) -> {
if (ok != null) {
segmentStats.eviction();
return ok;
} else {
return null;
}
}).get();
} catch (ExecutionException | InterruptedException e) {
throw new IllegalStateException(e);
}
segmentStats.eviction();
onRemoval.accept(future);
}
return entry;
}
private static class SegmentStats {
@ -476,12 +463,18 @@ public class Cache<K, V> {
*/
public void invalidate(K key) {
CacheSegment<K, V> segment = getCacheSegment(key);
Entry<K, V> entry = segment.remove(key);
if (entry != null) {
try (ReleasableLock ignored = lruLock.acquire()) {
delete(entry, RemovalNotification.RemovalReason.INVALIDATED);
segment.remove(key, f -> {
try {
Entry<K, V> entry = f.get();
try (ReleasableLock ignored = lruLock.acquire()) {
delete(entry, RemovalNotification.RemovalReason.INVALIDATED);
}
} catch (ExecutionException e) {
// ok
} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
}
});
}
/**
@ -632,7 +625,7 @@ public class Cache<K, V> {
Entry<K, V> entry = current;
if (entry != null) {
CacheSegment<K, V> segment = getCacheSegment(entry.key);
segment.remove(entry.key);
segment.remove(entry.key, f -> {});
try (ReleasableLock ignored = lruLock.acquire()) {
current = null;
delete(entry, RemovalNotification.RemovalReason.INVALIDATED);
@ -717,7 +710,7 @@ public class Cache<K, V> {
CacheSegment<K, V> segment = getCacheSegment(entry.key);
if (segment != null) {
segment.remove(entry.key);
segment.remove(entry.key, f -> {});
}
delete(entry, RemovalNotification.RemovalReason.EVICTED);
}

View File

@ -344,7 +344,6 @@ public class CacheTests extends ESTestCase {
assertEquals(numberOfEntries, cache.stats().getEvictions());
}
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/30428")
public void testComputeIfAbsentDeadlock() throws BrokenBarrierException, InterruptedException {
final int numberOfThreads = randomIntBetween(2, 32);
final Cache<Integer, String> cache =