mirror of https://github.com/apache/druid.git
VirtualColumnRegistry reuse virtual column should take account of value type (#11546)
Co-authored-by: huangqixiang.871 <huangqixiang.871@bytedance.com>
This commit is contained in:
parent
ce4dd48bb8
commit
38ebaee0fd
|
@ -24,6 +24,7 @@ import com.google.common.collect.Iterables;
|
|||
import org.apache.calcite.rel.core.AggregateCall;
|
||||
import org.apache.calcite.rel.core.Project;
|
||||
import org.apache.calcite.rex.RexBuilder;
|
||||
import org.apache.calcite.rex.RexNode;
|
||||
import org.apache.calcite.sql.SqlAggFunction;
|
||||
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
|
@ -38,6 +39,7 @@ import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
|||
import org.apache.druid.sql.calcite.aggregation.Aggregations;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.expression.Expressions;
|
||||
import org.apache.druid.sql.calcite.planner.Calcites;
|
||||
import org.apache.druid.sql.calcite.planner.PlannerContext;
|
||||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
@ -109,7 +111,12 @@ public class AvgSqlAggregator implements SqlAggregator
|
|||
expression = null;
|
||||
} else {
|
||||
// if the filter or anywhere else defined a virtual column for us, re-use it
|
||||
VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression());
|
||||
final RexNode resolutionArg = Expressions.fromFieldAccess(
|
||||
rowSignature,
|
||||
project,
|
||||
Iterables.getOnlyElement(aggregateCall.getArgList())
|
||||
);
|
||||
VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression(), resolutionArg.getType());
|
||||
fieldName = vc != null ? vc.getOutputName() : null;
|
||||
expression = vc != null ? null : arg.getExpression();
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ import javax.annotation.Nullable;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
|
@ -40,7 +41,7 @@ import java.util.stream.Collectors;
|
|||
public class VirtualColumnRegistry
|
||||
{
|
||||
private final RowSignature baseRowSignature;
|
||||
private final Map<String, VirtualColumn> virtualColumnsByExpression;
|
||||
private final Map<ExpressionWrapper, VirtualColumn> virtualColumnsByExpression;
|
||||
private final Map<String, VirtualColumn> virtualColumnsByName;
|
||||
private final String virtualColumnPrefix;
|
||||
private int virtualColumnCounter;
|
||||
|
@ -48,7 +49,7 @@ public class VirtualColumnRegistry
|
|||
private VirtualColumnRegistry(
|
||||
RowSignature baseRowSignature,
|
||||
String virtualColumnPrefix,
|
||||
Map<String, VirtualColumn> virtualColumnsByExpression,
|
||||
Map<ExpressionWrapper, VirtualColumn> virtualColumnsByExpression,
|
||||
Map<String, VirtualColumn> virtualColumnsByName
|
||||
)
|
||||
{
|
||||
|
@ -85,7 +86,8 @@ public class VirtualColumnRegistry
|
|||
ValueType valueType
|
||||
)
|
||||
{
|
||||
if (!virtualColumnsByExpression.containsKey(expression.getExpression())) {
|
||||
ExpressionWrapper expressionWrapper = new ExpressionWrapper(expression.getExpression(), valueType);
|
||||
if (!virtualColumnsByExpression.containsKey(expressionWrapper)) {
|
||||
final String virtualColumnName = virtualColumnPrefix + virtualColumnCounter++;
|
||||
final VirtualColumn virtualColumn = expression.toVirtualColumn(
|
||||
virtualColumnName,
|
||||
|
@ -93,7 +95,7 @@ public class VirtualColumnRegistry
|
|||
plannerContext.getExprMacroTable()
|
||||
);
|
||||
virtualColumnsByExpression.put(
|
||||
expression.getExpression(),
|
||||
expressionWrapper,
|
||||
virtualColumn
|
||||
);
|
||||
virtualColumnsByName.put(
|
||||
|
@ -102,7 +104,7 @@ public class VirtualColumnRegistry
|
|||
);
|
||||
}
|
||||
|
||||
return virtualColumnsByExpression.get(expression.getExpression());
|
||||
return virtualColumnsByExpression.get(expressionWrapper);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -131,9 +133,10 @@ public class VirtualColumnRegistry
|
|||
}
|
||||
|
||||
@Nullable
|
||||
public VirtualColumn getVirtualColumnByExpression(String expression)
|
||||
public VirtualColumn getVirtualColumnByExpression(String expression, RelDataType type)
|
||||
{
|
||||
return virtualColumnsByExpression.get(expression);
|
||||
ExpressionWrapper expressionWrapper = new ExpressionWrapper(expression, Calcites.getValueTypeForRelDataType(type));
|
||||
return virtualColumnsByExpression.get(expressionWrapper);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -164,4 +167,35 @@ public class VirtualColumnRegistry
|
|||
.map(this::getVirtualColumn)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static class ExpressionWrapper
|
||||
{
|
||||
private final String expression;
|
||||
private final ValueType valueType;
|
||||
|
||||
public ExpressionWrapper(String expression, ValueType valueType)
|
||||
{
|
||||
this.expression = expression;
|
||||
this.valueType = valueType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o)
|
||||
{
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
ExpressionWrapper expressionWrapper = (ExpressionWrapper) o;
|
||||
return Objects.equals(expression, expressionWrapper.expression) && valueType == expressionWrapper.valueType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode()
|
||||
{
|
||||
return Objects.hash(expression, valueType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18818,4 +18818,44 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
Collections.emptyList()
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCommonVirtualExpressionWithDifferentValueType() throws Exception
|
||||
{
|
||||
testQuery(
|
||||
"select\n"
|
||||
+ " dim1,\n"
|
||||
+ " sum(cast(0 as bigint)) as s1,\n"
|
||||
+ " sum(cast(0 as double)) as s2\n"
|
||||
+ "from druid.foo\n"
|
||||
+ "where dim1 = 'none'\n"
|
||||
+ "group by dim1\n"
|
||||
+ "limit 1",
|
||||
ImmutableList.of(new TopNQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE1)
|
||||
.intervals(querySegmentSpec(Filtration.eternity()))
|
||||
.filters(selector("dim1", "none", null))
|
||||
.granularity(Granularities.ALL)
|
||||
.virtualColumns(
|
||||
expressionVirtualColumn(
|
||||
"v0",
|
||||
"'none'",
|
||||
ValueType.STRING
|
||||
)
|
||||
)
|
||||
.dimension(
|
||||
new DefaultDimensionSpec("v0", "d0")
|
||||
)
|
||||
.aggregators(
|
||||
aggregators(
|
||||
new LongSumAggregatorFactory("a0", null, "0", ExprMacroTable.nil()),
|
||||
new DoubleSumAggregatorFactory("a1", null, "0", ExprMacroTable.nil())
|
||||
))
|
||||
.context(QUERY_CONTEXT_DEFAULT)
|
||||
.metric(new DimensionTopNMetricSpec(null, StringComparators.LEXICOGRAPHIC))
|
||||
.threshold(1)
|
||||
.build()),
|
||||
ImmutableList.of()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue