diff --git a/sql/src/main/java/io/druid/sql/calcite/expression/Expressions.java b/sql/src/main/java/io/druid/sql/calcite/expression/Expressions.java index 3261336b356..60d8d7139b2 100644 --- a/sql/src/main/java/io/druid/sql/calcite/expression/Expressions.java +++ b/sql/src/main/java/io/druid/sql/calcite/expression/Expressions.java @@ -164,7 +164,9 @@ public class Expressions } } else if (kind == SqlKind.LITERAL) { // Translate literal. - if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) { + if (RexLiteral.isNullLiteral(rexNode)) { + return DruidExpression.fromExpression(DruidExpression.nullLiteral()); + } else if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) { return DruidExpression.fromExpression(DruidExpression.numberLiteral((Number) RexLiteral.value(rexNode))); } else if (SqlTypeFamily.INTERVAL_DAY_TIME == sqlTypeName.getFamily()) { // Calcite represents DAY-TIME intervals in milliseconds. diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/CaseFilteredAggregatorRule.java b/sql/src/main/java/io/druid/sql/calcite/rule/CaseFilteredAggregatorRule.java index d51fee71d0d..4635addcbd3 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rule/CaseFilteredAggregatorRule.java +++ b/sql/src/main/java/io/druid/sql/calcite/rule/CaseFilteredAggregatorRule.java @@ -70,7 +70,7 @@ public class CaseFilteredAggregatorRule extends RelOptRule } for (AggregateCall aggregateCall : aggregate.getAggCallList()) { - if (isNonDistinctOneArgAggregateCall(aggregateCall) + if (isOneArgAggregateCall(aggregateCall) && isThreeArgCase(project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList())))) { return true; } @@ -97,21 +97,13 @@ public class CaseFilteredAggregatorRule extends RelOptRule for (AggregateCall aggregateCall : aggregate.getAggCallList()) { AggregateCall newCall = null; - if (isNonDistinctOneArgAggregateCall(aggregateCall)) { + if (isOneArgAggregateCall(aggregateCall)) { final RexNode rexNode = project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList())); - // Styles supported: - // - // A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null) - // A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM - // B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0) - // C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null) - // - // If the null and non-null args are switched, "flip" is set, which negates the filter. - if (isThreeArgCase(rexNode)) { final RexCall caseCall = (RexCall) rexNode; + // If one arg is null and the other is not, reverse them and set "flip", which negates the filter. final boolean flip = RexLiteral.isNullLiteral(caseCall.getOperands().get(1)) && !RexLiteral.isNullLiteral(caseCall.getOperands().get(2)); final RexNode arg1 = caseCall.getOperands().get(flip ? 2 : 1); @@ -126,6 +118,7 @@ public class CaseFilteredAggregatorRule extends RelOptRule ImmutableList.of(caseCall.getOperands().get(0)) ); + // Combine the CASE filter with an honest-to-goodness SQL FILTER, if the latter is present. if (aggregateCall.filterArg >= 0) { filter = rexBuilder.makeCall( booleanType, @@ -136,47 +129,72 @@ public class CaseFilteredAggregatorRule extends RelOptRule filter = filterFromCase; } - if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT - && arg1 instanceof RexLiteral - && !RexLiteral.isNullLiteral(arg1) - && RexLiteral.isNullLiteral(arg2)) { - // Case C - newProjects.add(filter); - newCall = AggregateCall.create( - SqlStdOperatorTable.COUNT, - false, - ImmutableList.of(), - newProjects.size() - 1, - aggregateCall.getType(), - aggregateCall.getName() - ); - } else if (aggregateCall.getAggregation().getKind() == SqlKind.SUM - && Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1 - && Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) { - // Case B - newProjects.add(filter); - newCall = AggregateCall.create( - SqlStdOperatorTable.COUNT, - false, - ImmutableList.of(), - newProjects.size() - 1, - typeFactory.createSqlType(SqlTypeName.BIGINT), - aggregateCall.getName() - ); - } else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */ - || (aggregateCall.getAggregation().getKind() == SqlKind.SUM - && Calcites.isIntLiteral(arg2) - && RexLiteral.intValue(arg2) == 0) /* Case A2 */) { - newProjects.add(arg1); - newProjects.add(filter); - newCall = AggregateCall.create( - aggregateCall.getAggregation(), - false, - ImmutableList.of(newProjects.size() - 2), - newProjects.size() - 1, - aggregateCall.getType(), - aggregateCall.getName() - ); + if (aggregateCall.isDistinct()) { + // Just one style supported: + // COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END) => COUNT(DISTINCT y) FILTER(WHERE x = 'foo') + + if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && RexLiteral.isNullLiteral(arg2)) { + newProjects.add(arg1); + newProjects.add(filter); + newCall = AggregateCall.create( + SqlStdOperatorTable.COUNT, + true, + ImmutableList.of(newProjects.size() - 2), + newProjects.size() - 1, + aggregateCall.getType(), + aggregateCall.getName() + ); + } + } else { + // Four styles supported: + // + // A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null) + // A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM + // B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0); must be SUM + // C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null) + + if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT + && arg1.isA(SqlKind.LITERAL) + && !RexLiteral.isNullLiteral(arg1) + && RexLiteral.isNullLiteral(arg2)) { + // Case C + newProjects.add(filter); + newCall = AggregateCall.create( + SqlStdOperatorTable.COUNT, + false, + ImmutableList.of(), + newProjects.size() - 1, + aggregateCall.getType(), + aggregateCall.getName() + ); + } else if (aggregateCall.getAggregation().getKind() == SqlKind.SUM + && Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1 + && Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) { + // Case B + newProjects.add(filter); + newCall = AggregateCall.create( + SqlStdOperatorTable.COUNT, + false, + ImmutableList.of(), + newProjects.size() - 1, + typeFactory.createSqlType(SqlTypeName.BIGINT), + aggregateCall.getName() + ); + } else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */ + || (aggregateCall.getAggregation().getKind() == SqlKind.SUM + && Calcites.isIntLiteral(arg2) + && RexLiteral.intValue(arg2) == 0) /* Case A2 */) { + newProjects.add(arg1); + newProjects.add(filter); + newCall = AggregateCall.create( + aggregateCall.getAggregation(), + false, + ImmutableList.of(newProjects.size() - 2), + newProjects.size() - 1, + aggregateCall.getType(), + aggregateCall.getName() + ); + } } } } @@ -211,9 +229,9 @@ public class CaseFilteredAggregatorRule extends RelOptRule } } - private static boolean isNonDistinctOneArgAggregateCall(final AggregateCall aggregateCall) + private static boolean isOneArgAggregateCall(final AggregateCall aggregateCall) { - return aggregateCall.getArgList().size() == 1 && !aggregateCall.isDistinct(); + return aggregateCall.getArgList().size() == 1; } private static boolean isThreeArgCase(final RexNode rexNode) diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java b/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java index 72b8bf97614..13089816a4a 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java +++ b/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java @@ -32,8 +32,6 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.SqlTypeName; import java.util.ArrayList; import java.util.List; @@ -63,8 +61,6 @@ public class GroupByRules ) { final DimFilter filter; - final SqlKind kind = call.getAggregation().getKind(); - final SqlTypeName outputType = call.getType().getSqlTypeName(); if (call.filterArg >= 0) { // AGG(xxx) FILTER(WHERE yyy) diff --git a/sql/src/main/java/io/druid/sql/calcite/table/RowSignature.java b/sql/src/main/java/io/druid/sql/calcite/table/RowSignature.java index a07ad23ab28..5f0a8e071ed 100644 --- a/sql/src/main/java/io/druid/sql/calcite/table/RowSignature.java +++ b/sql/src/main/java/io/druid/sql/calcite/table/RowSignature.java @@ -146,7 +146,7 @@ public class RowSignature */ public RelDataType getRelDataType(final RelDataTypeFactory typeFactory) { - final RelDataTypeFactory.FieldInfoBuilder builder = typeFactory.builder(); + final RelDataTypeFactory.Builder builder = typeFactory.builder(); for (final String columnName : columnNames) { final ValueType columnType = getColumnType(columnName); final RelDataType type; @@ -177,7 +177,10 @@ public class RowSignature break; case COMPLEX: // Loses information about exactly what kind of complex column this is. - type = typeFactory.createSqlType(SqlTypeName.OTHER); + type = typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.OTHER), + true + ); break; default: throw new ISE("WTF?! valueType[%s] not translatable?", columnType); diff --git a/sql/src/test/java/io/druid/sql/avatica/DruidAvaticaHandlerTest.java b/sql/src/test/java/io/druid/sql/avatica/DruidAvaticaHandlerTest.java index bce883d1c53..4e5b88f581c 100644 --- a/sql/src/test/java/io/druid/sql/avatica/DruidAvaticaHandlerTest.java +++ b/sql/src/test/java/io/druid/sql/avatica/DruidAvaticaHandlerTest.java @@ -446,7 +446,7 @@ public class DruidAvaticaHandlerTest Pair.of("COLUMN_NAME", "unique_dim1"), Pair.of("DATA_TYPE", Types.OTHER), Pair.of("TYPE_NAME", "OTHER"), - Pair.of("IS_NULLABLE", "NO") + Pair.of("IS_NULLABLE", "YES") ) ), getRows( @@ -529,7 +529,7 @@ public class DruidAvaticaHandlerTest Pair.of("COLUMN_NAME", "unique_dim1"), Pair.of("DATA_TYPE", Types.OTHER), Pair.of("TYPE_NAME", "OTHER"), - Pair.of("IS_NULLABLE", "NO") + Pair.of("IS_NULLABLE", "YES") ) ), getRows( diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java index 5fed009c6db..7590a1dbaf0 100644 --- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java @@ -407,7 +407,7 @@ public class CalciteQueryTest new Object[]{"dim2", "VARCHAR", "YES"}, new Object[]{"m1", "FLOAT", "NO"}, new Object[]{"m2", "DOUBLE", "NO"}, - new Object[]{"unique_dim1", "OTHER", "NO"} + new Object[]{"unique_dim1", "OTHER", "YES"} ) ); } @@ -437,12 +437,11 @@ public class CalciteQueryTest new Object[]{"dim2", "VARCHAR", "YES"}, new Object[]{"m1", "FLOAT", "NO"}, new Object[]{"m2", "DOUBLE", "NO"}, - new Object[]{"unique_dim1", "OTHER", "NO"} + new Object[]{"unique_dim1", "OTHER", "YES"} ) ); } - @Test public void testInformationSchemaColumnsOnView() throws Exception { @@ -2327,9 +2326,10 @@ public class CalciteQueryTest + "SUM(cnt) filter(WHERE dim2 = 'a'), " + "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a'), " + "SUM(CASE WHEN dim1 <> '1' THEN cnt ELSE 0 END), " - + "MAX(CASE WHEN dim1 <> '1' THEN cnt END) " + + "MAX(CASE WHEN dim1 <> '1' THEN cnt END), " + + "COUNT(DISTINCT CASE WHEN dim1 <> '1' THEN m1 END) " + "FROM druid.foo", - ImmutableList.of( + ImmutableList.of( Druids.newTimeseriesQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) .intervals(QSS(Filtration.eternity())) @@ -2380,13 +2380,23 @@ public class CalciteQueryTest new FilteredAggregatorFactory( new LongMaxAggregatorFactory("a9", "cnt"), NOT(SELECTOR("dim1", "1", null)) + ), + new FilteredAggregatorFactory( + new CardinalityAggregatorFactory( + "a10", + null, + DIMS(new DefaultDimensionSpec("m1", "m1", ValueType.FLOAT)), + false, + true + ), + NOT(SELECTOR("dim1", "1", null)) ) )) .context(TIMESERIES_CONTEXT_DEFAULT) .build() ), ImmutableList.of( - new Object[]{1L, 5L, 1L, 2L, 5L, 5L, 2L, 1L, 5L, 1L} + new Object[]{1L, 5L, 1L, 2L, 5L, 5L, 2L, 1L, 5L, 1L, 5L} ) ); } @@ -3455,6 +3465,57 @@ public class CalciteQueryTest ); } + @Test + public void testCountDistinctOfCaseWhen() throws Exception + { + testQuery( + "SELECT\n" + + "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN m1 END),\n" + + "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN dim1 END),\n" + + "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN unique_dim1 END)\n" + + "FROM druid.foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(QSS(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators( + AGGS( + new FilteredAggregatorFactory( + new CardinalityAggregatorFactory( + "a0", + null, + ImmutableList.of(new DefaultDimensionSpec("m1", "m1", ValueType.FLOAT)), + false, + true + ), + BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC) + ), + new FilteredAggregatorFactory( + new CardinalityAggregatorFactory( + "a1", + null, + ImmutableList.of(new DefaultDimensionSpec("dim1", "dim1", ValueType.STRING)), + false, + true + ), + BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC) + ), + new FilteredAggregatorFactory( + new HyperUniquesAggregatorFactory("a2", "unique_dim1", false, true), + BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC) + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{3L, 3L, 3L} + ) + ); + } + @Test public void testExactCountDistinct() throws Exception {