diff --git a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryQueryToolChest.java index d87a178dc07..53af0908c6d 100644 --- a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryQueryToolChest.java +++ b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryQueryToolChest.java @@ -294,8 +294,7 @@ public class TopNQueryQueryToolChest extends QueryToolChest aggs = Lists.newArrayList(query.getAggregatorSpecs()); private final List postAggs = AggregatorUtil.pruneDependentPostAgg( query.getPostAggregatorSpecs(), - query.getTopNMetricSpec() - .getMetricName(query.getDimensionSpec()) + query.getTopNMetricSpec().getMetricName(query.getDimensionSpec()) ); @Override @@ -419,14 +418,15 @@ public class TopNQueryQueryToolChest extends QueryToolChest postItr = query.getPostAggregatorSpecs().iterator(); while (postItr.hasNext() && resultIter.hasNext()) { vals.put(postItr.next().getName(), resultIter.next()); } + } else { + for (PostAggregator postAgg : postAggs) { + vals.put(postAgg.getName(), postAgg.compute(vals)); + } } retVal.add(vals); } diff --git a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryQueryToolChestTest.java b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryQueryToolChestTest.java index 191cc557f09..f9080e783d2 100644 --- a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryQueryToolChestTest.java +++ b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryQueryToolChestTest.java @@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.druid.collections.CloseableStupidPool; import org.apache.druid.collections.SerializablePair; +import org.apache.druid.hll.HyperLogLogCollector; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.granularity.Granularities; @@ -40,6 +41,8 @@ import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.aggregation.SerializablePairLongString; +import org.apache.druid.query.aggregation.cardinality.CardinalityAggregator; +import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory; import org.apache.druid.query.aggregation.last.DoubleLastAggregatorFactory; import org.apache.druid.query.aggregation.last.FloatLastAggregatorFactory; import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory; @@ -47,6 +50,7 @@ import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; import org.apache.druid.query.aggregation.post.ConstantPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; +import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.segment.IncrementalIndexSegment; @@ -79,6 +83,15 @@ public class TopNQueryQueryToolChestTest doTestCacheStrategy(ValueType.LONG, 2L); } + @Test + public void testCacheStrategyOrderByPostAggs() throws Exception + { + doTestCacheStrategyOrderByPost(ValueType.STRING, "val1"); + doTestCacheStrategyOrderByPost(ValueType.FLOAT, 2.1f); + doTestCacheStrategyOrderByPost(ValueType.DOUBLE, 2.1d); + doTestCacheStrategyOrderByPost(ValueType.LONG, 2L); + } + @Test public void testComputeCacheKeyWithDifferentPostAgg() { @@ -306,6 +319,28 @@ public class TopNQueryQueryToolChestTest } } + private HyperLogLogCollector getIntermediateHllCollector(final ValueType valueType, final Object dimValue) + { + HyperLogLogCollector collector = HyperLogLogCollector.makeLatestCollector(); + switch (valueType) { + case LONG: + collector.add(CardinalityAggregator.hashFn.hashLong((Long) dimValue).asBytes()); + break; + case DOUBLE: + collector.add(CardinalityAggregator.hashFn.hashLong(Double.doubleToLongBits((Double) dimValue)).asBytes()); + break; + case FLOAT: + collector.add(CardinalityAggregator.hashFn.hashInt(Float.floatToIntBits((Float) dimValue)).asBytes()); + break; + case STRING: + collector.add(CardinalityAggregator.hashFn.hashUnencodedChars((String) dimValue).asBytes()); + break; + default: + throw new IllegalArgumentException("bad valueType: " + valueType); + } + return collector; + } + private void doTestCacheStrategy(final ValueType valueType, final Object dimValue) throws IOException { CacheStrategy, Object, TopNQuery> strategy = @@ -419,6 +454,102 @@ public class TopNQueryQueryToolChestTest Assert.assertEquals(typeAdjustedResult2, fromResultCacheResult); } + private void doTestCacheStrategyOrderByPost(final ValueType valueType, final Object dimValue) throws IOException + { + CacheStrategy, Object, TopNQuery> strategy = + new TopNQueryQueryToolChest(null, null).getCacheStrategy( + new TopNQuery( + new TableDataSource("dummy"), + VirtualColumns.EMPTY, + new DefaultDimensionSpec("test", "test", valueType), + new NumericTopNMetricSpec("post"), + 3, + new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("2015-01-01/2015-01-02"))), + null, + Granularities.ALL, + ImmutableList.of( + new HyperUniquesAggregatorFactory("metric1", "test", false, false), + new CountAggregatorFactory("metric2") + ), + ImmutableList.of( + new ArithmeticPostAggregator( + "post", + "+", + ImmutableList.of( + new FinalizingFieldAccessPostAggregator( + "metric1", + "metric1" + ), + new FieldAccessPostAggregator( + "metric2", + "metric2" + ) + ) + ) + ), + null + ) + ); + + HyperLogLogCollector collector = getIntermediateHllCollector(valueType, dimValue); + + final Result result1 = new Result<>( + // test timestamps that result in integer size millis + DateTimes.utc(123L), + new TopNResultValue( + Collections.singletonList( + ImmutableMap.of( + "test", dimValue, + "metric1", collector, + "metric2", 2, + "post", collector.estimateCardinality() + 2 + ) + ) + ) + ); + + Object preparedValue = strategy.prepareForSegmentLevelCache().apply( + result1 + ); + + ObjectMapper objectMapper = TestHelper.makeJsonMapper(); + Object fromCacheValue = objectMapper.readValue( + objectMapper.writeValueAsBytes(preparedValue), + strategy.getCacheObjectClazz() + ); + + Result fromCacheResult = strategy.pullFromSegmentLevelCache().apply(fromCacheValue); + + Assert.assertEquals(result1, fromCacheResult); + + final Result resultLevelCacheResult = new Result<>( + // test timestamps that result in integer size millis + DateTimes.utc(123L), + new TopNResultValue( + Collections.singletonList( + ImmutableMap.of( + "test", dimValue, + "metric1", collector.estimateCardinality(), + "metric2", 2, + "post", collector.estimateCardinality() + 2 + ) + ) + ) + ); + + Object preparedResultCacheValue = strategy.prepareForCache(true).apply( + resultLevelCacheResult + ); + + Object fromResultCacheValue = objectMapper.readValue( + objectMapper.writeValueAsBytes(preparedResultCacheValue), + strategy.getCacheObjectClazz() + ); + + Result fromResultCacheResult = strategy.pullFromCache(true).apply(fromResultCacheValue); + Assert.assertEquals(resultLevelCacheResult, fromResultCacheResult); + } + static class MockQueryRunner implements QueryRunner> { private final QueryRunner> runner;