SQL: Support CASE-style filtered count distinct. (#5047)

i.e., aggregations like COUNT(DISTINCT CASE WHEN x THEN y END). This
patch also changes complex columns to report as nullable, which
is required for them to type-check properly when used in these kinds
of filtered aggregations.
This commit is contained in:
Gian Merlino 2017-11-13 20:23:54 -08:00 committed by Fangjin Yang
parent 9ac150c23a
commit 6c0c858913
6 changed files with 149 additions and 69 deletions

View File

@ -164,7 +164,9 @@ public class Expressions
}
} else if (kind == SqlKind.LITERAL) {
// Translate literal.
if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) {
if (RexLiteral.isNullLiteral(rexNode)) {
return DruidExpression.fromExpression(DruidExpression.nullLiteral());
} else if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) {
return DruidExpression.fromExpression(DruidExpression.numberLiteral((Number) RexLiteral.value(rexNode)));
} else if (SqlTypeFamily.INTERVAL_DAY_TIME == sqlTypeName.getFamily()) {
// Calcite represents DAY-TIME intervals in milliseconds.

View File

@ -70,7 +70,7 @@ public class CaseFilteredAggregatorRule extends RelOptRule
}
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
if (isNonDistinctOneArgAggregateCall(aggregateCall)
if (isOneArgAggregateCall(aggregateCall)
&& isThreeArgCase(project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList())))) {
return true;
}
@ -97,21 +97,13 @@ public class CaseFilteredAggregatorRule extends RelOptRule
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
AggregateCall newCall = null;
if (isNonDistinctOneArgAggregateCall(aggregateCall)) {
if (isOneArgAggregateCall(aggregateCall)) {
final RexNode rexNode = project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList()));
// Styles supported:
//
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0)
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null)
//
// If the null and non-null args are switched, "flip" is set, which negates the filter.
if (isThreeArgCase(rexNode)) {
final RexCall caseCall = (RexCall) rexNode;
// If one arg is null and the other is not, reverse them and set "flip", which negates the filter.
final boolean flip = RexLiteral.isNullLiteral(caseCall.getOperands().get(1))
&& !RexLiteral.isNullLiteral(caseCall.getOperands().get(2));
final RexNode arg1 = caseCall.getOperands().get(flip ? 2 : 1);
@ -126,6 +118,7 @@ public class CaseFilteredAggregatorRule extends RelOptRule
ImmutableList.of(caseCall.getOperands().get(0))
);
// Combine the CASE filter with an honest-to-goodness SQL FILTER, if the latter is present.
if (aggregateCall.filterArg >= 0) {
filter = rexBuilder.makeCall(
booleanType,
@ -136,47 +129,72 @@ public class CaseFilteredAggregatorRule extends RelOptRule
filter = filterFromCase;
}
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT
&& arg1 instanceof RexLiteral
&& !RexLiteral.isNullLiteral(arg1)
&& RexLiteral.isNullLiteral(arg2)) {
// Case C
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
ImmutableList.of(),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
} else if (aggregateCall.getAggregation().getKind() == SqlKind.SUM
&& Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
&& Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
// Case B
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
ImmutableList.of(),
newProjects.size() - 1,
typeFactory.createSqlType(SqlTypeName.BIGINT),
aggregateCall.getName()
);
} else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */
|| (aggregateCall.getAggregation().getKind() == SqlKind.SUM
&& Calcites.isIntLiteral(arg2)
&& RexLiteral.intValue(arg2) == 0) /* Case A2 */) {
newProjects.add(arg1);
newProjects.add(filter);
newCall = AggregateCall.create(
aggregateCall.getAggregation(),
false,
ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
if (aggregateCall.isDistinct()) {
// Just one style supported:
// COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END) => COUNT(DISTINCT y) FILTER(WHERE x = 'foo')
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && RexLiteral.isNullLiteral(arg2)) {
newProjects.add(arg1);
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
true,
ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
}
} else {
// Four styles supported:
//
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0); must be SUM
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null)
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT
&& arg1.isA(SqlKind.LITERAL)
&& !RexLiteral.isNullLiteral(arg1)
&& RexLiteral.isNullLiteral(arg2)) {
// Case C
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
ImmutableList.of(),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
} else if (aggregateCall.getAggregation().getKind() == SqlKind.SUM
&& Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
&& Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
// Case B
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
ImmutableList.of(),
newProjects.size() - 1,
typeFactory.createSqlType(SqlTypeName.BIGINT),
aggregateCall.getName()
);
} else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */
|| (aggregateCall.getAggregation().getKind() == SqlKind.SUM
&& Calcites.isIntLiteral(arg2)
&& RexLiteral.intValue(arg2) == 0) /* Case A2 */) {
newProjects.add(arg1);
newProjects.add(filter);
newCall = AggregateCall.create(
aggregateCall.getAggregation(),
false,
ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
}
}
}
}
@ -211,9 +229,9 @@ public class CaseFilteredAggregatorRule extends RelOptRule
}
}
private static boolean isNonDistinctOneArgAggregateCall(final AggregateCall aggregateCall)
private static boolean isOneArgAggregateCall(final AggregateCall aggregateCall)
{
return aggregateCall.getArgList().size() == 1 && !aggregateCall.isDistinct();
return aggregateCall.getArgList().size() == 1;
}
private static boolean isThreeArgCase(final RexNode rexNode)

View File

@ -32,8 +32,6 @@ import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import java.util.ArrayList;
import java.util.List;
@ -63,8 +61,6 @@ public class GroupByRules
)
{
final DimFilter filter;
final SqlKind kind = call.getAggregation().getKind();
final SqlTypeName outputType = call.getType().getSqlTypeName();
if (call.filterArg >= 0) {
// AGG(xxx) FILTER(WHERE yyy)

View File

@ -146,7 +146,7 @@ public class RowSignature
*/
public RelDataType getRelDataType(final RelDataTypeFactory typeFactory)
{
final RelDataTypeFactory.FieldInfoBuilder builder = typeFactory.builder();
final RelDataTypeFactory.Builder builder = typeFactory.builder();
for (final String columnName : columnNames) {
final ValueType columnType = getColumnType(columnName);
final RelDataType type;
@ -177,7 +177,10 @@ public class RowSignature
break;
case COMPLEX:
// Loses information about exactly what kind of complex column this is.
type = typeFactory.createSqlType(SqlTypeName.OTHER);
type = typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.OTHER),
true
);
break;
default:
throw new ISE("WTF?! valueType[%s] not translatable?", columnType);

View File

@ -446,7 +446,7 @@ public class DruidAvaticaHandlerTest
Pair.of("COLUMN_NAME", "unique_dim1"),
Pair.of("DATA_TYPE", Types.OTHER),
Pair.of("TYPE_NAME", "OTHER"),
Pair.of("IS_NULLABLE", "NO")
Pair.of("IS_NULLABLE", "YES")
)
),
getRows(
@ -529,7 +529,7 @@ public class DruidAvaticaHandlerTest
Pair.of("COLUMN_NAME", "unique_dim1"),
Pair.of("DATA_TYPE", Types.OTHER),
Pair.of("TYPE_NAME", "OTHER"),
Pair.of("IS_NULLABLE", "NO")
Pair.of("IS_NULLABLE", "YES")
)
),
getRows(

View File

@ -407,7 +407,7 @@ public class CalciteQueryTest
new Object[]{"dim2", "VARCHAR", "YES"},
new Object[]{"m1", "FLOAT", "NO"},
new Object[]{"m2", "DOUBLE", "NO"},
new Object[]{"unique_dim1", "OTHER", "NO"}
new Object[]{"unique_dim1", "OTHER", "YES"}
)
);
}
@ -437,12 +437,11 @@ public class CalciteQueryTest
new Object[]{"dim2", "VARCHAR", "YES"},
new Object[]{"m1", "FLOAT", "NO"},
new Object[]{"m2", "DOUBLE", "NO"},
new Object[]{"unique_dim1", "OTHER", "NO"}
new Object[]{"unique_dim1", "OTHER", "YES"}
)
);
}
@Test
public void testInformationSchemaColumnsOnView() throws Exception
{
@ -2327,9 +2326,10 @@ public class CalciteQueryTest
+ "SUM(cnt) filter(WHERE dim2 = 'a'), "
+ "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a'), "
+ "SUM(CASE WHEN dim1 <> '1' THEN cnt ELSE 0 END), "
+ "MAX(CASE WHEN dim1 <> '1' THEN cnt END) "
+ "MAX(CASE WHEN dim1 <> '1' THEN cnt END), "
+ "COUNT(DISTINCT CASE WHEN dim1 <> '1' THEN m1 END) "
+ "FROM druid.foo",
ImmutableList.<Query>of(
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(QSS(Filtration.eternity()))
@ -2380,13 +2380,23 @@ public class CalciteQueryTest
new FilteredAggregatorFactory(
new LongMaxAggregatorFactory("a9", "cnt"),
NOT(SELECTOR("dim1", "1", null))
),
new FilteredAggregatorFactory(
new CardinalityAggregatorFactory(
"a10",
null,
DIMS(new DefaultDimensionSpec("m1", "m1", ValueType.FLOAT)),
false,
true
),
NOT(SELECTOR("dim1", "1", null))
)
))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1L, 5L, 1L, 2L, 5L, 5L, 2L, 1L, 5L, 1L}
new Object[]{1L, 5L, 1L, 2L, 5L, 5L, 2L, 1L, 5L, 1L, 5L}
)
);
}
@ -3455,6 +3465,57 @@ public class CalciteQueryTest
);
}
@Test
public void testCountDistinctOfCaseWhen() throws Exception
{
testQuery(
"SELECT\n"
+ "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN m1 END),\n"
+ "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN dim1 END),\n"
+ "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN unique_dim1 END)\n"
+ "FROM druid.foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(QSS(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(
AGGS(
new FilteredAggregatorFactory(
new CardinalityAggregatorFactory(
"a0",
null,
ImmutableList.of(new DefaultDimensionSpec("m1", "m1", ValueType.FLOAT)),
false,
true
),
BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC)
),
new FilteredAggregatorFactory(
new CardinalityAggregatorFactory(
"a1",
null,
ImmutableList.of(new DefaultDimensionSpec("dim1", "dim1", ValueType.STRING)),
false,
true
),
BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC)
),
new FilteredAggregatorFactory(
new HyperUniquesAggregatorFactory("a2", "unique_dim1", false, true),
BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC)
)
)
)
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{3L, 3L, 3L}
)
);
}
@Test
public void testExactCountDistinct() throws Exception
{