VirtualColumnRegistry reuse virtual column should take account of value type (#11546)

Co-authored-by: huangqixiang.871 <huangqixiang.871@bytedance.com>
This commit is contained in:
hqx871 2021-08-19 16:46:27 +08:00 committed by GitHub
parent ce4dd48bb8
commit 38ebaee0fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 89 additions and 8 deletions

View File

@ -24,6 +24,7 @@ import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
@ -38,6 +39,7 @@ import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
@ -109,7 +111,12 @@ public class AvgSqlAggregator implements SqlAggregator
expression = null;
} else {
// if the filter or anywhere else defined a virtual column for us, re-use it
VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression());
final RexNode resolutionArg = Expressions.fromFieldAccess(
rowSignature,
project,
Iterables.getOnlyElement(aggregateCall.getArgList())
);
VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression(), resolutionArg.getType());
fieldName = vc != null ? vc.getOutputName() : null;
expression = vc != null ? null : arg.getExpression();
}

View File

@ -31,6 +31,7 @@ import javax.annotation.Nullable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
/**
@ -40,7 +41,7 @@ import java.util.stream.Collectors;
public class VirtualColumnRegistry
{
private final RowSignature baseRowSignature;
private final Map<String, VirtualColumn> virtualColumnsByExpression;
private final Map<ExpressionWrapper, VirtualColumn> virtualColumnsByExpression;
private final Map<String, VirtualColumn> virtualColumnsByName;
private final String virtualColumnPrefix;
private int virtualColumnCounter;
@ -48,7 +49,7 @@ public class VirtualColumnRegistry
private VirtualColumnRegistry(
RowSignature baseRowSignature,
String virtualColumnPrefix,
Map<String, VirtualColumn> virtualColumnsByExpression,
Map<ExpressionWrapper, VirtualColumn> virtualColumnsByExpression,
Map<String, VirtualColumn> virtualColumnsByName
)
{
@ -85,7 +86,8 @@ public class VirtualColumnRegistry
ValueType valueType
)
{
if (!virtualColumnsByExpression.containsKey(expression.getExpression())) {
ExpressionWrapper expressionWrapper = new ExpressionWrapper(expression.getExpression(), valueType);
if (!virtualColumnsByExpression.containsKey(expressionWrapper)) {
final String virtualColumnName = virtualColumnPrefix + virtualColumnCounter++;
final VirtualColumn virtualColumn = expression.toVirtualColumn(
virtualColumnName,
@ -93,7 +95,7 @@ public class VirtualColumnRegistry
plannerContext.getExprMacroTable()
);
virtualColumnsByExpression.put(
expression.getExpression(),
expressionWrapper,
virtualColumn
);
virtualColumnsByName.put(
@ -102,7 +104,7 @@ public class VirtualColumnRegistry
);
}
return virtualColumnsByExpression.get(expression.getExpression());
return virtualColumnsByExpression.get(expressionWrapper);
}
/**
@ -131,9 +133,10 @@ public class VirtualColumnRegistry
}
@Nullable
public VirtualColumn getVirtualColumnByExpression(String expression)
public VirtualColumn getVirtualColumnByExpression(String expression, RelDataType type)
{
return virtualColumnsByExpression.get(expression);
ExpressionWrapper expressionWrapper = new ExpressionWrapper(expression, Calcites.getValueTypeForRelDataType(type));
return virtualColumnsByExpression.get(expressionWrapper);
}
/**
@ -164,4 +167,35 @@ public class VirtualColumnRegistry
.map(this::getVirtualColumn)
.collect(Collectors.toList());
}
private static class ExpressionWrapper
{
private final String expression;
private final ValueType valueType;
public ExpressionWrapper(String expression, ValueType valueType)
{
this.expression = expression;
this.valueType = valueType;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ExpressionWrapper expressionWrapper = (ExpressionWrapper) o;
return Objects.equals(expression, expressionWrapper.expression) && valueType == expressionWrapper.valueType;
}
@Override
public int hashCode()
{
return Objects.hash(expression, valueType);
}
}
}

View File

@ -18818,4 +18818,44 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
Collections.emptyList()
);
}
@Test
public void testCommonVirtualExpressionWithDifferentValueType() throws Exception
{
testQuery(
"select\n"
+ " dim1,\n"
+ " sum(cast(0 as bigint)) as s1,\n"
+ " sum(cast(0 as double)) as s2\n"
+ "from druid.foo\n"
+ "where dim1 = 'none'\n"
+ "group by dim1\n"
+ "limit 1",
ImmutableList.of(new TopNQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.filters(selector("dim1", "none", null))
.granularity(Granularities.ALL)
.virtualColumns(
expressionVirtualColumn(
"v0",
"'none'",
ValueType.STRING
)
)
.dimension(
new DefaultDimensionSpec("v0", "d0")
)
.aggregators(
aggregators(
new LongSumAggregatorFactory("a0", null, "0", ExprMacroTable.nil()),
new DoubleSumAggregatorFactory("a1", null, "0", ExprMacroTable.nil())
))
.context(QUERY_CONTEXT_DEFAULT)
.metric(new DimensionTopNMetricSpec(null, StringComparators.LEXICOGRAPHIC))
.threshold(1)
.build()),
ImmutableList.of()
);
}
}