mirror of https://github.com/apache/druid.git
SQL: Support CASE-style filtered count distinct. (#5047)
i.e., aggregations like COUNT(DISTINCT CASE WHEN x THEN y END). This patch also changes complex columns to report as nullable, which is required for them to type-check properly when used in these kinds of filtered aggregations.
This commit is contained in:
parent
9ac150c23a
commit
6c0c858913
|
@ -164,7 +164,9 @@ public class Expressions
|
|||
}
|
||||
} else if (kind == SqlKind.LITERAL) {
|
||||
// Translate literal.
|
||||
if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) {
|
||||
if (RexLiteral.isNullLiteral(rexNode)) {
|
||||
return DruidExpression.fromExpression(DruidExpression.nullLiteral());
|
||||
} else if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) {
|
||||
return DruidExpression.fromExpression(DruidExpression.numberLiteral((Number) RexLiteral.value(rexNode)));
|
||||
} else if (SqlTypeFamily.INTERVAL_DAY_TIME == sqlTypeName.getFamily()) {
|
||||
// Calcite represents DAY-TIME intervals in milliseconds.
|
||||
|
|
|
@ -70,7 +70,7 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
|||
}
|
||||
|
||||
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
||||
if (isNonDistinctOneArgAggregateCall(aggregateCall)
|
||||
if (isOneArgAggregateCall(aggregateCall)
|
||||
&& isThreeArgCase(project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList())))) {
|
||||
return true;
|
||||
}
|
||||
|
@ -97,21 +97,13 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
|||
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
||||
AggregateCall newCall = null;
|
||||
|
||||
if (isNonDistinctOneArgAggregateCall(aggregateCall)) {
|
||||
if (isOneArgAggregateCall(aggregateCall)) {
|
||||
final RexNode rexNode = project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList()));
|
||||
|
||||
// Styles supported:
|
||||
//
|
||||
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
|
||||
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM
|
||||
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0)
|
||||
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null)
|
||||
//
|
||||
// If the null and non-null args are switched, "flip" is set, which negates the filter.
|
||||
|
||||
if (isThreeArgCase(rexNode)) {
|
||||
final RexCall caseCall = (RexCall) rexNode;
|
||||
|
||||
// If one arg is null and the other is not, reverse them and set "flip", which negates the filter.
|
||||
final boolean flip = RexLiteral.isNullLiteral(caseCall.getOperands().get(1))
|
||||
&& !RexLiteral.isNullLiteral(caseCall.getOperands().get(2));
|
||||
final RexNode arg1 = caseCall.getOperands().get(flip ? 2 : 1);
|
||||
|
@ -126,6 +118,7 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
|||
ImmutableList.of(caseCall.getOperands().get(0))
|
||||
);
|
||||
|
||||
// Combine the CASE filter with an honest-to-goodness SQL FILTER, if the latter is present.
|
||||
if (aggregateCall.filterArg >= 0) {
|
||||
filter = rexBuilder.makeCall(
|
||||
booleanType,
|
||||
|
@ -136,8 +129,32 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
|||
filter = filterFromCase;
|
||||
}
|
||||
|
||||
if (aggregateCall.isDistinct()) {
|
||||
// Just one style supported:
|
||||
// COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END) => COUNT(DISTINCT y) FILTER(WHERE x = 'foo')
|
||||
|
||||
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && RexLiteral.isNullLiteral(arg2)) {
|
||||
newProjects.add(arg1);
|
||||
newProjects.add(filter);
|
||||
newCall = AggregateCall.create(
|
||||
SqlStdOperatorTable.COUNT,
|
||||
true,
|
||||
ImmutableList.of(newProjects.size() - 2),
|
||||
newProjects.size() - 1,
|
||||
aggregateCall.getType(),
|
||||
aggregateCall.getName()
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Four styles supported:
|
||||
//
|
||||
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
|
||||
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM
|
||||
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0); must be SUM
|
||||
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null)
|
||||
|
||||
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT
|
||||
&& arg1 instanceof RexLiteral
|
||||
&& arg1.isA(SqlKind.LITERAL)
|
||||
&& !RexLiteral.isNullLiteral(arg1)
|
||||
&& RexLiteral.isNullLiteral(arg2)) {
|
||||
// Case C
|
||||
|
@ -180,6 +197,7 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newCalls.add(newCall == null ? aggregateCall : newCall);
|
||||
|
||||
|
@ -211,9 +229,9 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
|||
}
|
||||
}
|
||||
|
||||
private static boolean isNonDistinctOneArgAggregateCall(final AggregateCall aggregateCall)
|
||||
private static boolean isOneArgAggregateCall(final AggregateCall aggregateCall)
|
||||
{
|
||||
return aggregateCall.getArgList().size() == 1 && !aggregateCall.isDistinct();
|
||||
return aggregateCall.getArgList().size() == 1;
|
||||
}
|
||||
|
||||
private static boolean isThreeArgCase(final RexNode rexNode)
|
||||
|
|
|
@ -32,8 +32,6 @@ import org.apache.calcite.rel.core.AggregateCall;
|
|||
import org.apache.calcite.rel.core.Project;
|
||||
import org.apache.calcite.rex.RexBuilder;
|
||||
import org.apache.calcite.rex.RexNode;
|
||||
import org.apache.calcite.sql.SqlKind;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -63,8 +61,6 @@ public class GroupByRules
|
|||
)
|
||||
{
|
||||
final DimFilter filter;
|
||||
final SqlKind kind = call.getAggregation().getKind();
|
||||
final SqlTypeName outputType = call.getType().getSqlTypeName();
|
||||
|
||||
if (call.filterArg >= 0) {
|
||||
// AGG(xxx) FILTER(WHERE yyy)
|
||||
|
|
|
@ -146,7 +146,7 @@ public class RowSignature
|
|||
*/
|
||||
public RelDataType getRelDataType(final RelDataTypeFactory typeFactory)
|
||||
{
|
||||
final RelDataTypeFactory.FieldInfoBuilder builder = typeFactory.builder();
|
||||
final RelDataTypeFactory.Builder builder = typeFactory.builder();
|
||||
for (final String columnName : columnNames) {
|
||||
final ValueType columnType = getColumnType(columnName);
|
||||
final RelDataType type;
|
||||
|
@ -177,7 +177,10 @@ public class RowSignature
|
|||
break;
|
||||
case COMPLEX:
|
||||
// Loses information about exactly what kind of complex column this is.
|
||||
type = typeFactory.createSqlType(SqlTypeName.OTHER);
|
||||
type = typeFactory.createTypeWithNullability(
|
||||
typeFactory.createSqlType(SqlTypeName.OTHER),
|
||||
true
|
||||
);
|
||||
break;
|
||||
default:
|
||||
throw new ISE("WTF?! valueType[%s] not translatable?", columnType);
|
||||
|
|
|
@ -446,7 +446,7 @@ public class DruidAvaticaHandlerTest
|
|||
Pair.of("COLUMN_NAME", "unique_dim1"),
|
||||
Pair.of("DATA_TYPE", Types.OTHER),
|
||||
Pair.of("TYPE_NAME", "OTHER"),
|
||||
Pair.of("IS_NULLABLE", "NO")
|
||||
Pair.of("IS_NULLABLE", "YES")
|
||||
)
|
||||
),
|
||||
getRows(
|
||||
|
@ -529,7 +529,7 @@ public class DruidAvaticaHandlerTest
|
|||
Pair.of("COLUMN_NAME", "unique_dim1"),
|
||||
Pair.of("DATA_TYPE", Types.OTHER),
|
||||
Pair.of("TYPE_NAME", "OTHER"),
|
||||
Pair.of("IS_NULLABLE", "NO")
|
||||
Pair.of("IS_NULLABLE", "YES")
|
||||
)
|
||||
),
|
||||
getRows(
|
||||
|
|
|
@ -407,7 +407,7 @@ public class CalciteQueryTest
|
|||
new Object[]{"dim2", "VARCHAR", "YES"},
|
||||
new Object[]{"m1", "FLOAT", "NO"},
|
||||
new Object[]{"m2", "DOUBLE", "NO"},
|
||||
new Object[]{"unique_dim1", "OTHER", "NO"}
|
||||
new Object[]{"unique_dim1", "OTHER", "YES"}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
@ -437,12 +437,11 @@ public class CalciteQueryTest
|
|||
new Object[]{"dim2", "VARCHAR", "YES"},
|
||||
new Object[]{"m1", "FLOAT", "NO"},
|
||||
new Object[]{"m2", "DOUBLE", "NO"},
|
||||
new Object[]{"unique_dim1", "OTHER", "NO"}
|
||||
new Object[]{"unique_dim1", "OTHER", "YES"}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testInformationSchemaColumnsOnView() throws Exception
|
||||
{
|
||||
|
@ -2327,9 +2326,10 @@ public class CalciteQueryTest
|
|||
+ "SUM(cnt) filter(WHERE dim2 = 'a'), "
|
||||
+ "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a'), "
|
||||
+ "SUM(CASE WHEN dim1 <> '1' THEN cnt ELSE 0 END), "
|
||||
+ "MAX(CASE WHEN dim1 <> '1' THEN cnt END) "
|
||||
+ "MAX(CASE WHEN dim1 <> '1' THEN cnt END), "
|
||||
+ "COUNT(DISTINCT CASE WHEN dim1 <> '1' THEN m1 END) "
|
||||
+ "FROM druid.foo",
|
||||
ImmutableList.<Query>of(
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE1)
|
||||
.intervals(QSS(Filtration.eternity()))
|
||||
|
@ -2380,13 +2380,23 @@ public class CalciteQueryTest
|
|||
new FilteredAggregatorFactory(
|
||||
new LongMaxAggregatorFactory("a9", "cnt"),
|
||||
NOT(SELECTOR("dim1", "1", null))
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new CardinalityAggregatorFactory(
|
||||
"a10",
|
||||
null,
|
||||
DIMS(new DefaultDimensionSpec("m1", "m1", ValueType.FLOAT)),
|
||||
false,
|
||||
true
|
||||
),
|
||||
NOT(SELECTOR("dim1", "1", null))
|
||||
)
|
||||
))
|
||||
.context(TIMESERIES_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{1L, 5L, 1L, 2L, 5L, 5L, 2L, 1L, 5L, 1L}
|
||||
new Object[]{1L, 5L, 1L, 2L, 5L, 5L, 2L, 1L, 5L, 1L, 5L}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
@ -3455,6 +3465,57 @@ public class CalciteQueryTest
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCountDistinctOfCaseWhen() throws Exception
|
||||
{
|
||||
testQuery(
|
||||
"SELECT\n"
|
||||
+ "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN m1 END),\n"
|
||||
+ "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN dim1 END),\n"
|
||||
+ "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN unique_dim1 END)\n"
|
||||
+ "FROM druid.foo",
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE1)
|
||||
.intervals(QSS(Filtration.eternity()))
|
||||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
AGGS(
|
||||
new FilteredAggregatorFactory(
|
||||
new CardinalityAggregatorFactory(
|
||||
"a0",
|
||||
null,
|
||||
ImmutableList.of(new DefaultDimensionSpec("m1", "m1", ValueType.FLOAT)),
|
||||
false,
|
||||
true
|
||||
),
|
||||
BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC)
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new CardinalityAggregatorFactory(
|
||||
"a1",
|
||||
null,
|
||||
ImmutableList.of(new DefaultDimensionSpec("dim1", "dim1", ValueType.STRING)),
|
||||
false,
|
||||
true
|
||||
),
|
||||
BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC)
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new HyperUniquesAggregatorFactory("a2", "unique_dim1", false, true),
|
||||
BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC)
|
||||
)
|
||||
)
|
||||
)
|
||||
.context(TIMESERIES_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{3L, 3L, 3L}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExactCountDistinct() throws Exception
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue