diff --git a/processing/src/main/java/org/apache/druid/query/CacheStrategy.java b/processing/src/main/java/org/apache/druid/query/CacheStrategy.java index 8b106a6f2b7..f93a3955aa6 100644 --- a/processing/src/main/java/org/apache/druid/query/CacheStrategy.java +++ b/processing/src/main/java/org/apache/druid/query/CacheStrategy.java @@ -22,8 +22,11 @@ package org.apache.druid.query; import com.fasterxml.jackson.core.type.TypeReference; import com.google.common.base.Function; import org.apache.druid.guice.annotations.ExtensionPoint; +import org.apache.druid.query.aggregation.AggregatorFactory; +import java.util.Iterator; import java.util.concurrent.ExecutorService; +import java.util.function.BiFunction; /** */ @@ -98,4 +101,31 @@ public interface CacheStrategy> { return pullFromCache(false); } + + /** + * Helper function used by TopN, GroupBy, Timeseries queries in {@link #pullFromCache(boolean)}. + * When using the result level cache, the agg values seen here are + * finalized values generated by AggregatorFactory.finalizeComputation(). + * These finalized values are deserialized from the cache as generic Objects, which will + * later be reserialized and returned to the user without further modification. + * Because the agg values are deserialized as generic Objects, the values are subject to the same + * type consistency issues handled by DimensionHandlerUtils.convertObjectToType() in the pullFromCache implementations + * for dimension values (e.g., a Float would become Double). + */ + static void fetchAggregatorsFromCache( + Iterator aggIter, + Iterator resultIter, + boolean isResultLevelCache, + BiFunction addToResultFunction + ) + { + while (aggIter.hasNext() && resultIter.hasNext()) { + final AggregatorFactory factory = aggIter.next(); + if (isResultLevelCache) { + addToResultFunction.apply(factory.getName(), resultIter.next()); + } else { + addToResultFunction.apply(factory.getName(), factory.deserialize(resultIter.next())); + } + } + } } diff --git a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChest.java index c94d427b1bd..e7d1e27efb4 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChest.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChest.java @@ -555,7 +555,7 @@ public class GroupByQueryQueryToolChest extends QueryToolChest event = Maps.newLinkedHashMap(); + final Map event = Maps.newLinkedHashMap(); Iterator dimsIter = dims.iterator(); while (dimsIter.hasNext() && results.hasNext()) { final DimensionSpec dimensionSpec = dimsIter.next(); @@ -566,12 +566,18 @@ public class GroupByQueryQueryToolChest extends QueryToolChest aggsIter = aggs.iterator(); - while (aggsIter.hasNext() && results.hasNext()) { - final AggregatorFactory factory = aggsIter.next(); - event.put(factory.getName(), factory.deserialize(results.next())); - } + + CacheStrategy.fetchAggregatorsFromCache( + aggsIter, + results, + isResultLevelCache, + (aggName, aggValueObject) -> { + event.put(aggName, aggValueObject); + return null; + } + ); + if (isResultLevelCache) { Iterator postItr = query.getPostAggregatorSpecs().iterator(); while (postItr.hasNext() && results.hasNext()) { diff --git a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChest.java index f8f5aa0c4c7..d625c318ce2 100644 --- a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChest.java +++ b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChest.java @@ -327,17 +327,23 @@ public class TimeseriesQueryQueryToolChest extends QueryToolChest apply(@Nullable Object input) { List results = (List) input; - Map retVal = Maps.newLinkedHashMap(); + final Map retVal = Maps.newLinkedHashMap(); Iterator aggsIter = aggs.iterator(); Iterator resultIter = results.iterator(); DateTime timestamp = granularity.toDateTime(((Number) resultIter.next()).longValue()); - while (aggsIter.hasNext() && resultIter.hasNext()) { - final AggregatorFactory factory = aggsIter.next(); - retVal.put(factory.getName(), factory.deserialize(resultIter.next())); - } + CacheStrategy.fetchAggregatorsFromCache( + aggsIter, + resultIter, + isResultLevelCache, + (aggName, aggValueObject) -> { + retVal.put(aggName, aggValueObject); + return null; + } + ); + if (isResultLevelCache) { Iterator postItr = query.getPostAggregatorSpecs().iterator(); while (postItr.hasNext() && resultIter.hasNext()) { diff --git a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryQueryToolChest.java index 2c3bd2b2f75..d87a178dc07 100644 --- a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryQueryToolChest.java +++ b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryQueryToolChest.java @@ -398,7 +398,7 @@ public class TopNQueryQueryToolChest extends QueryToolChest result = (List) inputIter.next(); - Map vals = Maps.newLinkedHashMap(); + final Map vals = Maps.newLinkedHashMap(); Iterator aggIter = aggs.iterator(); Iterator resultIter = result.iterator(); @@ -409,10 +409,15 @@ public class TopNQueryQueryToolChest extends QueryToolChest { + vals.put(aggName, aggValueObject); + return null; + } + ); for (PostAggregator postAgg : postAggs) { vals.put(postAgg.getName(), postAgg.compute(vals)); diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChestTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChestTest.java index 2bad8f82f7f..94842e08925 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChestTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChestTest.java @@ -19,15 +19,26 @@ package org.apache.druid.query.groupby; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; +import org.apache.druid.collections.SerializablePair; +import org.apache.druid.data.input.MapBasedRow; import org.apache.druid.data.input.Row; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.query.CacheStrategy; import org.apache.druid.query.QueryRunnerTestHelper; +import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.FloatSumAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; +import org.apache.druid.query.aggregation.SerializablePairLongString; +import org.apache.druid.query.aggregation.last.DoubleLastAggregatorFactory; +import org.apache.druid.query.aggregation.last.FloatLastAggregatorFactory; +import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory; +import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory; +import org.apache.druid.query.aggregation.post.ConstantPostAggregator; import org.apache.druid.query.aggregation.post.ExpressionPostAggregator; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.expression.TestExprMacroTable; @@ -46,10 +57,14 @@ import org.apache.druid.query.groupby.having.OrHavingSpec; import org.apache.druid.query.groupby.orderby.DefaultLimitSpec; import org.apache.druid.query.groupby.orderby.OrderByColumnSpec; import org.apache.druid.query.ordering.StringComparators; +import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.ValueType; import org.junit.Assert; import org.junit.Test; +import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class GroupByQueryQueryToolChestTest @@ -483,4 +498,143 @@ public class GroupByQueryQueryToolChestTest )); } + @Test + public void testCacheStrategy() throws Exception + { + doTestCacheStrategy(ValueType.STRING, "val1"); + doTestCacheStrategy(ValueType.FLOAT, 2.1f); + doTestCacheStrategy(ValueType.DOUBLE, 2.1d); + doTestCacheStrategy(ValueType.LONG, 2L); + } + + private AggregatorFactory getComplexAggregatorFactoryForValueType(final ValueType valueType) + { + switch (valueType) { + case LONG: + return new LongLastAggregatorFactory("complexMetric", "test"); + case DOUBLE: + return new DoubleLastAggregatorFactory("complexMetric", "test"); + case FLOAT: + return new FloatLastAggregatorFactory("complexMetric", "test"); + case STRING: + return new StringLastAggregatorFactory("complexMetric", "test", null); + default: + throw new IllegalArgumentException("bad valueType: " + valueType); + } + } + + private SerializablePair getIntermediateComplexValue(final ValueType valueType, final Object dimValue) + { + switch (valueType) { + case LONG: + case DOUBLE: + case FLOAT: + return new SerializablePair<>(123L, dimValue); + case STRING: + return new SerializablePairLongString(123L, (String) dimValue); + default: + throw new IllegalArgumentException("bad valueType: " + valueType); + } + } + + private void doTestCacheStrategy(final ValueType valueType, final Object dimValue) throws IOException + { + final GroupByQuery query1 = GroupByQuery + .builder() + .setDataSource(QueryRunnerTestHelper.dataSource) + .setQuerySegmentSpec(QueryRunnerTestHelper.firstToThird) + .setDimensions(Collections.singletonList( + new DefaultDimensionSpec("test", "test", valueType) + )) + .setAggregatorSpecs( + Arrays.asList( + QueryRunnerTestHelper.rowsCount, + getComplexAggregatorFactoryForValueType(valueType) + ) + ) + .setPostAggregatorSpecs( + ImmutableList.of(new ConstantPostAggregator("post", 10)) + ) + .setGranularity(QueryRunnerTestHelper.dayGran) + .build(); + + CacheStrategy strategy = + new GroupByQueryQueryToolChest(null, null).getCacheStrategy( + query1 + ); + + final Row result1 = new MapBasedRow( + // test timestamps that result in integer size millis + DateTimes.utc(123L), + ImmutableMap.of( + "test", dimValue, + "rows", 1, + "complexMetric", getIntermediateComplexValue(valueType, dimValue) + ) + ); + + Object preparedValue = strategy.prepareForSegmentLevelCache().apply( + result1 + ); + + ObjectMapper objectMapper = TestHelper.makeJsonMapper(); + Object fromCacheValue = objectMapper.readValue( + objectMapper.writeValueAsBytes(preparedValue), + strategy.getCacheObjectClazz() + ); + + Row fromCacheResult = strategy.pullFromSegmentLevelCache().apply(fromCacheValue); + + Assert.assertEquals(result1, fromCacheResult); + + final Row result2 = new MapBasedRow( + // test timestamps that result in integer size millis + DateTimes.utc(123L), + ImmutableMap.of( + "test", dimValue, + "rows", 1, + "complexMetric", dimValue, + "post", 10 + ) + ); + + // Please see the comments on aggregator serde and type handling in CacheStrategy.fetchAggregatorsFromCache() + final Row typeAdjustedResult2; + if (valueType == ValueType.FLOAT) { + typeAdjustedResult2 = new MapBasedRow( + DateTimes.utc(123L), + ImmutableMap.of( + "test", dimValue, + "rows", 1, + "complexMetric", 2.1d, + "post", 10 + ) + ); + } else if (valueType == ValueType.LONG) { + typeAdjustedResult2 = new MapBasedRow( + DateTimes.utc(123L), + ImmutableMap.of( + "test", dimValue, + "rows", 1, + "complexMetric", 2, + "post", 10 + ) + ); + } else { + typeAdjustedResult2 = result2; + } + + + Object preparedResultCacheValue = strategy.prepareForCache(true).apply( + result2 + ); + + Object fromResultCacheValue = objectMapper.readValue( + objectMapper.writeValueAsBytes(preparedResultCacheValue), + strategy.getCacheObjectClazz() + ); + + Row fromResultCacheResult = strategy.pullFromCache(true).apply(fromResultCacheValue); + Assert.assertEquals(typeAdjustedResult2, fromResultCacheResult); + } } diff --git a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChestTest.java b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChestTest.java index 6d07c59d269..304ac9b5ff1 100644 --- a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChestTest.java +++ b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChestTest.java @@ -32,6 +32,8 @@ import org.apache.druid.query.Result; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; +import org.apache.druid.query.aggregation.SerializablePairLongString; +import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; import org.apache.druid.query.aggregation.post.ConstantPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; @@ -77,7 +79,8 @@ public class TimeseriesQueryQueryToolChestTest Granularities.ALL, ImmutableList.of( new CountAggregatorFactory("metric1"), - new LongSumAggregatorFactory("metric0", "metric0") + new LongSumAggregatorFactory("metric0", "metric0"), + new StringLastAggregatorFactory("complexMetric", "test", null) ), ImmutableList.of(new ConstantPostAggregator("post", 10)), 0, @@ -89,7 +92,11 @@ public class TimeseriesQueryQueryToolChestTest // test timestamps that result in integer size millis DateTimes.utc(123L), new TimeseriesResultValue( - ImmutableMap.of("metric1", 2, "metric0", 3) + ImmutableMap.of( + "metric1", 2, + "metric0", 3, + "complexMetric", new SerializablePairLongString(123L, "val1") + ) ) ); @@ -109,7 +116,12 @@ public class TimeseriesQueryQueryToolChestTest // test timestamps that result in integer size millis DateTimes.utc(123L), new TimeseriesResultValue( - ImmutableMap.of("metric1", 2, "metric0", 3, "post", 10) + ImmutableMap.of( + "metric1", 2, + "metric0", 3, + "complexMetric", "val1", + "post", 10 + ) ) ); diff --git a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryQueryToolChestTest.java b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryQueryToolChestTest.java index 81079df3c14..191cc557f09 100644 --- a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryQueryToolChestTest.java +++ b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryQueryToolChestTest.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.druid.collections.CloseableStupidPool; +import org.apache.druid.collections.SerializablePair; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.granularity.Granularities; @@ -35,8 +36,14 @@ import org.apache.druid.query.QueryRunnerTestHelper; import org.apache.druid.query.Result; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.TestQueryRunners; +import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; +import org.apache.druid.query.aggregation.SerializablePairLongString; +import org.apache.druid.query.aggregation.last.DoubleLastAggregatorFactory; +import org.apache.druid.query.aggregation.last.FloatLastAggregatorFactory; +import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory; +import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; import org.apache.druid.query.aggregation.post.ConstantPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; @@ -269,6 +276,36 @@ public class TopNQueryQueryToolChestTest } } + private AggregatorFactory getComplexAggregatorFactoryForValueType(final ValueType valueType) + { + switch (valueType) { + case LONG: + return new LongLastAggregatorFactory("complexMetric", "test"); + case DOUBLE: + return new DoubleLastAggregatorFactory("complexMetric", "test"); + case FLOAT: + return new FloatLastAggregatorFactory("complexMetric", "test"); + case STRING: + return new StringLastAggregatorFactory("complexMetric", "test", null); + default: + throw new IllegalArgumentException("bad valueType: " + valueType); + } + } + + private SerializablePair getIntermediateComplexValue(final ValueType valueType, final Object dimValue) + { + switch (valueType) { + case LONG: + case DOUBLE: + case FLOAT: + return new SerializablePair<>(123L, dimValue); + case STRING: + return new SerializablePairLongString(123L, (String) dimValue); + default: + throw new IllegalArgumentException("bad valueType: " + valueType); + } + } + private void doTestCacheStrategy(final ValueType valueType, final Object dimValue) throws IOException { CacheStrategy, Object, TopNQuery> strategy = @@ -282,7 +319,10 @@ public class TopNQueryQueryToolChestTest new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("2015-01-01/2015-01-02"))), null, Granularities.ALL, - ImmutableList.of(new CountAggregatorFactory("metric1")), + ImmutableList.of( + new CountAggregatorFactory("metric1"), + getComplexAggregatorFactoryForValueType(valueType) + ), ImmutableList.of(new ConstantPostAggregator("post", 10)), null ) @@ -295,7 +335,8 @@ public class TopNQueryQueryToolChestTest Collections.singletonList( ImmutableMap.of( "test", dimValue, - "metric1", 2 + "metric1", 2, + "complexMetric", getIntermediateComplexValue(valueType, dimValue) ) ) ) @@ -323,12 +364,48 @@ public class TopNQueryQueryToolChestTest ImmutableMap.of( "test", dimValue, "metric1", 2, + "complexMetric", dimValue, "post", 10 ) ) ) ); + // Please see the comments on aggregator serde and type handling in CacheStrategy.fetchAggregatorsFromCache() + final Result typeAdjustedResult2; + if (valueType == ValueType.FLOAT) { + typeAdjustedResult2 = new Result<>( + DateTimes.utc(123L), + new TopNResultValue( + Collections.singletonList( + ImmutableMap.of( + "test", dimValue, + "metric1", 2, + "complexMetric", 2.1d, + "post", 10 + ) + ) + ) + ); + } else if (valueType == ValueType.LONG) { + typeAdjustedResult2 = new Result<>( + DateTimes.utc(123L), + new TopNResultValue( + Collections.singletonList( + ImmutableMap.of( + "test", dimValue, + "metric1", 2, + "complexMetric", 2, + "post", 10 + ) + ) + ) + ); + } else { + typeAdjustedResult2 = result2; + } + + Object preparedResultCacheValue = strategy.prepareForCache(true).apply( result2 ); @@ -339,7 +416,7 @@ public class TopNQueryQueryToolChestTest ); Result fromResultCacheResult = strategy.pullFromCache(true).apply(fromResultCacheValue); - Assert.assertEquals(result2, fromResultCacheResult); + Assert.assertEquals(typeAdjustedResult2, fromResultCacheResult); } static class MockQueryRunner implements QueryRunner>