diff --git a/core/src/main/java/org/elasticsearch/common/cache/Cache.java b/core/src/main/java/org/elasticsearch/common/cache/Cache.java index d2d6970fe9e..f16e7d8d25d 100644 --- a/core/src/main/java/org/elasticsearch/common/cache/Cache.java +++ b/core/src/main/java/org/elasticsearch/common/cache/Cache.java @@ -23,7 +23,10 @@ import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.util.concurrent.ReleasableLock; import java.util.*; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; import java.util.concurrent.atomic.LongAdder; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantLock; @@ -172,7 +175,8 @@ public class Cache { ReleasableLock readLock = new ReleasableLock(segmentLock.readLock()); ReleasableLock writeLock = new ReleasableLock(segmentLock.writeLock()); - Map> map = new HashMap<>(); + Map>> map = new HashMap<>(); + SegmentStats segmentStats = new SegmentStats(); /** @@ -183,13 +187,19 @@ public class Cache { * @return the entry if there was one, otherwise null */ Entry get(K key, long now) { - Entry entry; + Future> future; + Entry entry = null; try (ReleasableLock ignored = readLock.acquire()) { - entry = map.get(key); + future = map.get(key); } - if (entry != null) { + if (future != null) { segmentStats.hit(); - entry.accessTime = now; + try { + entry = future.get(); + entry.accessTime = now; + } catch (ExecutionException | InterruptedException e) { + throw new IllegalStateException("future should be a completedFuture for which get should not throw", e); + } } else { segmentStats.miss(); } @@ -208,7 +218,12 @@ public class Cache { Entry entry = new Entry<>(key, value, now); Entry existing; try (ReleasableLock ignored = writeLock.acquire()) { - existing = map.put(key, entry); + try { + Future> future = map.put(key, CompletableFuture.completedFuture(entry)); + existing = future != null ? future.get() : null; + } catch (ExecutionException | InterruptedException e) { + throw new IllegalStateException("future should be a completedFuture for which get should not throw", e); + } } return Tuple.tuple(entry, existing); } @@ -220,12 +235,18 @@ public class Cache { * @return the removed entry if there was one, otherwise null */ Entry remove(K key) { - Entry entry; + Future> future; + Entry entry = null; try (ReleasableLock ignored = writeLock.acquire()) { - entry = map.remove(key); + future = map.remove(key); } - if (entry != null) { + if (future != null) { segmentStats.eviction(); + try { + entry = future.get(); + } catch (ExecutionException | InterruptedException e) { + throw new IllegalStateException("future should be a completedFuture for which get should not throw", e); + } } return entry; } @@ -287,7 +308,8 @@ public class Cache { /** * If the specified key is not already associated with a value (or is mapped to null), attempts to compute its - * value using the given mapping function and enters it into this map unless null. + * value using the given mapping function and enters it into this map unless null. The load method for a given key + * will be invoked at most once. * * @param key the key whose associated value is to be returned or computed for if non-existant * @param loader the function to compute a value given a key @@ -299,25 +321,35 @@ public class Cache { long now = now(); V value = get(key, now); if (value == null) { + // we need to synchronize loading of a value for a given key; however, holding the segment lock while + // invoking load can lead to deadlock against another thread due to dependent key loading; therefore, we + // need a mechanism to ensure that load is invoked at most once, but we are not invoking load while holding + // the segment lock; to do this, we atomically put a future in the map that can load the value, and then + // get the value from this future on the thread that won the race to place the future into the segment map CacheSegment segment = getCacheSegment(key); - // we synchronize against the segment lock; this is to avoid a scenario where another thread is inserting - // a value for the same key via put which would not be observed on this thread without a mechanism - // synchronizing the two threads; it is possible that the segment lock will be too expensive here (it blocks - // readers too!) so consider this as a possible place to optimize should contention be observed + Future> future; + FutureTask> task = new FutureTask<>(() -> new Entry<>(key, loader.load(key), now)); try (ReleasableLock ignored = segment.writeLock.acquire()) { - value = get(key, now); - if (value == null) { - try { - value = loader.load(key); - } catch (Exception e) { - throw new ExecutionException(e); - } - if (value == null) { - throw new ExecutionException(new NullPointerException("loader returned a null value")); - } - put(key, value, now); - } + future = segment.map.putIfAbsent(key, task); } + if (future == null) { + future = task; + task.run(); + } + + Entry entry; + try { + entry = future.get(); + } catch (InterruptedException e) { + throw new ExecutionException(e); + } + if (entry.value == null) { + throw new ExecutionException(new NullPointerException("loader returned a null value")); + } + try (ReleasableLock ignored = lruLock.acquire()) { + promote(entry, now); + } + value = entry.value; } return value; } diff --git a/core/src/test/java/org/elasticsearch/common/cache/CacheTests.java b/core/src/test/java/org/elasticsearch/common/cache/CacheTests.java index d1481a5ad5b..4f64f0baca7 100644 --- a/core/src/test/java/org/elasticsearch/common/cache/CacheTests.java +++ b/core/src/test/java/org/elasticsearch/common/cache/CacheTests.java @@ -394,12 +394,12 @@ public class CacheTests extends ESTestCase { // randomly replace some entries, increasing the weight by 1 for each replacement, then count that the cache size // is correct public void testReplaceRecomputesSize() { - class Key { - private int key; + class Value { + private String value; private long weight; - public Key(int key, long weight) { - this.key = key; + public Value(String value, long weight) { + this.value = value; this.weight = weight; } @@ -408,20 +408,20 @@ public class CacheTests extends ESTestCase { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - Key key1 = (Key) o; + Value that = (Value) o; - return key == key1.key; + return value == that.value; } @Override public int hashCode() { - return key; + return value.hashCode(); } } - Cache cache = CacheBuilder.builder().weigher((k, s) -> k.weight).build(); + Cache cache = CacheBuilder.builder().weigher((k, s) -> s.weight).build(); for (int i = 0; i < numberOfEntries; i++) { - cache.put(new Key(i, 1), Integer.toString(i)); + cache.put(i, new Value(Integer.toString(i), 1)); } assertEquals(numberOfEntries, cache.count()); assertEquals(numberOfEntries, cache.weight()); @@ -429,7 +429,7 @@ public class CacheTests extends ESTestCase { for (int i = 0; i < numberOfEntries; i++) { if (rarely()) { replaced++; - cache.put(new Key(i, 2), Integer.toString(i)); + cache.put(i, new Value(Integer.toString(i), 2)); } } assertEquals(numberOfEntries, cache.count());