mirror of https://github.com/apache/druid.git
SQL: Support CASE-style filtered count distinct. (#5047)
i.e., aggregations like COUNT(DISTINCT CASE WHEN x THEN y END). This patch also changes complex columns to report as nullable, which is required for them to type-check properly when used in these kinds of filtered aggregations.
This commit is contained in:
parent
9ac150c23a
commit
6c0c858913
|
@ -164,7 +164,9 @@ public class Expressions
|
||||||
}
|
}
|
||||||
} else if (kind == SqlKind.LITERAL) {
|
} else if (kind == SqlKind.LITERAL) {
|
||||||
// Translate literal.
|
// Translate literal.
|
||||||
if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) {
|
if (RexLiteral.isNullLiteral(rexNode)) {
|
||||||
|
return DruidExpression.fromExpression(DruidExpression.nullLiteral());
|
||||||
|
} else if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) {
|
||||||
return DruidExpression.fromExpression(DruidExpression.numberLiteral((Number) RexLiteral.value(rexNode)));
|
return DruidExpression.fromExpression(DruidExpression.numberLiteral((Number) RexLiteral.value(rexNode)));
|
||||||
} else if (SqlTypeFamily.INTERVAL_DAY_TIME == sqlTypeName.getFamily()) {
|
} else if (SqlTypeFamily.INTERVAL_DAY_TIME == sqlTypeName.getFamily()) {
|
||||||
// Calcite represents DAY-TIME intervals in milliseconds.
|
// Calcite represents DAY-TIME intervals in milliseconds.
|
||||||
|
|
|
@ -70,7 +70,7 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
||||||
}
|
}
|
||||||
|
|
||||||
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
||||||
if (isNonDistinctOneArgAggregateCall(aggregateCall)
|
if (isOneArgAggregateCall(aggregateCall)
|
||||||
&& isThreeArgCase(project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList())))) {
|
&& isThreeArgCase(project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList())))) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -97,21 +97,13 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
||||||
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
||||||
AggregateCall newCall = null;
|
AggregateCall newCall = null;
|
||||||
|
|
||||||
if (isNonDistinctOneArgAggregateCall(aggregateCall)) {
|
if (isOneArgAggregateCall(aggregateCall)) {
|
||||||
final RexNode rexNode = project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList()));
|
final RexNode rexNode = project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList()));
|
||||||
|
|
||||||
// Styles supported:
|
|
||||||
//
|
|
||||||
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
|
|
||||||
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM
|
|
||||||
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0)
|
|
||||||
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null)
|
|
||||||
//
|
|
||||||
// If the null and non-null args are switched, "flip" is set, which negates the filter.
|
|
||||||
|
|
||||||
if (isThreeArgCase(rexNode)) {
|
if (isThreeArgCase(rexNode)) {
|
||||||
final RexCall caseCall = (RexCall) rexNode;
|
final RexCall caseCall = (RexCall) rexNode;
|
||||||
|
|
||||||
|
// If one arg is null and the other is not, reverse them and set "flip", which negates the filter.
|
||||||
final boolean flip = RexLiteral.isNullLiteral(caseCall.getOperands().get(1))
|
final boolean flip = RexLiteral.isNullLiteral(caseCall.getOperands().get(1))
|
||||||
&& !RexLiteral.isNullLiteral(caseCall.getOperands().get(2));
|
&& !RexLiteral.isNullLiteral(caseCall.getOperands().get(2));
|
||||||
final RexNode arg1 = caseCall.getOperands().get(flip ? 2 : 1);
|
final RexNode arg1 = caseCall.getOperands().get(flip ? 2 : 1);
|
||||||
|
@ -126,6 +118,7 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
||||||
ImmutableList.of(caseCall.getOperands().get(0))
|
ImmutableList.of(caseCall.getOperands().get(0))
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Combine the CASE filter with an honest-to-goodness SQL FILTER, if the latter is present.
|
||||||
if (aggregateCall.filterArg >= 0) {
|
if (aggregateCall.filterArg >= 0) {
|
||||||
filter = rexBuilder.makeCall(
|
filter = rexBuilder.makeCall(
|
||||||
booleanType,
|
booleanType,
|
||||||
|
@ -136,47 +129,72 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
||||||
filter = filterFromCase;
|
filter = filterFromCase;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT
|
if (aggregateCall.isDistinct()) {
|
||||||
&& arg1 instanceof RexLiteral
|
// Just one style supported:
|
||||||
&& !RexLiteral.isNullLiteral(arg1)
|
// COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END) => COUNT(DISTINCT y) FILTER(WHERE x = 'foo')
|
||||||
&& RexLiteral.isNullLiteral(arg2)) {
|
|
||||||
// Case C
|
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && RexLiteral.isNullLiteral(arg2)) {
|
||||||
newProjects.add(filter);
|
newProjects.add(arg1);
|
||||||
newCall = AggregateCall.create(
|
newProjects.add(filter);
|
||||||
SqlStdOperatorTable.COUNT,
|
newCall = AggregateCall.create(
|
||||||
false,
|
SqlStdOperatorTable.COUNT,
|
||||||
ImmutableList.of(),
|
true,
|
||||||
newProjects.size() - 1,
|
ImmutableList.of(newProjects.size() - 2),
|
||||||
aggregateCall.getType(),
|
newProjects.size() - 1,
|
||||||
aggregateCall.getName()
|
aggregateCall.getType(),
|
||||||
);
|
aggregateCall.getName()
|
||||||
} else if (aggregateCall.getAggregation().getKind() == SqlKind.SUM
|
);
|
||||||
&& Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
|
}
|
||||||
&& Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
|
} else {
|
||||||
// Case B
|
// Four styles supported:
|
||||||
newProjects.add(filter);
|
//
|
||||||
newCall = AggregateCall.create(
|
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
|
||||||
SqlStdOperatorTable.COUNT,
|
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM
|
||||||
false,
|
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0); must be SUM
|
||||||
ImmutableList.of(),
|
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null)
|
||||||
newProjects.size() - 1,
|
|
||||||
typeFactory.createSqlType(SqlTypeName.BIGINT),
|
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT
|
||||||
aggregateCall.getName()
|
&& arg1.isA(SqlKind.LITERAL)
|
||||||
);
|
&& !RexLiteral.isNullLiteral(arg1)
|
||||||
} else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */
|
&& RexLiteral.isNullLiteral(arg2)) {
|
||||||
|| (aggregateCall.getAggregation().getKind() == SqlKind.SUM
|
// Case C
|
||||||
&& Calcites.isIntLiteral(arg2)
|
newProjects.add(filter);
|
||||||
&& RexLiteral.intValue(arg2) == 0) /* Case A2 */) {
|
newCall = AggregateCall.create(
|
||||||
newProjects.add(arg1);
|
SqlStdOperatorTable.COUNT,
|
||||||
newProjects.add(filter);
|
false,
|
||||||
newCall = AggregateCall.create(
|
ImmutableList.of(),
|
||||||
aggregateCall.getAggregation(),
|
newProjects.size() - 1,
|
||||||
false,
|
aggregateCall.getType(),
|
||||||
ImmutableList.of(newProjects.size() - 2),
|
aggregateCall.getName()
|
||||||
newProjects.size() - 1,
|
);
|
||||||
aggregateCall.getType(),
|
} else if (aggregateCall.getAggregation().getKind() == SqlKind.SUM
|
||||||
aggregateCall.getName()
|
&& Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
|
||||||
);
|
&& Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
|
||||||
|
// Case B
|
||||||
|
newProjects.add(filter);
|
||||||
|
newCall = AggregateCall.create(
|
||||||
|
SqlStdOperatorTable.COUNT,
|
||||||
|
false,
|
||||||
|
ImmutableList.of(),
|
||||||
|
newProjects.size() - 1,
|
||||||
|
typeFactory.createSqlType(SqlTypeName.BIGINT),
|
||||||
|
aggregateCall.getName()
|
||||||
|
);
|
||||||
|
} else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */
|
||||||
|
|| (aggregateCall.getAggregation().getKind() == SqlKind.SUM
|
||||||
|
&& Calcites.isIntLiteral(arg2)
|
||||||
|
&& RexLiteral.intValue(arg2) == 0) /* Case A2 */) {
|
||||||
|
newProjects.add(arg1);
|
||||||
|
newProjects.add(filter);
|
||||||
|
newCall = AggregateCall.create(
|
||||||
|
aggregateCall.getAggregation(),
|
||||||
|
false,
|
||||||
|
ImmutableList.of(newProjects.size() - 2),
|
||||||
|
newProjects.size() - 1,
|
||||||
|
aggregateCall.getType(),
|
||||||
|
aggregateCall.getName()
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -211,9 +229,9 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static boolean isNonDistinctOneArgAggregateCall(final AggregateCall aggregateCall)
|
private static boolean isOneArgAggregateCall(final AggregateCall aggregateCall)
|
||||||
{
|
{
|
||||||
return aggregateCall.getArgList().size() == 1 && !aggregateCall.isDistinct();
|
return aggregateCall.getArgList().size() == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static boolean isThreeArgCase(final RexNode rexNode)
|
private static boolean isThreeArgCase(final RexNode rexNode)
|
||||||
|
|
|
@ -32,8 +32,6 @@ import org.apache.calcite.rel.core.AggregateCall;
|
||||||
import org.apache.calcite.rel.core.Project;
|
import org.apache.calcite.rel.core.Project;
|
||||||
import org.apache.calcite.rex.RexBuilder;
|
import org.apache.calcite.rex.RexBuilder;
|
||||||
import org.apache.calcite.rex.RexNode;
|
import org.apache.calcite.rex.RexNode;
|
||||||
import org.apache.calcite.sql.SqlKind;
|
|
||||||
import org.apache.calcite.sql.type.SqlTypeName;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -63,8 +61,6 @@ public class GroupByRules
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
final DimFilter filter;
|
final DimFilter filter;
|
||||||
final SqlKind kind = call.getAggregation().getKind();
|
|
||||||
final SqlTypeName outputType = call.getType().getSqlTypeName();
|
|
||||||
|
|
||||||
if (call.filterArg >= 0) {
|
if (call.filterArg >= 0) {
|
||||||
// AGG(xxx) FILTER(WHERE yyy)
|
// AGG(xxx) FILTER(WHERE yyy)
|
||||||
|
|
|
@ -146,7 +146,7 @@ public class RowSignature
|
||||||
*/
|
*/
|
||||||
public RelDataType getRelDataType(final RelDataTypeFactory typeFactory)
|
public RelDataType getRelDataType(final RelDataTypeFactory typeFactory)
|
||||||
{
|
{
|
||||||
final RelDataTypeFactory.FieldInfoBuilder builder = typeFactory.builder();
|
final RelDataTypeFactory.Builder builder = typeFactory.builder();
|
||||||
for (final String columnName : columnNames) {
|
for (final String columnName : columnNames) {
|
||||||
final ValueType columnType = getColumnType(columnName);
|
final ValueType columnType = getColumnType(columnName);
|
||||||
final RelDataType type;
|
final RelDataType type;
|
||||||
|
@ -177,7 +177,10 @@ public class RowSignature
|
||||||
break;
|
break;
|
||||||
case COMPLEX:
|
case COMPLEX:
|
||||||
// Loses information about exactly what kind of complex column this is.
|
// Loses information about exactly what kind of complex column this is.
|
||||||
type = typeFactory.createSqlType(SqlTypeName.OTHER);
|
type = typeFactory.createTypeWithNullability(
|
||||||
|
typeFactory.createSqlType(SqlTypeName.OTHER),
|
||||||
|
true
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new ISE("WTF?! valueType[%s] not translatable?", columnType);
|
throw new ISE("WTF?! valueType[%s] not translatable?", columnType);
|
||||||
|
|
|
@ -446,7 +446,7 @@ public class DruidAvaticaHandlerTest
|
||||||
Pair.of("COLUMN_NAME", "unique_dim1"),
|
Pair.of("COLUMN_NAME", "unique_dim1"),
|
||||||
Pair.of("DATA_TYPE", Types.OTHER),
|
Pair.of("DATA_TYPE", Types.OTHER),
|
||||||
Pair.of("TYPE_NAME", "OTHER"),
|
Pair.of("TYPE_NAME", "OTHER"),
|
||||||
Pair.of("IS_NULLABLE", "NO")
|
Pair.of("IS_NULLABLE", "YES")
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
getRows(
|
getRows(
|
||||||
|
@ -529,7 +529,7 @@ public class DruidAvaticaHandlerTest
|
||||||
Pair.of("COLUMN_NAME", "unique_dim1"),
|
Pair.of("COLUMN_NAME", "unique_dim1"),
|
||||||
Pair.of("DATA_TYPE", Types.OTHER),
|
Pair.of("DATA_TYPE", Types.OTHER),
|
||||||
Pair.of("TYPE_NAME", "OTHER"),
|
Pair.of("TYPE_NAME", "OTHER"),
|
||||||
Pair.of("IS_NULLABLE", "NO")
|
Pair.of("IS_NULLABLE", "YES")
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
getRows(
|
getRows(
|
||||||
|
|
|
@ -407,7 +407,7 @@ public class CalciteQueryTest
|
||||||
new Object[]{"dim2", "VARCHAR", "YES"},
|
new Object[]{"dim2", "VARCHAR", "YES"},
|
||||||
new Object[]{"m1", "FLOAT", "NO"},
|
new Object[]{"m1", "FLOAT", "NO"},
|
||||||
new Object[]{"m2", "DOUBLE", "NO"},
|
new Object[]{"m2", "DOUBLE", "NO"},
|
||||||
new Object[]{"unique_dim1", "OTHER", "NO"}
|
new Object[]{"unique_dim1", "OTHER", "YES"}
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -437,12 +437,11 @@ public class CalciteQueryTest
|
||||||
new Object[]{"dim2", "VARCHAR", "YES"},
|
new Object[]{"dim2", "VARCHAR", "YES"},
|
||||||
new Object[]{"m1", "FLOAT", "NO"},
|
new Object[]{"m1", "FLOAT", "NO"},
|
||||||
new Object[]{"m2", "DOUBLE", "NO"},
|
new Object[]{"m2", "DOUBLE", "NO"},
|
||||||
new Object[]{"unique_dim1", "OTHER", "NO"}
|
new Object[]{"unique_dim1", "OTHER", "YES"}
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testInformationSchemaColumnsOnView() throws Exception
|
public void testInformationSchemaColumnsOnView() throws Exception
|
||||||
{
|
{
|
||||||
|
@ -2327,9 +2326,10 @@ public class CalciteQueryTest
|
||||||
+ "SUM(cnt) filter(WHERE dim2 = 'a'), "
|
+ "SUM(cnt) filter(WHERE dim2 = 'a'), "
|
||||||
+ "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a'), "
|
+ "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a'), "
|
||||||
+ "SUM(CASE WHEN dim1 <> '1' THEN cnt ELSE 0 END), "
|
+ "SUM(CASE WHEN dim1 <> '1' THEN cnt ELSE 0 END), "
|
||||||
+ "MAX(CASE WHEN dim1 <> '1' THEN cnt END) "
|
+ "MAX(CASE WHEN dim1 <> '1' THEN cnt END), "
|
||||||
|
+ "COUNT(DISTINCT CASE WHEN dim1 <> '1' THEN m1 END) "
|
||||||
+ "FROM druid.foo",
|
+ "FROM druid.foo",
|
||||||
ImmutableList.<Query>of(
|
ImmutableList.of(
|
||||||
Druids.newTimeseriesQueryBuilder()
|
Druids.newTimeseriesQueryBuilder()
|
||||||
.dataSource(CalciteTests.DATASOURCE1)
|
.dataSource(CalciteTests.DATASOURCE1)
|
||||||
.intervals(QSS(Filtration.eternity()))
|
.intervals(QSS(Filtration.eternity()))
|
||||||
|
@ -2380,13 +2380,23 @@ public class CalciteQueryTest
|
||||||
new FilteredAggregatorFactory(
|
new FilteredAggregatorFactory(
|
||||||
new LongMaxAggregatorFactory("a9", "cnt"),
|
new LongMaxAggregatorFactory("a9", "cnt"),
|
||||||
NOT(SELECTOR("dim1", "1", null))
|
NOT(SELECTOR("dim1", "1", null))
|
||||||
|
),
|
||||||
|
new FilteredAggregatorFactory(
|
||||||
|
new CardinalityAggregatorFactory(
|
||||||
|
"a10",
|
||||||
|
null,
|
||||||
|
DIMS(new DefaultDimensionSpec("m1", "m1", ValueType.FLOAT)),
|
||||||
|
false,
|
||||||
|
true
|
||||||
|
),
|
||||||
|
NOT(SELECTOR("dim1", "1", null))
|
||||||
)
|
)
|
||||||
))
|
))
|
||||||
.context(TIMESERIES_CONTEXT_DEFAULT)
|
.context(TIMESERIES_CONTEXT_DEFAULT)
|
||||||
.build()
|
.build()
|
||||||
),
|
),
|
||||||
ImmutableList.of(
|
ImmutableList.of(
|
||||||
new Object[]{1L, 5L, 1L, 2L, 5L, 5L, 2L, 1L, 5L, 1L}
|
new Object[]{1L, 5L, 1L, 2L, 5L, 5L, 2L, 1L, 5L, 1L, 5L}
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -3455,6 +3465,57 @@ public class CalciteQueryTest
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCountDistinctOfCaseWhen() throws Exception
|
||||||
|
{
|
||||||
|
testQuery(
|
||||||
|
"SELECT\n"
|
||||||
|
+ "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN m1 END),\n"
|
||||||
|
+ "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN dim1 END),\n"
|
||||||
|
+ "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN unique_dim1 END)\n"
|
||||||
|
+ "FROM druid.foo",
|
||||||
|
ImmutableList.of(
|
||||||
|
Druids.newTimeseriesQueryBuilder()
|
||||||
|
.dataSource(CalciteTests.DATASOURCE1)
|
||||||
|
.intervals(QSS(Filtration.eternity()))
|
||||||
|
.granularity(Granularities.ALL)
|
||||||
|
.aggregators(
|
||||||
|
AGGS(
|
||||||
|
new FilteredAggregatorFactory(
|
||||||
|
new CardinalityAggregatorFactory(
|
||||||
|
"a0",
|
||||||
|
null,
|
||||||
|
ImmutableList.of(new DefaultDimensionSpec("m1", "m1", ValueType.FLOAT)),
|
||||||
|
false,
|
||||||
|
true
|
||||||
|
),
|
||||||
|
BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC)
|
||||||
|
),
|
||||||
|
new FilteredAggregatorFactory(
|
||||||
|
new CardinalityAggregatorFactory(
|
||||||
|
"a1",
|
||||||
|
null,
|
||||||
|
ImmutableList.of(new DefaultDimensionSpec("dim1", "dim1", ValueType.STRING)),
|
||||||
|
false,
|
||||||
|
true
|
||||||
|
),
|
||||||
|
BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC)
|
||||||
|
),
|
||||||
|
new FilteredAggregatorFactory(
|
||||||
|
new HyperUniquesAggregatorFactory("a2", "unique_dim1", false, true),
|
||||||
|
BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.context(TIMESERIES_CONTEXT_DEFAULT)
|
||||||
|
.build()
|
||||||
|
),
|
||||||
|
ImmutableList.of(
|
||||||
|
new Object[]{3L, 3L, 3L}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testExactCountDistinct() throws Exception
|
public void testExactCountDistinct() throws Exception
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue