mirror of https://github.com/apache/druid.git
Restore usage of filtered SUM (#17378)
This commit is contained in:
parent
05c3cbce08
commit
1a38434d8d
|
@ -625,6 +625,15 @@ public class QueryContext
|
|||
);
|
||||
}
|
||||
|
||||
public boolean isExtendedFilteredSumRewrite()
|
||||
{
|
||||
return getBoolean(
|
||||
QueryContexts.EXTENDED_FILTERED_SUM_REWRITE_ENABLED,
|
||||
QueryContexts.DEFAULT_EXTENDED_FILTERED_SUM_REWRITE_ENABLED
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
public QueryResourceId getQueryResourceId()
|
||||
{
|
||||
return new QueryResourceId(getString(QueryContexts.QUERY_RESOURCE_ID));
|
||||
|
|
|
@ -89,6 +89,22 @@ public class QueryContexts
|
|||
public static final String UNCOVERED_INTERVALS_LIMIT_KEY = "uncoveredIntervalsLimit";
|
||||
public static final String MIN_TOP_N_THRESHOLD = "minTopNThreshold";
|
||||
public static final String CATALOG_VALIDATION_ENABLED = "catalogValidationEnabled";
|
||||
/**
|
||||
* Context parameter to enable/disable the extended filtered sum rewrite logic.
|
||||
*
|
||||
* Controls the rewrite of:
|
||||
* <pre>
|
||||
* SUM(CASE WHEN COND THEN COL1 ELSE 0 END)
|
||||
* to
|
||||
* SUM(COL1) FILTER (COND)
|
||||
* </pre>
|
||||
* managed by {@link DruidAggregateCaseToFilterRule}. Defaults to true for performance,
|
||||
* but may produce incorrect results when the condition never matches (expected 0).
|
||||
* This is for testing and can be removed once a correct and high-performance rewrite
|
||||
* is implemented.
|
||||
*/
|
||||
public static final String EXTENDED_FILTERED_SUM_REWRITE_ENABLED = "extendedFilteredSumRewrite";
|
||||
|
||||
|
||||
// projection context keys
|
||||
public static final String NO_PROJECTIONS = "noProjections";
|
||||
|
@ -139,6 +155,7 @@ public class QueryContexts
|
|||
public static final boolean DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING = false;
|
||||
public static final boolean DEFAULT_CATALOG_VALIDATION_ENABLED = true;
|
||||
public static final boolean DEFAULT_USE_NESTED_FOR_UNKNOWN_TYPE_IN_SUBQUERY = false;
|
||||
public static final boolean DEFAULT_EXTENDED_FILTERED_SUM_REWRITE_ENABLED = true;
|
||||
|
||||
|
||||
@SuppressWarnings("unused") // Used by Jackson serialization
|
||||
|
|
|
@ -394,6 +394,17 @@ public class QueryContextTest
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExtendedFilteredSumRewrite()
|
||||
{
|
||||
assertTrue(QueryContext.empty().isExtendedFilteredSumRewrite());
|
||||
assertFalse(
|
||||
QueryContext
|
||||
.of(ImmutableMap.of(QueryContexts.EXTENDED_FILTERED_SUM_REWRITE_ENABLED, false))
|
||||
.isExtendedFilteredSumRewrite()
|
||||
);
|
||||
}
|
||||
|
||||
// This test is a bit silly. It is retained because another test uses the
|
||||
// LegacyContextQuery test.
|
||||
@Test
|
||||
|
|
|
@ -54,6 +54,7 @@ import org.apache.druid.sql.calcite.external.ExternalTableScanRule;
|
|||
import org.apache.druid.sql.calcite.rule.AggregatePullUpLookupRule;
|
||||
import org.apache.druid.sql.calcite.rule.CaseToCoalesceRule;
|
||||
import org.apache.druid.sql.calcite.rule.CoalesceLookupRule;
|
||||
import org.apache.druid.sql.calcite.rule.DruidAggregateCaseToFilterRule;
|
||||
import org.apache.druid.sql.calcite.rule.DruidLogicalValuesRule;
|
||||
import org.apache.druid.sql.calcite.rule.DruidRelToDruidRule;
|
||||
import org.apache.druid.sql.calcite.rule.DruidRules;
|
||||
|
@ -119,7 +120,6 @@ public class CalciteRulesManager
|
|||
CoreRules.FILTER_PROJECT_TRANSPOSE,
|
||||
CoreRules.JOIN_PUSH_EXPRESSIONS,
|
||||
CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT,
|
||||
CoreRules.AGGREGATE_CASE_TO_FILTER,
|
||||
CoreRules.FILTER_AGGREGATE_TRANSPOSE,
|
||||
CoreRules.PROJECT_WINDOW_TRANSPOSE,
|
||||
CoreRules.MATCH,
|
||||
|
@ -495,6 +495,7 @@ public class CalciteRulesManager
|
|||
rules.addAll(BASE_RULES);
|
||||
rules.addAll(ABSTRACT_RULES);
|
||||
rules.addAll(ABSTRACT_RELATIONAL_RULES);
|
||||
rules.add(new DruidAggregateCaseToFilterRule(plannerContext.queryContext().isExtendedFilteredSumRewrite()));
|
||||
rules.addAll(configurableRuleSet(plannerContext));
|
||||
|
||||
if (plannerContext.getJoinAlgorithm().requiresSubquery()) {
|
||||
|
|
|
@ -0,0 +1,349 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.sql.calcite.rule;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.apache.calcite.plan.RelOptCluster;
|
||||
import org.apache.calcite.plan.RelOptRule;
|
||||
import org.apache.calcite.plan.RelOptRuleCall;
|
||||
import org.apache.calcite.rel.RelCollations;
|
||||
import org.apache.calcite.rel.core.Aggregate;
|
||||
import org.apache.calcite.rel.core.AggregateCall;
|
||||
import org.apache.calcite.rel.core.Project;
|
||||
import org.apache.calcite.rel.rules.AggregateCaseToFilterRule;
|
||||
import org.apache.calcite.rel.rules.SubstitutionRule;
|
||||
import org.apache.calcite.rel.type.RelDataType;
|
||||
import org.apache.calcite.rel.type.RelDataTypeFactory;
|
||||
import org.apache.calcite.rex.RexBuilder;
|
||||
import org.apache.calcite.rex.RexCall;
|
||||
import org.apache.calcite.rex.RexLiteral;
|
||||
import org.apache.calcite.rex.RexNode;
|
||||
import org.apache.calcite.sql.SqlKind;
|
||||
import org.apache.calcite.sql.SqlPostfixOperator;
|
||||
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.calcite.tools.RelBuilder;
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Druid extension of {@link AggregateCaseToFilterRule}.
|
||||
*
|
||||
* Turning on extendedFilteredSumRewrite enables rewrites of:
|
||||
* <pre>
|
||||
* SUM(CASE WHEN COND THEN COL1 ELSE 0 END)
|
||||
* </pre>
|
||||
* to:
|
||||
* <pre>
|
||||
* SUM(COL1) FILTER (WHERE COND)
|
||||
* </pre>
|
||||
* <p>
|
||||
* This rewrite improves performance but introduces a known inconsistency when
|
||||
* the condition never matches, as the expected result (0) is replaced with `null`.
|
||||
* <p>
|
||||
* Example behavior:
|
||||
* <pre>
|
||||
* +-----------------+--------------+----------+------+--------------+
|
||||
* | input row count | cond matches | valueCol | orig | filtered-SUM |
|
||||
* +-----------------+--------------+----------+------+--------------+
|
||||
* | 0 | * | * | null | null |
|
||||
* | >0 | none | * | 0 | null |
|
||||
* | >0 | all | null | null | null |
|
||||
* | >0 | N>0 | 1 | N | N |
|
||||
* +-----------------+--------------+----------+------+--------------+
|
||||
* </pre>
|
||||
*/
|
||||
public class DruidAggregateCaseToFilterRule extends RelOptRule implements SubstitutionRule
|
||||
{
|
||||
private boolean extendedFilteredSumRewrite;
|
||||
|
||||
public DruidAggregateCaseToFilterRule(boolean extendedFilteredSumRewrite)
|
||||
{
|
||||
super(operand(Aggregate.class, operand(Project.class, any())));
|
||||
this.extendedFilteredSumRewrite = extendedFilteredSumRewrite;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean matches(final RelOptRuleCall call)
|
||||
{
|
||||
final Aggregate aggregate = call.rel(0);
|
||||
final Project project = call.rel(1);
|
||||
|
||||
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
||||
final int singleArg = soleArgument(aggregateCall);
|
||||
if (singleArg >= 0
|
||||
&& isThreeArgCase(project.getProjects().get(singleArg))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMatch(RelOptRuleCall call)
|
||||
{
|
||||
final Aggregate aggregate = call.rel(0);
|
||||
final Project project = call.rel(1);
|
||||
final List<AggregateCall> newCalls = new ArrayList<>(aggregate.getAggCallList().size());
|
||||
final List<RexNode> newProjects = new ArrayList<>(project.getProjects());
|
||||
|
||||
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
||||
AggregateCall newCall = transform(aggregateCall, project, newProjects);
|
||||
|
||||
if (newCall == null) {
|
||||
newCalls.add(aggregateCall);
|
||||
} else {
|
||||
newCalls.add(newCall);
|
||||
}
|
||||
}
|
||||
|
||||
if (newCalls.equals(aggregate.getAggCallList())) {
|
||||
return;
|
||||
}
|
||||
|
||||
final RelBuilder relBuilder = call.builder()
|
||||
.push(project.getInput())
|
||||
.project(newProjects);
|
||||
|
||||
final RelBuilder.GroupKey groupKey = relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets());
|
||||
|
||||
relBuilder.aggregate(groupKey, newCalls)
|
||||
.convert(aggregate.getRowType(), false);
|
||||
|
||||
call.transformTo(relBuilder.build());
|
||||
call.getPlanner().prune(aggregate);
|
||||
}
|
||||
|
||||
private @Nullable AggregateCall transform(AggregateCall call,
|
||||
Project project, List<RexNode> newProjects)
|
||||
{
|
||||
final int singleArg = soleArgument(call);
|
||||
if (singleArg < 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
final RexNode rexNode = project.getProjects().get(singleArg);
|
||||
if (!isThreeArgCase(rexNode)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
final RelOptCluster cluster = project.getCluster();
|
||||
final RexBuilder rexBuilder = cluster.getRexBuilder();
|
||||
final RexCall caseCall = (RexCall) rexNode;
|
||||
|
||||
// If one arg is null and the other is not, reverse them and set "flip",
|
||||
// which negates the filter.
|
||||
final boolean flip = RexLiteral.isNullLiteral(caseCall.operands.get(1))
|
||||
&& !RexLiteral.isNullLiteral(caseCall.operands.get(2));
|
||||
final RexNode arg1 = caseCall.operands.get(flip ? 2 : 1);
|
||||
final RexNode arg2 = caseCall.operands.get(flip ? 1 : 2);
|
||||
|
||||
// Operand 1: Filter
|
||||
final SqlPostfixOperator op = flip ? SqlStdOperatorTable.IS_NOT_TRUE : SqlStdOperatorTable.IS_TRUE;
|
||||
final RexNode filterFromCase = rexBuilder.makeCall(op, caseCall.operands.get(0));
|
||||
|
||||
// Combine the CASE filter with an honest-to-goodness SQL FILTER, if the
|
||||
// latter is present.
|
||||
final RexNode filter;
|
||||
if (call.filterArg >= 0) {
|
||||
filter = rexBuilder.makeCall(
|
||||
SqlStdOperatorTable.AND,
|
||||
project.getProjects().get(call.filterArg),
|
||||
filterFromCase
|
||||
);
|
||||
} else {
|
||||
filter = filterFromCase;
|
||||
}
|
||||
|
||||
RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
|
||||
final SqlKind kind = call.getAggregation().getKind();
|
||||
if (call.isDistinct()) {
|
||||
// Just one style supported:
|
||||
// COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END)
|
||||
// =>
|
||||
// COUNT(DISTINCT y) FILTER(WHERE x = 'foo')
|
||||
|
||||
if (kind == SqlKind.COUNT
|
||||
&& RexLiteral.isNullLiteral(arg2)) {
|
||||
newProjects.add(arg1);
|
||||
newProjects.add(filter);
|
||||
return AggregateCall.create(
|
||||
SqlStdOperatorTable.COUNT,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
call.rexList,
|
||||
ImmutableList.of(newProjects.size() - 2),
|
||||
newProjects.size() - 1,
|
||||
null,
|
||||
RelCollations.EMPTY,
|
||||
call.getType(),
|
||||
call.getName()
|
||||
);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// Four styles supported:
|
||||
//
|
||||
// A1: AGG(CASE WHEN x = 'foo' THEN expr END)
|
||||
// => AGG(expr) FILTER (x = 'foo')
|
||||
// A2: SUM0(CASE WHEN x = 'foo' THEN cnt ELSE 0 END)
|
||||
// => SUM0(cnt) FILTER (x = 'foo')
|
||||
// B: SUM0(CASE WHEN x = 'foo' THEN 1 ELSE 0 END)
|
||||
// => COUNT() FILTER (x = 'foo')
|
||||
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END)
|
||||
// => COUNT() FILTER (x = 'foo')
|
||||
|
||||
if (kind == SqlKind.COUNT // Case C
|
||||
&& arg1.isA(SqlKind.LITERAL)
|
||||
&& !RexLiteral.isNullLiteral(arg1)
|
||||
&& RexLiteral.isNullLiteral(arg2)) {
|
||||
newProjects.add(filter);
|
||||
return AggregateCall.create(
|
||||
SqlStdOperatorTable.COUNT,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
call.rexList, ImmutableList.of(), newProjects.size() - 1, null,
|
||||
RelCollations.EMPTY, call.getType(),
|
||||
call.getName());
|
||||
} else if (kind == SqlKind.SUM0 // Case B
|
||||
&& isIntLiteral(arg1, BigDecimal.ONE)
|
||||
&& isIntLiteral(arg2, BigDecimal.ZERO)) {
|
||||
|
||||
newProjects.add(filter);
|
||||
final RelDataType dataType = typeFactory
|
||||
.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), false);
|
||||
return AggregateCall.create(
|
||||
SqlStdOperatorTable.COUNT,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
call.rexList,
|
||||
ImmutableList.of(),
|
||||
newProjects.size() - 1,
|
||||
null,
|
||||
RelCollations.EMPTY,
|
||||
dataType,
|
||||
call.getName()
|
||||
);
|
||||
} else if ((RexLiteral.isNullLiteral(arg2) // Case A1
|
||||
&& call.getAggregation().allowsFilter())
|
||||
|| (kind == SqlKind.SUM0 // Case A2
|
||||
&& isIntLiteral(arg2, BigDecimal.ZERO))) {
|
||||
newProjects.add(arg1);
|
||||
newProjects.add(filter);
|
||||
return AggregateCall.create(
|
||||
call.getAggregation(),
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
call.rexList,
|
||||
ImmutableList.of(newProjects.size() - 2),
|
||||
newProjects.size() - 1,
|
||||
null,
|
||||
RelCollations.EMPTY,
|
||||
call.getType(),
|
||||
call.getName()
|
||||
);
|
||||
}
|
||||
|
||||
// Rewrites
|
||||
// D1: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END)
|
||||
// => SUM0(cnt) FILTER (x = 'foo')
|
||||
// D2: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END)
|
||||
// => COUNT() FILTER (x = 'foo')
|
||||
//
|
||||
// https://issues.apache.org/jira/browse/CALCITE-5953
|
||||
// have restricted this rewrite as in case there are no rows it may not be equvivalent;
|
||||
// however it may have some performance impact in Druid
|
||||
if (extendedFilteredSumRewrite &&
|
||||
kind == SqlKind.SUM && isIntLiteral(arg2, BigDecimal.ZERO)) {
|
||||
if (isIntLiteral(arg1, BigDecimal.ONE)) { // D2
|
||||
newProjects.add(filter);
|
||||
final RelDataType dataType = typeFactory.createTypeWithNullability(
|
||||
typeFactory.createSqlType(SqlTypeName.BIGINT), false
|
||||
);
|
||||
return AggregateCall.create(
|
||||
SqlStdOperatorTable.COUNT,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
call.rexList,
|
||||
ImmutableList.of(),
|
||||
newProjects.size() - 1,
|
||||
null,
|
||||
RelCollations.EMPTY,
|
||||
dataType,
|
||||
call.getName()
|
||||
);
|
||||
|
||||
} else { // D1
|
||||
newProjects.add(arg1);
|
||||
newProjects.add(filter);
|
||||
|
||||
RelDataType newType = typeFactory.createTypeWithNullability(call.getType(), true);
|
||||
return AggregateCall.create(
|
||||
call.getAggregation(),
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
call.rexList,
|
||||
ImmutableList.of(newProjects.size() - 2),
|
||||
newProjects.size() - 1,
|
||||
null,
|
||||
RelCollations.EMPTY,
|
||||
newType,
|
||||
call.getName()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the argument, if an aggregate call has a single argument, otherwise
|
||||
* -1.
|
||||
*/
|
||||
private static int soleArgument(AggregateCall aggregateCall)
|
||||
{
|
||||
return aggregateCall.getArgList().size() == 1
|
||||
? aggregateCall.getArgList().get(0)
|
||||
: -1;
|
||||
}
|
||||
|
||||
private static boolean isThreeArgCase(final RexNode rexNode)
|
||||
{
|
||||
return rexNode.getKind() == SqlKind.CASE
|
||||
&& ((RexCall) rexNode).operands.size() == 3;
|
||||
}
|
||||
|
||||
private static boolean isIntLiteral(RexNode rexNode, BigDecimal value)
|
||||
{
|
||||
return rexNode instanceof RexLiteral
|
||||
&& SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName())
|
||||
&& value.equals(((RexLiteral) rexNode).getValueAs(BigDecimal.class));
|
||||
}
|
||||
}
|
|
@ -5188,7 +5188,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
@Test
|
||||
public void testFilteredAggregations()
|
||||
{
|
||||
cannotVectorizeUnlessFallback();
|
||||
Druids.TimeseriesQueryBuilder builder =
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE1)
|
||||
|
@ -5196,18 +5195,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.granularity(Granularities.ALL)
|
||||
.context(QUERY_CONTEXT_DEFAULT);
|
||||
if (NullHandling.sqlCompatible()) {
|
||||
cannotVectorizeUnlessFallback();
|
||||
builder = builder.virtualColumns(
|
||||
expressionVirtualColumn("v0", "substring(\"dim1\", 0, 1)", ColumnType.STRING),
|
||||
expressionVirtualColumn(
|
||||
"v1",
|
||||
"case_searched((\"dim1\" != '1'),1,0)",
|
||||
ColumnType.LONG
|
||||
),
|
||||
expressionVirtualColumn(
|
||||
"v2",
|
||||
"case_searched((\"dim1\" != '1'),\"cnt\",0)",
|
||||
ColumnType.LONG
|
||||
)
|
||||
expressionVirtualColumn("v0", "substring(\"dim1\", 0, 1)", ColumnType.STRING)
|
||||
)
|
||||
.aggregators(
|
||||
aggregators(
|
||||
|
@ -5234,7 +5224,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new CountAggregatorFactory("a4"),
|
||||
not(equality("dim1", "1", ColumnType.STRING))
|
||||
),
|
||||
new LongSumAggregatorFactory("a5", "v1"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a5"),
|
||||
not(equality("dim1", "1", ColumnType.STRING))
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new LongSumAggregatorFactory("a6", "cnt"),
|
||||
equality("dim2", "a", ColumnType.STRING)
|
||||
|
@ -5246,7 +5239,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
not(equality("dim1", "1", ColumnType.STRING))
|
||||
)
|
||||
),
|
||||
new LongSumAggregatorFactory("a8", "v2"),
|
||||
new FilteredAggregatorFactory(
|
||||
new LongSumAggregatorFactory("a8", "cnt"),
|
||||
not(equality("dim1", "1", ColumnType.STRING))
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new LongMaxAggregatorFactory("a9", "cnt"),
|
||||
not(equality("dim1", "1", ColumnType.STRING))
|
||||
|
@ -5272,16 +5268,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
);
|
||||
} else {
|
||||
builder = builder.virtualColumns(
|
||||
expressionVirtualColumn(
|
||||
"v0",
|
||||
"case_searched((\"dim1\" != '1'),1,0)",
|
||||
ColumnType.LONG
|
||||
),
|
||||
expressionVirtualColumn(
|
||||
"v1",
|
||||
"case_searched((\"dim1\" != '1'),\"cnt\",0)",
|
||||
ColumnType.LONG
|
||||
))
|
||||
)
|
||||
.aggregators(
|
||||
aggregators(
|
||||
new FilteredAggregatorFactory(
|
||||
|
@ -5307,7 +5294,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new CountAggregatorFactory("a4"),
|
||||
not(equality("dim1", "1", ColumnType.STRING))
|
||||
),
|
||||
new LongSumAggregatorFactory("a5", "v0"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a5"),
|
||||
not(equality("dim1", "1", ColumnType.STRING))
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new LongSumAggregatorFactory("a6", "cnt"),
|
||||
equality("dim2", "a", ColumnType.STRING)
|
||||
|
@ -5319,7 +5309,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
not(equality("dim1", "1", ColumnType.STRING))
|
||||
)
|
||||
),
|
||||
new LongSumAggregatorFactory("a8", "v1"),
|
||||
new FilteredAggregatorFactory(
|
||||
new LongSumAggregatorFactory("a8", "cnt"),
|
||||
not(equality("dim1", "1", ColumnType.STRING))
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new LongMaxAggregatorFactory("a9", "cnt"),
|
||||
not(equality("dim1", "1", ColumnType.STRING))
|
||||
|
@ -5373,7 +5366,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
@Test
|
||||
public void testCaseFilteredAggregationWithGroupBy()
|
||||
{
|
||||
cannotVectorizeUnlessFallback();
|
||||
testQuery(
|
||||
"SELECT\n"
|
||||
+ " cnt,\n"
|
||||
|
@ -5386,15 +5378,11 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setDimensions(dimensions(new DefaultDimensionSpec("cnt", "d0", ColumnType.LONG)))
|
||||
.setVirtualColumns(
|
||||
expressionVirtualColumn(
|
||||
"v0",
|
||||
"case_searched((\"dim1\" != '1'),1,0)",
|
||||
ColumnType.LONG
|
||||
)
|
||||
)
|
||||
.setAggregatorSpecs(aggregators(
|
||||
new LongSumAggregatorFactory("a0", "v0"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a0"),
|
||||
not(equality("dim1", "1", ColumnType.STRING))
|
||||
),
|
||||
new LongSumAggregatorFactory("a1", "cnt")
|
||||
))
|
||||
.setPostAggregatorSpecs(
|
||||
|
@ -5409,6 +5397,52 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCaseFilteredAggregationWithGroupRewriteToSum()
|
||||
{
|
||||
testBuilder()
|
||||
.sql(
|
||||
"SELECT\n"
|
||||
+ " cnt,\n"
|
||||
+ " SUM(CASE WHEN dim1 <> '1' THEN 2 ELSE 0 END) + SUM(cnt)\n"
|
||||
+ "FROM druid.foo\n"
|
||||
+ "GROUP BY cnt"
|
||||
)
|
||||
.expectedQueries(
|
||||
ImmutableList.of(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE1)
|
||||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setDimensions(dimensions(new DefaultDimensionSpec("cnt", "d0", ColumnType.LONG)))
|
||||
.setVirtualColumns(
|
||||
expressionVirtualColumn("v0", "2", ColumnType.LONG)
|
||||
)
|
||||
.setAggregatorSpecs(
|
||||
aggregators(
|
||||
new FilteredAggregatorFactory(
|
||||
new LongSumAggregatorFactory("a0", "v0"),
|
||||
not(equality("dim1", "1", ColumnType.STRING))
|
||||
),
|
||||
new LongSumAggregatorFactory("a1", "cnt")
|
||||
)
|
||||
)
|
||||
.setPostAggregatorSpecs(
|
||||
expressionPostAgg("p0", "(\"a0\" + \"a1\")", ColumnType.LONG)
|
||||
)
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
)
|
||||
)
|
||||
.expectedResults(
|
||||
ImmutableList.of(
|
||||
new Object[] {1L, 16L}
|
||||
)
|
||||
)
|
||||
.run();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testFilteredAggregationWithNotIn()
|
||||
{
|
||||
|
@ -9479,7 +9513,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
@Test
|
||||
public void testQueryWithSelectProjectAndIdentityProjectDoesNotRename()
|
||||
{
|
||||
cannotVectorizeUnlessFallback();
|
||||
msqIncompatible();
|
||||
testQuery(
|
||||
PLANNER_CONFIG_NO_HLL.withOverrides(
|
||||
|
@ -9506,25 +9539,30 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
"v0",
|
||||
"((\"__time\" >= 947005200000) && (\"__time\" < 1641402000000))",
|
||||
ColumnType.LONG
|
||||
),
|
||||
expressionVirtualColumn(
|
||||
"v1",
|
||||
"case_searched(((\"__time\" >= 947005200000) && (\"__time\" < 1641402000000)),1,0)",
|
||||
ColumnType.LONG
|
||||
)
|
||||
)
|
||||
.setDimensions(
|
||||
dimensions(
|
||||
new DefaultDimensionSpec("dim1", "d0", ColumnType.STRING),
|
||||
new DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
|
||||
new DefaultDimensionSpec("v0", "d0", ColumnType.LONG),
|
||||
new DefaultDimensionSpec("dim1", "d1", ColumnType.STRING)
|
||||
)
|
||||
)
|
||||
.setAggregatorSpecs(
|
||||
aggregators(
|
||||
new LongSumAggregatorFactory("a0", "v1"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a0"),
|
||||
range(
|
||||
"__time",
|
||||
ColumnType.LONG,
|
||||
timestamp("2000-01-04T17:00:00"),
|
||||
timestamp("2022-01-05T17:00:00"),
|
||||
false,
|
||||
true
|
||||
)
|
||||
),
|
||||
new GroupingAggregatorFactory(
|
||||
"a1",
|
||||
ImmutableList.of("dim1", "v0")
|
||||
ImmutableList.of("v0", "dim1")
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -9549,9 +9587,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("_a1"),
|
||||
and(
|
||||
notNull("d0"),
|
||||
notNull("d1"),
|
||||
equality("a1", 0L, ColumnType.LONG),
|
||||
expressionFilter("\"d1\"")
|
||||
expressionFilter("\"d0\"")
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
!use druidtest://?numMergeBuffers=3
|
||||
!set outputformat mysql
|
||||
|
||||
-- empty input
|
||||
SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is null,
|
||||
SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1) FROM numfoo where l1 < -1;
|
||||
+--------+--------+--------+--------+--------+
|
||||
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
|
||||
+--------+--------+--------+--------+--------+
|
||||
| 0 | 0 | true | | |
|
||||
+--------+--------+--------+--------+--------+
|
||||
(1 row)
|
||||
|
||||
!ok
|
||||
-- 0=-1,0
|
||||
SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is null,
|
||||
SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1) FROM numfoo where l1 < 3;
|
||||
+--------+--------+--------+--------+--------+
|
||||
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
|
||||
+--------+--------+--------+--------+--------+
|
||||
| 0 | 1 | false | | |
|
||||
+--------+--------+--------+--------+--------+
|
||||
(1 row)
|
||||
|
||||
!ok
|
||||
|
||||
|
||||
-- 0=0,0
|
||||
SELECT COUNT(1)FILTER(WHERE l1=0),COUNT(1)FILTER(WHERE l1!=0),MIN(l2) is null,
|
||||
SUM(CASE WHEN l1 = 0 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=0) FROM numfoo where l1 < 3;
|
||||
+--------+--------+--------+--------+--------+
|
||||
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
|
||||
+--------+--------+--------+--------+--------+
|
||||
| 1 | 0 | false | 0 | 0 |
|
||||
+--------+--------+--------+--------+--------+
|
||||
(1 row)
|
||||
|
||||
!ok
|
||||
|
||||
-- 7=7,null
|
||||
SELECT COUNT(1)FILTER(WHERE l1=7),COUNT(1)FILTER(WHERE l1!=7),MIN(l2) is null,
|
||||
SUM(CASE WHEN l1 = 7 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=7) FROM numfoo where 0 < l1 and l1 < 10;
|
||||
+--------+--------+--------+--------+--------+
|
||||
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
|
||||
+--------+--------+--------+--------+--------+
|
||||
| 1 | 0 | true | | |
|
||||
+--------+--------+--------+--------+--------+
|
||||
(1 row)
|
||||
|
||||
!ok
|
||||
|
||||
LogicalProject(EXPR$0=[$0], EXPR$1=[$1], EXPR$2=[IS NULL($2)], EXPR$3=[$3], EXPR$4=[$4])
|
||||
LogicalAggregate(group=[{}], EXPR$0=[COUNT() FILTER $0], EXPR$1=[COUNT() FILTER $1], agg#2=[MIN($2)], EXPR$3=[SUM($3)], EXPR$4=[SUM($2) FILTER $0])
|
||||
LogicalProject($f1=[IS TRUE(=($0, 7))], $f2=[IS TRUE(<>($0, 7))], l2=[$1], $f4=[CASE(=($0, 7), $1, 0:BIGINT)])
|
||||
LogicalFilter(condition=[SEARCH($0, Sarg[(0..10)])])
|
||||
LogicalProject(l1=[$11], l2=[$12])
|
||||
LogicalTableScan(table=[[druid, numfoo]])
|
||||
|
||||
!druidPlan
|
||||
|
||||
!set extendedFilteredSumRewrite false
|
||||
!use druidtest://?numMergeBuffers=3
|
||||
|
||||
|
||||
-- empty input
|
||||
SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is null,
|
||||
SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1) FROM numfoo where l1 < -1;
|
||||
+--------+--------+--------+--------+--------+
|
||||
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
|
||||
+--------+--------+--------+--------+--------+
|
||||
| 0 | 0 | true | | |
|
||||
+--------+--------+--------+--------+--------+
|
||||
(1 row)
|
||||
|
||||
!ok
|
||||
-- 0=-1,0
|
||||
SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is null,
|
||||
SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1) FROM numfoo where l1 < 3;
|
||||
+--------+--------+--------+--------+--------+
|
||||
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
|
||||
+--------+--------+--------+--------+--------+
|
||||
| 0 | 1 | false | 0 | |
|
||||
+--------+--------+--------+--------+--------+
|
||||
(1 row)
|
||||
|
||||
!ok
|
||||
|
||||
|
||||
-- 0=0,0
|
||||
SELECT COUNT(1)FILTER(WHERE l1=0),COUNT(1)FILTER(WHERE l1!=0),MIN(l2) is null,
|
||||
SUM(CASE WHEN l1 = 0 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=0) FROM numfoo where l1 < 3;
|
||||
+--------+--------+--------+--------+--------+
|
||||
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
|
||||
+--------+--------+--------+--------+--------+
|
||||
| 1 | 0 | false | 0 | 0 |
|
||||
+--------+--------+--------+--------+--------+
|
||||
(1 row)
|
||||
|
||||
!ok
|
||||
|
||||
-- 7=7,null
|
||||
SELECT COUNT(1)FILTER(WHERE l1=7),COUNT(1)FILTER(WHERE l1!=7),MIN(l2) is null,
|
||||
SUM(CASE WHEN l1 = 7 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=7) FROM numfoo where 0 < l1 and l1 < 10;
|
||||
+--------+--------+--------+--------+--------+
|
||||
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
|
||||
+--------+--------+--------+--------+--------+
|
||||
| 1 | 0 | true | | |
|
||||
+--------+--------+--------+--------+--------+
|
||||
(1 row)
|
||||
|
||||
!ok
|
||||
|
||||
LogicalProject(EXPR$0=[$0], EXPR$1=[$1], EXPR$2=[IS NULL($2)], EXPR$3=[$3], EXPR$4=[$4])
|
||||
LogicalAggregate(group=[{}], EXPR$0=[COUNT() FILTER $0], EXPR$1=[COUNT() FILTER $1], agg#2=[MIN($2)], EXPR$3=[SUM($3)], EXPR$4=[SUM($2) FILTER $0])
|
||||
LogicalProject($f1=[IS TRUE(=($0, 7))], $f2=[IS TRUE(<>($0, 7))], l2=[$1], $f4=[CASE(=($0, 7), $1, 0:BIGINT)])
|
||||
LogicalFilter(condition=[SEARCH($0, Sarg[(0..10)])])
|
||||
LogicalProject(l1=[$11], l2=[$12])
|
||||
LogicalTableScan(table=[[druid, numfoo]])
|
||||
|
||||
!druidPlan
|
Loading…
Reference in New Issue