diff --git a/extensions-core/histogram/src/main/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java index 6711c36f4b1..b9c8d3d3d46 100644 --- a/extensions-core/histogram/src/main/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java +++ b/extensions-core/histogram/src/main/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java @@ -93,15 +93,27 @@ public class QuantileSqlAggregator implements SqlAggregator project, aggregateCall.getArgList().get(1) ); - final float probability = ((Number) RexLiteral.value(probabilityArg)).floatValue(); + if (!probabilityArg.isA(SqlKind.LITERAL)) { + // Probability must be a literal in order to plan. + return null; + } + + final float probability = ((Number) RexLiteral.value(probabilityArg)).floatValue(); final int resolution; + if (aggregateCall.getArgList().size() >= 3) { final RexNode resolutionArg = Expressions.fromFieldAccess( rowSignature, project, aggregateCall.getArgList().get(2) ); + + if (!resolutionArg.isA(SqlKind.LITERAL)) { + // Resolution must be a literal in order to plan. + return null; + } + resolution = ((Number) RexLiteral.value(resolutionArg)).intValue(); } else { resolution = ApproximateHistogram.DEFAULT_HISTOGRAM_SIZE; diff --git a/extensions-core/histogram/src/test/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java b/extensions-core/histogram/src/test/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java index 95734ff0363..8fae927cef8 100644 --- a/extensions-core/histogram/src/test/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java +++ b/extensions-core/histogram/src/test/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import io.druid.java.util.common.granularity.Granularities; import io.druid.query.Druids; +import io.druid.query.QueryDataSource; import io.druid.query.aggregation.CountAggregatorFactory; import io.druid.query.aggregation.DoubleSumAggregatorFactory; import io.druid.query.aggregation.FilteredAggregatorFactory; @@ -32,9 +33,13 @@ import io.druid.query.aggregation.histogram.ApproximateHistogramAggregatorFactor import io.druid.query.aggregation.histogram.ApproximateHistogramDruidModule; import io.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory; import io.druid.query.aggregation.histogram.QuantilePostAggregator; +import io.druid.query.aggregation.post.ArithmeticPostAggregator; +import io.druid.query.aggregation.post.FieldAccessPostAggregator; +import io.druid.query.dimension.DefaultDimensionSpec; import io.druid.query.expression.TestExprMacroTable; import io.druid.query.filter.NotDimFilter; import io.druid.query.filter.SelectorDimFilter; +import io.druid.query.groupby.GroupByQuery; import io.druid.query.spec.MultipleIntervalSegmentSpec; import io.druid.segment.IndexBuilder; import io.druid.segment.QueryableIndex; @@ -291,4 +296,69 @@ public class QuantileSqlAggregatorTest ); } } + + @Test + public void testQuantileOnInnerQuery() throws Exception + { + try (final DruidPlanner planner = plannerFactory.createPlanner(null)) { + final String sql = "SELECT AVG(x), APPROX_QUANTILE(x, 0.98)\n" + + "FROM (SELECT dim2, SUM(m1) AS x FROM foo GROUP BY dim2)"; + + final PlannerResult plannerResult = planner.plan(sql); + + // Verify results + final List results = plannerResult.run().toList(); + final List expectedResults = ImmutableList.of( + new Object[]{7.0, 8.26386833190918} + ); + Assert.assertEquals(expectedResults.size(), results.size()); + for (int i = 0; i < expectedResults.size(); i++) { + Assert.assertArrayEquals(expectedResults.get(i), results.get(i)); + } + + // Verify query + Assert.assertEquals( + GroupByQuery.builder() + .setDataSource( + new QueryDataSource( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .setGranularity(Granularities.ALL) + .setDimensions(ImmutableList.of(new DefaultDimensionSpec("dim2", "d0"))) + .setAggregatorSpecs( + ImmutableList.of( + new DoubleSumAggregatorFactory("a0", "m1") + ) + ) + .setContext(ImmutableMap.of()) + .build() + ) + ) + .setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs(ImmutableList.of( + new DoubleSumAggregatorFactory("_a0:sum", "a0"), + new CountAggregatorFactory("_a0:count"), + new ApproximateHistogramAggregatorFactory("_a1:agg", "a0", null, null, null, null) + )) + .setPostAggregatorSpecs( + ImmutableList.of( + new ArithmeticPostAggregator( + "_a0", + "quotient", + ImmutableList.of( + new FieldAccessPostAggregator(null, "_a0:sum"), + new FieldAccessPostAggregator(null, "_a0:count") + ) + ), + new QuantilePostAggregator("_a1", "_a1:agg", 0.98f) + ) + ) + .setContext(ImmutableMap.of()) + .build(), + Iterables.getOnlyElement(queryLogHook.getRecordedQueries()) + ); + } + } }