From bbb61e638b391d298e338c17a02d65a6fd18b477 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Mon, 27 Mar 2017 15:22:36 -0700 Subject: [PATCH] SQL: Support for another form of filtered aggregator. (#4109) * SQL: Support for another form of filtered aggregator. * Fix comment, add test for MAX too. --- .../io/druid/sql/calcite/planner/Calcites.java | 15 +++++++++++++++ .../io/druid/sql/calcite/rule/GroupByRules.java | 15 ++++++++------- .../io/druid/sql/calcite/CalciteQueryTest.java | 14 ++++++++++++-- 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/sql/src/main/java/io/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/io/druid/sql/calcite/planner/Calcites.java index 3cdbf05d478..59f8943994c 100644 --- a/sql/src/main/java/io/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/io/druid/sql/calcite/planner/Calcites.java @@ -25,6 +25,8 @@ import io.druid.segment.column.ValueType; import io.druid.sql.calcite.schema.DruidSchema; import io.druid.sql.calcite.schema.InformationSchema; import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.sql.type.SqlTypeName; @@ -181,4 +183,17 @@ public class Calcites { return new DateTime(0L, DateTimeZone.UTC).plusDays(date).withZoneRetainFields(timeZone); } + + /** + * Checks if a RexNode is a literal int or not. If this returns true, then {@code RexLiteral.intValue(literal)} can be + * used to get the value of the literal. + * + * @param rexNode the node + * + * @return true if this is an int + */ + public static boolean isIntLiteral(final RexNode rexNode) + { + return rexNode instanceof RexLiteral && SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName()); + } } diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java b/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java index 05acca45faf..c3455bf5252 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java +++ b/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java @@ -807,7 +807,8 @@ public class GroupByRules input = foe; } else if (rexNode.getKind() == SqlKind.CASE && ((RexCall) rexNode).getOperands().size() == 3) { // Possibly a CASE-style filtered aggregation. Styles supported: - // A: SUM(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null) + // A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null) + // A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM // B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0) // C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null) // If the null and non-null args are switched, "flip" is set, which negates the filter. @@ -839,15 +840,15 @@ public class GroupByRules forceCount = true; input = null; } else if (call.getAggregation().getKind() == SqlKind.SUM - && arg1 instanceof RexLiteral - && ((Number) RexLiteral.value(arg1)).intValue() == 1 - && arg2 instanceof RexLiteral - && ((Number) RexLiteral.value(arg2)).intValue() == 0) { + && Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1 + && Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) { // Case B forceCount = true; input = null; - } else if (RexLiteral.isNullLiteral(arg2)) { - // Maybe case A + } else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */ + || (kind == SqlKind.SUM + && Calcites.isIntLiteral(arg2) + && RexLiteral.intValue(arg2) == 0) /* Case A2 */) { input = FieldOrExpression.fromRexNode(operatorTable, plannerContext, rowOrder, arg1); if (input == null) { return null; diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java index 12b7ee2097c..54c123fbb2c 100644 --- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java @@ -1432,7 +1432,9 @@ public class CalciteQueryTest + "COUNT(CASE WHEN dim1 <> '1' THEN 'dummy' END), " + "SUM(CASE WHEN dim1 <> '1' THEN 1 ELSE 0 END), " + "SUM(cnt) filter(WHERE dim2 = 'a'), " - + "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a') " + + "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a'), " + + "SUM(CASE WHEN dim1 <> '1' THEN cnt ELSE 0 END), " + + "MAX(CASE WHEN dim1 <> '1' THEN cnt END) " + "FROM druid.foo", ImmutableList.of( Druids.newTimeseriesQueryBuilder() @@ -1474,13 +1476,21 @@ public class CalciteQueryTest SELECTOR("dim2", "a", null), NOT(SELECTOR("dim1", "1", null)) ) + ), + new FilteredAggregatorFactory( + new LongSumAggregatorFactory("a8", "cnt"), + NOT(SELECTOR("dim1", "1", null)) + ), + new FilteredAggregatorFactory( + new LongMaxAggregatorFactory("a9", "cnt"), + NOT(SELECTOR("dim1", "1", null)) ) )) .context(TIMESERIES_CONTEXT_DEFAULT) .build() ), ImmutableList.of( - new Object[]{1L, 5L, 1L, 5L, 5L, 5, 2L, 1L} + new Object[]{1L, 5L, 1L, 5L, 5L, 5, 2L, 1L, 5L, 1L} ) ); }