SQL: Support for another form of filtered aggregator. (#4109)

* SQL: Support for another form of filtered aggregator.

* Fix comment, add test for MAX too.
This commit is contained in:
Gian Merlino 2017-03-27 15:22:36 -07:00 committed by Jonathan Wei
parent 73d9b31664
commit bbb61e638b
3 changed files with 35 additions and 9 deletions

View File

@ -25,6 +25,8 @@ import io.druid.segment.column.ValueType;
import io.druid.sql.calcite.schema.DruidSchema;
import io.druid.sql.calcite.schema.InformationSchema;
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.sql.type.SqlTypeName;
@ -181,4 +183,17 @@ public class Calcites
{
return new DateTime(0L, DateTimeZone.UTC).plusDays(date).withZoneRetainFields(timeZone);
}
/**
* Checks if a RexNode is a literal int or not. If this returns true, then {@code RexLiteral.intValue(literal)} can be
* used to get the value of the literal.
*
* @param rexNode the node
*
* @return true if this is an int
*/
public static boolean isIntLiteral(final RexNode rexNode)
{
return rexNode instanceof RexLiteral && SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName());
}
}

View File

@ -807,7 +807,8 @@ public class GroupByRules
input = foe;
} else if (rexNode.getKind() == SqlKind.CASE && ((RexCall) rexNode).getOperands().size() == 3) {
// Possibly a CASE-style filtered aggregation. Styles supported:
// A: SUM(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0)
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null)
// If the null and non-null args are switched, "flip" is set, which negates the filter.
@ -839,15 +840,15 @@ public class GroupByRules
forceCount = true;
input = null;
} else if (call.getAggregation().getKind() == SqlKind.SUM
&& arg1 instanceof RexLiteral
&& ((Number) RexLiteral.value(arg1)).intValue() == 1
&& arg2 instanceof RexLiteral
&& ((Number) RexLiteral.value(arg2)).intValue() == 0) {
&& Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
&& Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
// Case B
forceCount = true;
input = null;
} else if (RexLiteral.isNullLiteral(arg2)) {
// Maybe case A
} else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */
|| (kind == SqlKind.SUM
&& Calcites.isIntLiteral(arg2)
&& RexLiteral.intValue(arg2) == 0) /* Case A2 */) {
input = FieldOrExpression.fromRexNode(operatorTable, plannerContext, rowOrder, arg1);
if (input == null) {
return null;

View File

@ -1432,7 +1432,9 @@ public class CalciteQueryTest
+ "COUNT(CASE WHEN dim1 <> '1' THEN 'dummy' END), "
+ "SUM(CASE WHEN dim1 <> '1' THEN 1 ELSE 0 END), "
+ "SUM(cnt) filter(WHERE dim2 = 'a'), "
+ "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a') "
+ "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a'), "
+ "SUM(CASE WHEN dim1 <> '1' THEN cnt ELSE 0 END), "
+ "MAX(CASE WHEN dim1 <> '1' THEN cnt END) "
+ "FROM druid.foo",
ImmutableList.<Query>of(
Druids.newTimeseriesQueryBuilder()
@ -1474,13 +1476,21 @@ public class CalciteQueryTest
SELECTOR("dim2", "a", null),
NOT(SELECTOR("dim1", "1", null))
)
),
new FilteredAggregatorFactory(
new LongSumAggregatorFactory("a8", "cnt"),
NOT(SELECTOR("dim1", "1", null))
),
new FilteredAggregatorFactory(
new LongMaxAggregatorFactory("a9", "cnt"),
NOT(SELECTOR("dim1", "1", null))
)
))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1L, 5L, 1L, 5L, 5L, 5, 2L, 1L}
new Object[]{1L, 5L, 1L, 5L, 5L, 5, 2L, 1L, 5L, 1L}
)
);
}