diff --git a/client/src/main/java/com/metamx/druid/client/CachingClusteredClient.java b/client/src/main/java/com/metamx/druid/client/CachingClusteredClient.java index 625aef4d0e1..d69f91e6e14 100644 --- a/client/src/main/java/com/metamx/druid/client/CachingClusteredClient.java +++ b/client/src/main/java/com/metamx/druid/client/CachingClusteredClient.java @@ -20,8 +20,6 @@ package com.metamx.druid.client; import com.google.common.base.Function; -import com.google.common.base.Predicate; -import com.google.common.base.Predicates; import com.google.common.base.Supplier; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; @@ -30,6 +28,7 @@ import com.google.common.collect.Iterators; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Ordering; +import com.google.common.collect.Sets; import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.metamx.common.ISE; import com.metamx.common.Pair; @@ -66,6 +65,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.Executors; /** @@ -134,9 +134,8 @@ public class CachingClusteredClient implements QueryRunner return Sequences.empty(); } - Map, CacheBroker.NamedKey> segments = Maps.newLinkedHashMap(); - - final byte[] queryCacheKey = (strategy != null) ? strategy.computeCacheKey(query) : null; + // build set of segments to query + Set> segments = Sets.newLinkedHashSet(); for (Interval interval : rewrittenQuery.getIntervals()) { List> serversLookup = timeline.lookup(interval); @@ -148,43 +147,44 @@ public class CachingClusteredClient implements QueryRunner holder.getInterval(), holder.getVersion(), chunk.getChunkNumber() ); - segments.put( - Pair.of(selector, descriptor), - queryCacheKey == null ? null : - computeSegmentCacheKey(selector.getSegment().getIdentifier(), descriptor, queryCacheKey) - ); + segments.add(Pair.of(selector, descriptor)); } } } - Map cachedValues = cacheBroker.getBulk( - Iterables.filter(segments.values(), Predicates.notNull()) - ); + final byte[] queryCacheKey; + if(strategy != null) { + queryCacheKey = strategy.computeCacheKey(query); + } else { + queryCacheKey = null; + } - for(Map.Entry, CacheBroker.NamedKey> entry : segments.entrySet()) { - Pair segment = entry.getKey(); - CacheBroker.NamedKey segmentCacheKey = entry.getValue(); + // Pull cached segments from cache and remove from set of segments to query + if(useCache && queryCacheKey != null) { + Map, CacheBroker.NamedKey> cacheKeys = Maps.newHashMap(); + for(Pair e : segments) { + cacheKeys.put(e, computeSegmentCacheKey(e.lhs.getSegment().getIdentifier(), e.rhs, queryCacheKey)); + } - final ServerSelector selector = segment.lhs; - final SegmentDescriptor descriptor = segment.rhs; - final Interval segmentQueryInterval = descriptor.getInterval(); + Map cachedValues = cacheBroker.getBulk(cacheKeys.values()); - final byte[] cachedValue = segmentCacheKey == null ? null : cachedValues.get(segmentCacheKey); + for(Map.Entry, CacheBroker.NamedKey> entry : cacheKeys.entrySet()) { + Pair segment = entry.getKey(); + CacheBroker.NamedKey segmentCacheKey = entry.getValue(); - if (useCache && cachedValue != null) { - cachedResults.add(Pair.of(segmentQueryInterval.getStart(), cachedValue)); - } else { - final DruidServer server = selector.pick(); - List descriptors = serverSegments.get(server); + final ServerSelector selector = segment.lhs; + final SegmentDescriptor descriptor = segment.rhs; + final Interval segmentQueryInterval = descriptor.getInterval(); - if (descriptors == null) { - descriptors = Lists.newArrayList(); - serverSegments.put(server, descriptors); + final byte[] cachedValue = cachedValues.get(segmentCacheKey); + + if (cachedValue != null) { + cachedResults.add(Pair.of(segmentQueryInterval.getStart(), cachedValue)); + + // remove cached segment from set of segments to query + segments.remove(segment); } - - descriptors.add(descriptor); - - if(segmentCacheKey != null) { + else { final String segmentIdentifier = selector.getSegment().getIdentifier(); cachePopulatorMap.put( String.format("%s_%s", segmentIdentifier, segmentQueryInterval), @@ -194,6 +194,19 @@ public class CachingClusteredClient implements QueryRunner } } + // Compile list of all segments not pulled from cache + for(Pair segment : segments) { + final DruidServer server = segment.lhs.pick(); + List descriptors = serverSegments.get(server); + + if (descriptors == null) { + descriptors = Lists.newArrayList(); + serverSegments.put(server, descriptors); + } + + descriptors.add(segment.rhs); + } + return new LazySequence( new Supplier>()