diff --git a/server/src/main/java/io/druid/client/cache/ByteCountingLRUMap.java b/server/src/main/java/io/druid/client/cache/ByteCountingLRUMap.java index b351dfc6837..ba91b290d16 100644 --- a/server/src/main/java/io/druid/client/cache/ByteCountingLRUMap.java +++ b/server/src/main/java/io/druid/client/cache/ByteCountingLRUMap.java @@ -24,6 +24,7 @@ import com.metamx.common.logger.Logger; import java.nio.ByteBuffer; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; /** */ @@ -74,8 +75,7 @@ class ByteCountingLRUMap extends LinkedHashMap public byte[] put(ByteBuffer key, byte[] value) { numBytes += key.remaining() + value.length; - byte[] retVal = super.put(key, value); - return retVal; + return super.put(key, value); } @Override @@ -98,4 +98,31 @@ class ByteCountingLRUMap extends LinkedHashMap } return false; } + + @Override + public byte[] remove(Object key) + { + byte[] value = super.remove(key); + if(value != null) { + numBytes -= ((ByteBuffer)key).remaining() + value.length; + } + return value; + } + + /** + * We want keySet().iterator().remove() to account for object removal + * The underlying Map calls this.remove(key) so we do not need to override this + */ + @Override + public Set keySet() + { + return super.keySet(); + } + + @Override + public void clear() + { + numBytes = 0; + super.clear(); + } } diff --git a/server/src/test/java/io/druid/client/cache/ByteCountingLRUMapTest.java b/server/src/test/java/io/druid/client/cache/ByteCountingLRUMapTest.java index 5edf6a669e6..4a7143d3efa 100644 --- a/server/src/test/java/io/druid/client/cache/ByteCountingLRUMapTest.java +++ b/server/src/test/java/io/druid/client/cache/ByteCountingLRUMapTest.java @@ -24,6 +24,7 @@ import org.junit.Before; import org.junit.Test; import java.nio.ByteBuffer; +import java.util.Iterator; /** */ @@ -65,6 +66,18 @@ public class ByteCountingLRUMapTest assertMapValues(2, 101, 2); Assert.assertEquals(ByteBuffer.wrap(eightyEightVal), ByteBuffer.wrap(map.get(tenKey))); Assert.assertEquals(oneByte, ByteBuffer.wrap(map.get(twoByte))); + + Iterator it = map.keySet().iterator(); + while(it.hasNext()) { + ByteBuffer buf = it.next(); + if(buf.remaining() == 10) { + it.remove(); + } + } + assertMapValues(1, 3, 2); + + map.remove(twoByte); + assertMapValues(0, 0, 2); } private void assertMapValues(final int size, final int numBytes, final int evictionCount)