use virtual columns for sql simple aggregators instead of inline expressions (#12251)

* use virtual columns for sql simple aggregators instead of inline expressions

* fixes

* always use virtual columns

* add more tests
This commit is contained in:
Clint Wylie 2022-03-03 15:05:28 -08:00 committed by GitHub
parent 36193955b6
commit 1c004ea47e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 337 additions and 118 deletions

View File

@ -72,6 +72,7 @@ import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.FloatMinAggregatorFactory;
import org.apache.druid.query.aggregation.FloatSumAggregatorFactory;
import org.apache.druid.query.aggregation.JavaScriptAggregatorFactory;
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
@ -12858,6 +12859,165 @@ public class GroupByQueryRunnerTest extends InitializedNullHandlingTest
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
}
@Test
public void testGroupByFloatMaxExpressionVsVirtualColumn()
{
GroupByQuery query = makeQueryBuilder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("nil", "nil", ColumnType.STRING))
.setVirtualColumns(
new ExpressionVirtualColumn(
"v0",
"\"floatNumericNull\"",
ColumnType.FLOAT,
TestExprMacroTable.INSTANCE
)
)
.setAggregatorSpecs(
new FloatMinAggregatorFactory(
"min",
"floatNumericNull"
),
new FloatMinAggregatorFactory(
"minExpression",
null,
"\"floatNumericNull\"",
TestExprMacroTable.INSTANCE
),
new FloatMinAggregatorFactory(
"minVc",
"v0"
)
)
.setGranularity(QueryRunnerTestHelper.ALL_GRAN)
.build();
List<ResultRow> expectedResults = Collections.singletonList(
makeRow(
query,
"2011-04-01",
"nil",
null,
"min",
NullHandling.replaceWithDefault() ? 0.0 : 10.0,
"minExpression",
NullHandling.replaceWithDefault() ? 0.0 : 10.0,
"minVc",
NullHandling.replaceWithDefault() ? 0.0 : 10.0
)
);
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
}
@Test
public void testGroupByFloatMinExpressionVsVirtualColumnWithNonFloatInputButMatchingVirtualColumnType()
{
// SQL should never plan anything like this, the virtual column would be inferred to be a string type, and
// would try to make a string min aggregator, which would throw an exception at the time of this comment since
// it doesn't exist...
GroupByQuery query = makeQueryBuilder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("nil", "nil", ColumnType.STRING))
.setVirtualColumns(
new ExpressionVirtualColumn(
"v0",
"\"placement\"",
ColumnType.FLOAT,
TestExprMacroTable.INSTANCE
)
)
.setAggregatorSpecs(
new FloatMinAggregatorFactory(
"minExpression",
null,
"\"placement\"",
TestExprMacroTable.INSTANCE
),
new FloatMinAggregatorFactory(
"minVc",
"v0"
)
)
.setGranularity(QueryRunnerTestHelper.ALL_GRAN)
.build();
List<ResultRow> expectedResults = Collections.singletonList(
makeRow(
query,
"2011-04-01",
"nil",
null,
"minExpression",
NullHandling.replaceWithDefault() ? Float.POSITIVE_INFINITY : null,
"minVc",
NullHandling.defaultFloatValue()
)
);
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
}
@Test
public void testGroupByFloatMinExpressionVsVirtualColumnWithExplicitStringVirtualColumnTypedInput()
{
cannotVectorize();
// SQL should never plan anything like this, where the virtual column type mismatches the aggregator type
// but it still works ok
GroupByQuery query = makeQueryBuilder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("nil", "nil", ColumnType.STRING))
.setVirtualColumns(
new ExpressionVirtualColumn(
"v0",
"\"placement\"",
ColumnType.STRING,
TestExprMacroTable.INSTANCE
)
)
.setAggregatorSpecs(
new FloatMinAggregatorFactory(
"min",
"placement"
),
new FloatMinAggregatorFactory(
"minExpression",
null,
"\"placement\"",
TestExprMacroTable.INSTANCE
),
new FloatMinAggregatorFactory(
"minVc",
"v0"
)
)
.setGranularity(QueryRunnerTestHelper.ALL_GRAN)
.build();
List<ResultRow> expectedResults = Collections.singletonList(
makeRow(
query,
"2011-04-01",
"nil",
null,
"min",
Float.POSITIVE_INFINITY,
"minExpression",
NullHandling.replaceWithDefault() ? Float.POSITIVE_INFINITY : null,
"minVc",
NullHandling.replaceWithDefault() ? Float.POSITIVE_INFINITY : null
)
);
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
}
private static ResultRow makeRow(final GroupByQuery query, final String timestamp, final Object... vals)
{
return GroupByQueryRunnerTestHelper.createExpectedRow(query, timestamp, vals);

View File

@ -91,11 +91,8 @@ public class AvgSqlAggregator implements SqlAggregator
project
);
final String fieldName;
final String expression;
final DruidExpression arg = Iterables.getOnlyElement(arguments);
final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
final ValueType sumType;
// Use 64-bit sum regardless of the type of the AVG aggregator.
@ -105,26 +102,24 @@ public class AvgSqlAggregator implements SqlAggregator
sumType = ValueType.DOUBLE;
}
final String fieldName;
if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn();
expression = null;
} else {
// if the filter or anywhere else defined a virtual column for us, re-use it
final RexNode resolutionArg = Expressions.fromFieldAccess(
rowSignature,
project,
Iterables.getOnlyElement(aggregateCall.getArgList())
);
String vc = virtualColumnRegistry.getVirtualColumnByExpression(arg, resolutionArg.getType());
fieldName = vc != null ? vc : null;
expression = vc != null ? null : arg.getExpression();
fieldName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(arg, resolutionArg.getType());
}
final String sumName = Calcites.makePrefixedName(name, "sum");
final AggregatorFactory sum = SumSqlAggregator.createSumAggregatorFactory(
sumType,
sumName,
fieldName,
expression,
macroTable
);

View File

@ -46,32 +46,30 @@ public class MaxSqlAggregator extends SimpleSqlAggregator
final String name,
final AggregateCall aggregateCall,
final ExprMacroTable macroTable,
final String fieldName,
final String expression
final String fieldName
)
{
final ColumnType valueType = Calcites.getColumnTypeForRelDataType(aggregateCall.getType());
if (valueType == null) {
return null;
}
return Aggregation.create(createMaxAggregatorFactory(valueType.getType(), name, fieldName, expression, macroTable));
return Aggregation.create(createMaxAggregatorFactory(valueType.getType(), name, fieldName, macroTable));
}
private static AggregatorFactory createMaxAggregatorFactory(
final ValueType aggregationType,
final String name,
final String fieldName,
final String expression,
final ExprMacroTable macroTable
)
{
switch (aggregationType) {
case LONG:
return new LongMaxAggregatorFactory(name, fieldName, expression, macroTable);
return new LongMaxAggregatorFactory(name, fieldName, null, macroTable);
case FLOAT:
return new FloatMaxAggregatorFactory(name, fieldName, expression, macroTable);
return new FloatMaxAggregatorFactory(name, fieldName, null, macroTable);
case DOUBLE:
return new DoubleMaxAggregatorFactory(name, fieldName, expression, macroTable);
return new DoubleMaxAggregatorFactory(name, fieldName, null, macroTable);
default:
throw new UnsupportedSQLQueryException("Max aggregation is not supported for '%s' type", aggregationType);
}

View File

@ -45,29 +45,27 @@ public class MinSqlAggregator extends SimpleSqlAggregator
final String name,
final AggregateCall aggregateCall,
final ExprMacroTable macroTable,
final String fieldName,
final String expression
final String fieldName
)
{
final ColumnType valueType = Calcites.getColumnTypeForRelDataType(aggregateCall.getType());
return Aggregation.create(createMinAggregatorFactory(valueType, name, fieldName, expression, macroTable));
return Aggregation.create(createMinAggregatorFactory(valueType, name, fieldName, macroTable));
}
private static AggregatorFactory createMinAggregatorFactory(
final ColumnType aggregationType,
final String name,
final String fieldName,
final String expression,
final ExprMacroTable macroTable
)
{
switch (aggregationType.getType()) {
case LONG:
return new LongMinAggregatorFactory(name, fieldName, expression, macroTable);
return new LongMinAggregatorFactory(name, fieldName, null, macroTable);
case FLOAT:
return new FloatMinAggregatorFactory(name, fieldName, expression, macroTable);
return new FloatMinAggregatorFactory(name, fieldName, null, macroTable);
case DOUBLE:
return new DoubleMinAggregatorFactory(name, fieldName, expression, macroTable);
return new DoubleMinAggregatorFactory(name, fieldName, null, macroTable);
default:
throw new UnsupportedSQLQueryException("MIN aggregator is not supported for '%s' type", aggregationType);
}

View File

@ -78,24 +78,21 @@ public abstract class SimpleSqlAggregator implements SqlAggregator
final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
final String fieldName;
final String expression;
if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn();
expression = null;
} else {
fieldName = null;
expression = arg.getExpression();
// sharing is caring, make a virtual column to maximize re-use
fieldName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(arg, aggregateCall.getType());
}
return getAggregation(name, aggregateCall, macroTable, fieldName, expression);
return getAggregation(name, aggregateCall, macroTable, fieldName);
}
abstract Aggregation getAggregation(
String name,
AggregateCall aggregateCall,
ExprMacroTable macroTable,
String fieldName,
String expression
String fieldName
);
}

View File

@ -46,32 +46,30 @@ public class SumSqlAggregator extends SimpleSqlAggregator
final String name,
final AggregateCall aggregateCall,
final ExprMacroTable macroTable,
final String fieldName,
final String expression
final String fieldName
)
{
final ColumnType valueType = Calcites.getColumnTypeForRelDataType(aggregateCall.getType());
if (valueType == null) {
return null;
}
return Aggregation.create(createSumAggregatorFactory(valueType.getType(), name, fieldName, expression, macroTable));
return Aggregation.create(createSumAggregatorFactory(valueType.getType(), name, fieldName, macroTable));
}
static AggregatorFactory createSumAggregatorFactory(
final ValueType aggregationType,
final String name,
final String fieldName,
final String expression,
final ExprMacroTable macroTable
)
{
switch (aggregationType) {
case LONG:
return new LongSumAggregatorFactory(name, fieldName, expression, macroTable);
return new LongSumAggregatorFactory(name, fieldName, null, macroTable);
case FLOAT:
return new FloatSumAggregatorFactory(name, fieldName, expression, macroTable);
return new FloatSumAggregatorFactory(name, fieldName, null, macroTable);
case DOUBLE:
return new DoubleSumAggregatorFactory(name, fieldName, expression, macroTable);
return new DoubleSumAggregatorFactory(name, fieldName, null, macroTable);
default:
throw new UnsupportedSQLQueryException("Sum aggregation is not supported for '%s' type", aggregationType);
}

View File

@ -81,7 +81,7 @@ public class BinaryOperatorConversion implements SqlOperatorConversion
);
}
private DruidExpression.DruidExpressionBuilder getOperatorFunction(RexNode rexNode)
private DruidExpression.DruidExpressionCreator getOperatorFunction(RexNode rexNode)
{
return operands -> {
if (operands.size() < 2) {

View File

@ -24,6 +24,7 @@ import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.io.BaseEncoding;
import com.google.common.primitives.Chars;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.Parser;
@ -50,7 +51,7 @@ import java.util.function.Function;
* {@link #toVirtualColumn(String, ColumnType, ExprMacroTable)}
*
* Approximate expression structure is retained in the {@link #arguments}, which when fed into the
* {@link ExpressionBuilder} that all {@link DruidExpression} must be created with will produce the final String
* {@link ExpressionGenerator} that all {@link DruidExpression} must be created with will produce the final String
* expression (which will be later parsed into {@link Expr} during native processing).
*
* This allows using the {@link DruidExpressionShuttle} to examine this expression "tree" and potentially rewrite some
@ -83,7 +84,7 @@ public class DruidExpression
// Must be sorted
private static final char[] SAFE_CHARS = " ,._-;:(){}[]<>!@#$%^&*`~?/".toCharArray();
private static final VirtualColumnBuilder DEFAULT_VIRTUAL_COLUMN_BUILDER = new ExpressionVirtualColumnBuilder();
private static final VirtualColumnCreator DEFAULT_VIRTUAL_COLUMN_BUILDER = new ExpressionVirtualColumnCreator();
static {
Arrays.sort(SAFE_CHARS);
@ -118,7 +119,7 @@ public class DruidExpression
return "null";
}
public static ExpressionBuilder functionCall(final String functionName)
public static ExpressionGenerator functionCall(final String functionName)
{
Preconditions.checkNotNull(functionName, "functionName");
@ -148,7 +149,7 @@ public class DruidExpression
@Deprecated
public static String functionCall(final String functionName, final List<DruidExpression> args)
{
return functionCall(functionName).buildExpression(args);
return functionCall(functionName).compile(args);
}
/**
@ -157,7 +158,7 @@ public class DruidExpression
@Deprecated
public static String functionCall(final String functionName, final DruidExpression... args)
{
return functionCall(functionName).buildExpression(Arrays.asList(args));
return functionCall(functionName).compile(Arrays.asList(args));
}
public static DruidExpression ofLiteral(
@ -169,7 +170,7 @@ public class DruidExpression
NodeType.LITERAL,
columnType,
null,
new LiteralExpressionBuilder(literal),
new LiteralExpressionGenerator(literal),
Collections.emptyList(),
null
);
@ -190,7 +191,7 @@ public class DruidExpression
NodeType.IDENTIFIER,
columnType,
simpleExtraction,
new IdentifierExpressionBuilder(column),
new IdentifierExpressionGenerator(column),
Collections.emptyList(),
null
);
@ -212,35 +213,35 @@ public class DruidExpression
public static DruidExpression ofVirtualColumn(
final ColumnType type,
final ExpressionBuilder expressionBuilder,
final ExpressionGenerator expressionGenerator,
final List<DruidExpression> arguments,
final VirtualColumnBuilder virtualColumnBuilder
final VirtualColumnCreator virtualColumnCreator
)
{
return new DruidExpression(NodeType.SPECIALIZED, type, null, expressionBuilder, arguments, virtualColumnBuilder);
return new DruidExpression(NodeType.SPECIALIZED, type, null, expressionGenerator, arguments, virtualColumnCreator);
}
public static DruidExpression ofExpression(
@Nullable final ColumnType columnType,
final ExpressionBuilder expressionBuilder,
final ExpressionGenerator expressionGenerator,
final List<DruidExpression> arguments
)
{
return new DruidExpression(NodeType.EXPRESSION, columnType, null, expressionBuilder, arguments, null);
return new DruidExpression(NodeType.EXPRESSION, columnType, null, expressionGenerator, arguments, null);
}
public static DruidExpression ofExpression(
@Nullable final ColumnType columnType,
final SimpleExtraction simpleExtraction,
final ExpressionBuilder expressionBuilder,
final ExpressionGenerator expressionGenerator,
final List<DruidExpression> arguments
)
{
return new DruidExpression(NodeType.EXPRESSION, columnType, simpleExtraction, expressionBuilder, arguments, null);
return new DruidExpression(NodeType.EXPRESSION, columnType, simpleExtraction, expressionGenerator, arguments, null);
}
/**
* @deprecated use {@link #ofExpression(ColumnType, SimpleExtraction, ExpressionBuilder, List)} instead to participate
* @deprecated use {@link #ofExpression(ColumnType, SimpleExtraction, ExpressionGenerator, List)} instead to participate
* in virtual column and expression optimization
*/
@Deprecated
@ -250,7 +251,7 @@ public class DruidExpression
NodeType.EXPRESSION,
null,
simpleExtraction,
new LiteralExpressionBuilder(expression),
new LiteralExpressionGenerator(expression),
Collections.emptyList(),
null
);
@ -267,14 +268,14 @@ public class DruidExpression
NodeType.EXPRESSION,
null,
SimpleExtraction.of(column, null),
new IdentifierExpressionBuilder(column),
new IdentifierExpressionGenerator(column),
Collections.emptyList(),
null
);
}
/**
* @deprecated use {@link #ofExpression(ColumnType, ExpressionBuilder, List)} instead to participate in virtual
* @deprecated use {@link #ofExpression(ColumnType, ExpressionGenerator, List)} instead to participate in virtual
* column and expression optimization
*/
@Deprecated
@ -284,7 +285,7 @@ public class DruidExpression
NodeType.EXPRESSION,
null,
null,
new LiteralExpressionBuilder(expression),
new LiteralExpressionGenerator(expression),
Collections.emptyList(),
null
);
@ -301,7 +302,7 @@ public class DruidExpression
NodeType.EXPRESSION,
null,
null,
new LiteralExpressionBuilder(functionCall(functionName, args)),
new LiteralExpressionGenerator(functionCall(functionName, args)),
Collections.emptyList(),
null
);
@ -313,8 +314,8 @@ public class DruidExpression
private final List<DruidExpression> arguments;
@Nullable
private final SimpleExtraction simpleExtraction;
private final ExpressionBuilder expressionBuilder;
private final VirtualColumnBuilder virtualColumnBuilder;
private final ExpressionGenerator expressionGenerator;
private final VirtualColumnCreator virtualColumnCreator;
private final Supplier<String> expression;
@ -322,18 +323,18 @@ public class DruidExpression
final NodeType nodeType,
@Nullable final ColumnType druidType,
@Nullable final SimpleExtraction simpleExtraction,
final ExpressionBuilder expressionBuilder,
final ExpressionGenerator expressionGenerator,
final List<DruidExpression> arguments,
@Nullable final VirtualColumnBuilder virtualColumnBuilder
@Nullable final VirtualColumnCreator virtualColumnCreator
)
{
this.nodeType = nodeType;
this.druidType = druidType;
this.simpleExtraction = simpleExtraction;
this.expressionBuilder = Preconditions.checkNotNull(expressionBuilder);
this.expressionGenerator = Preconditions.checkNotNull(expressionGenerator);
this.arguments = arguments;
this.virtualColumnBuilder = virtualColumnBuilder != null ? virtualColumnBuilder : DEFAULT_VIRTUAL_COLUMN_BUILDER;
this.expression = Suppliers.memoize(() -> this.expressionBuilder.buildExpression(this.arguments));
this.virtualColumnCreator = virtualColumnCreator != null ? virtualColumnCreator : DEFAULT_VIRTUAL_COLUMN_BUILDER;
this.expression = Suppliers.memoize(() -> this.expressionGenerator.compile(this.arguments));
}
public String getExpression()
@ -361,11 +362,17 @@ public class DruidExpression
return Preconditions.checkNotNull(simpleExtraction);
}
/**
* Get sub {@link DruidExpression} arguments of this expression
*/
public List<DruidExpression> getArguments()
{
return arguments;
}
/**
* Compile the {@link DruidExpression} into a string and parse it into a native Druid {@link Expr}
*/
public Expr parse(final ExprMacroTable macroTable)
{
return Parser.parse(expression.get(), macroTable);
@ -377,7 +384,7 @@ public class DruidExpression
final ExprMacroTable macroTable
)
{
return virtualColumnBuilder.build(name, outputType, expression.get(), macroTable);
return virtualColumnCreator.create(name, outputType, expression.get(), macroTable);
}
public NodeType getType()
@ -385,6 +392,17 @@ public class DruidExpression
return nodeType;
}
/**
* The {@link ColumnType} of this expression as inferred when this expression was created. This is likely the result
* of converting the output of {@link org.apache.calcite.rex.RexNode#getType()} using
* {@link org.apache.druid.sql.calcite.planner.Calcites#getColumnTypeForRelDataType(RelDataType)}, but may also be
* supplied by other means.
*
* This value is not currently used other than for tracking the types of the {@link DruidExpression} tree. The
* value passed to {@link #toVirtualColumn(String, ColumnType, ExprMacroTable)} will instead be whatever type "hint"
* was specified when the expression was added to the {@link org.apache.druid.sql.calcite.rel.VirtualColumnRegistry}.
*/
@Nullable
public ColumnType getDruidType()
{
return druidType;
@ -399,9 +417,9 @@ public class DruidExpression
nodeType,
druidType,
simpleExtraction == null ? null : extractionMap.apply(simpleExtraction),
(args) -> expressionMap.apply(expressionBuilder.buildExpression(args)),
(args) -> expressionMap.apply(expressionGenerator.compile(args)),
arguments,
virtualColumnBuilder
virtualColumnCreator
);
}
@ -411,21 +429,26 @@ public class DruidExpression
nodeType,
druidType,
simpleExtraction,
expressionBuilder,
expressionGenerator,
newArgs,
virtualColumnBuilder
virtualColumnCreator
);
}
/**
* Visit all sub {@link DruidExpression} (the {@link #arguments} of this expression), allowing the
* {@link DruidExpressionShuttle} to potentially rewrite these arguments with new {@link DruidExpression}, finally
* building a new version of this {@link DruidExpression} with updated {@link #arguments}.
*/
public DruidExpression visit(DruidExpressionShuttle shuttle)
{
return new DruidExpression(
nodeType,
druidType,
simpleExtraction,
expressionBuilder,
expressionGenerator,
shuttle.visitAll(arguments),
virtualColumnBuilder
virtualColumnCreator
);
}
@ -478,38 +501,39 @@ public class DruidExpression
}
}
/**
* Create a {@link DruidExpression} given some set of input argument sub-expressions
*/
@FunctionalInterface
public interface DruidExpressionBuilder
public interface DruidExpressionCreator
{
DruidExpression buildExpression(List<DruidExpression> arguments);
DruidExpression create(List<DruidExpression> arguments);
}
/**
* Used by {@link DruidExpression} to compile a string which can be parsed into an {@link Expr} from given the
* sub-expression arguments
*/
@FunctionalInterface
public interface ExpressionBuilder
public interface ExpressionGenerator
{
String buildExpression(List<DruidExpression> arguments);
}
@FunctionalInterface
public interface VirtualColumnBuilder
{
VirtualColumn build(String name, ColumnType outputType, String expression, ExprMacroTable macroTable);
String compile(List<DruidExpression> arguments);
}
/**
* Direct reference to a physical or virtual column
*/
public static class IdentifierExpressionBuilder implements ExpressionBuilder
public static class IdentifierExpressionGenerator implements ExpressionGenerator
{
private final String identifier;
public IdentifierExpressionBuilder(String identifier)
public IdentifierExpressionGenerator(String identifier)
{
this.identifier = escape(identifier);
}
@Override
public String buildExpression(List<DruidExpression> arguments)
public String compile(List<DruidExpression> arguments)
{
// identifier expression has no arguments
return "\"" + identifier + "\"";
@ -519,27 +543,37 @@ public class DruidExpression
/**
* Builds expressions for a static constant value
*/
public static class LiteralExpressionBuilder implements ExpressionBuilder
public static class LiteralExpressionGenerator implements ExpressionGenerator
{
private final String literal;
public LiteralExpressionBuilder(String literal)
public LiteralExpressionGenerator(String literal)
{
this.literal = literal;
}
@Override
public String buildExpression(List<DruidExpression> arguments)
public String compile(List<DruidExpression> arguments)
{
// literal expression has no arguments
return literal;
}
}
public static class ExpressionVirtualColumnBuilder implements VirtualColumnBuilder
/**
* Used by a {@link DruidExpression} to translate itself into a {@link VirtualColumn} to add to a native query when
* referenced by a projection, filter, aggregator, etc.
*/
@FunctionalInterface
public interface VirtualColumnCreator
{
VirtualColumn create(String name, ColumnType outputType, String expression, ExprMacroTable macroTable);
}
public static class ExpressionVirtualColumnCreator implements VirtualColumnCreator
{
@Override
public VirtualColumn build(String name, ColumnType outputType, String expression, ExprMacroTable macroTable)
public VirtualColumn create(String name, ColumnType outputType, String expression, ExprMacroTable macroTable)
{
return new ExpressionVirtualColumn(name, expression, outputType, macroTable);
}

View File

@ -119,7 +119,7 @@ public class OperatorConversions
final PlannerContext plannerContext,
final RowSignature rowSignature,
final RexNode rexNode,
final DruidExpression.ExpressionBuilder expressionBuilder
final DruidExpression.ExpressionGenerator expressionGenerator
)
{
return convertCall(
@ -128,7 +128,7 @@ public class OperatorConversions
rexNode,
(operands) -> DruidExpression.ofExpression(
Calcites.getColumnTypeForRelDataType(rexNode.getType()),
expressionBuilder,
expressionGenerator,
operands
)
);
@ -139,7 +139,7 @@ public class OperatorConversions
final PlannerContext plannerContext,
final RowSignature rowSignature,
final RexNode rexNode,
final DruidExpression.DruidExpressionBuilder expressionFunction
final DruidExpression.DruidExpressionCreator expressionFunction
)
{
final RexCall call = (RexCall) rexNode;
@ -154,7 +154,7 @@ public class OperatorConversions
return null;
}
return expressionFunction.buildExpression(druidExpressions);
return expressionFunction.create(druidExpressions);
}
@Deprecated
@ -211,7 +211,7 @@ public class OperatorConversions
final PlannerContext plannerContext,
final RowSignature rowSignature,
final RexNode rexNode,
final DruidExpression.DruidExpressionBuilder expressionFunction,
final DruidExpression.DruidExpressionCreator expressionFunction,
final PostAggregatorVisitor postAggregatorVisitor
)
{
@ -228,7 +228,7 @@ public class OperatorConversions
return null;
}
return expressionFunction.buildExpression(druidExpressions);
return expressionFunction.create(druidExpressions);
}
/**

View File

@ -341,7 +341,7 @@ public class MultiValueStringOperatorConversions
return null;
}
final DruidExpression.ExpressionBuilder builder = (args) -> {
final DruidExpression.ExpressionGenerator builder = (args) -> {
final StringBuilder expressionBuilder;
if (isAllowList()) {
expressionBuilder = new StringBuilder("filter((x) -> array_contains(");

View File

@ -63,7 +63,7 @@ public class StrposOperatorConversion implements SqlOperatorConversion
Calcites.getColumnTypeForRelDataType(rexNode.getType()),
(args) -> StringUtils.format(
"(%s + 1)",
DruidExpression.functionCall("strpos").buildExpression(args)
DruidExpression.functionCall("strpos").compile(args)
),
druidExpressions
)

View File

@ -262,6 +262,19 @@ public class VirtualColumnRegistry
return new ExpressionAndTypeHint(expression, typeHint);
}
/**
* Wrapper class for a {@link DruidExpression} and the output {@link ColumnType} "hint" that callers can specify when
* adding a virtual column with {@link #getOrCreateVirtualColumnForExpression(DruidExpression, RelDataType)} or
* {@link #getOrCreateVirtualColumnForExpression(DruidExpression, ColumnType)}. This "hint" will be passed into
* {@link DruidExpression#toVirtualColumn(String, ColumnType, ExprMacroTable)}.
*
* The type hint might be different than {@link DruidExpression#getDruidType()} since that value is the captured value
* of {@link org.apache.calcite.rex.RexNode#getType()} converted to the Druid type system, while callers might still
* explicitly specify a different type to use for the hint. Additionally, the method used to convert Calcite types to
* Druid types does not completely map the former to the latter, and the method typically used to do the conversion,
* {@link Calcites#getColumnTypeForRelDataType(RelDataType)}, might return null, where the caller might know what
* the type should be.
*/
private static class ExpressionAndTypeHint
{
private final DruidExpression expression;

View File

@ -966,10 +966,13 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setVirtualColumns(
expressionVirtualColumn("v0", "CAST(\"a0\", 'DOUBLE')", ColumnType.DOUBLE)
)
.setAggregatorSpecs(aggregators(new DoubleSumAggregatorFactory(
"_a0",
"v0",
null,
"CAST(\"a0\", 'DOUBLE')",
ExprMacroTable.nil()
)))
.setContext(QUERY_CONTEXT_DEFAULT)
@ -1014,10 +1017,13 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setVirtualColumns(
expressionVirtualColumn("v0", "CAST(\"a0\", 'DOUBLE')", ColumnType.DOUBLE)
)
.setAggregatorSpecs(aggregators(new DoubleSumAggregatorFactory(
"_a0",
"v0",
null,
"CAST(\"a0\", 'DOUBLE')",
ExprMacroTable.nil()
)))
.setContext(QUERY_CONTEXT_DEFAULT)
@ -1121,10 +1127,13 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setVirtualColumns(
expressionVirtualColumn("v0", "CAST(\"a0\", 'DOUBLE')", ColumnType.DOUBLE)
)
.setAggregatorSpecs(aggregators(new DoubleSumAggregatorFactory(
"_a0",
"v0",
null,
"CAST(\"a0\", 'DOUBLE')",
ExprMacroTable.nil()
)))
.setContext(QUERY_CONTEXT_DEFAULT)
@ -4976,19 +4985,26 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
+ " LN(SUM(cnt) + SUM(m1)),\n"
+ " MOD(SUM(cnt), 4),\n"
+ " SUM(CHARACTER_LENGTH(CAST(cnt * 10 AS VARCHAR))),\n"
+ " MAX(CHARACTER_LENGTH(dim2) + LN(m1))\n"
+ " MAX(CHARACTER_LENGTH(dim2) + LN(m1)),\n"
+ " MIN(CHARACTER_LENGTH(dim2) + LN(m1))\n"
+ "FROM druid.foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.virtualColumns(
expressionVirtualColumn("v0", "(\"cnt\" * 3)", ColumnType.LONG),
expressionVirtualColumn("v1", "strlen(CAST((\"cnt\" * 10), 'STRING'))", ColumnType.LONG),
expressionVirtualColumn("v2", "(strlen(\"dim2\") + log(\"m1\"))", ColumnType.DOUBLE)
)
.aggregators(aggregators(
new LongSumAggregatorFactory("a0", null, "(\"cnt\" * 3)", macroTable),
new LongSumAggregatorFactory("a0", "v0", null, macroTable),
new LongSumAggregatorFactory("a1", "cnt"),
new DoubleSumAggregatorFactory("a2", "m1"),
new LongSumAggregatorFactory("a3", null, "strlen(CAST((\"cnt\" * 10), 'STRING'))", macroTable),
new DoubleMaxAggregatorFactory("a4", null, "(strlen(\"dim2\") + log(\"m1\"))", macroTable)
new LongSumAggregatorFactory("a3", "v1", null, macroTable),
new DoubleMaxAggregatorFactory("a4", "v2", null, macroTable),
new DoubleMinAggregatorFactory("a5", "v2", null, macroTable)
))
.postAggregators(
expressionPostAgg("p0", "log((\"a1\" + \"a2\"))"),
@ -4998,7 +5014,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.build()
),
ImmutableList.of(
new Object[]{18L, 3.295836866004329, 2, 12L, 3f + (Math.log(5.0))}
new Object[]{18L, 3.295836866004329, 2, 12L, 3f + (Math.log(5.0)), useDefault ? 0.6931471805599453 : 1.0}
)
);
}
@ -5885,11 +5901,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.virtualColumns(
expressionVirtualColumn("v0", "CAST(\"dim1\", 'LONG')", ColumnType.LONG)
)
.aggregators(aggregators(
new LongSumAggregatorFactory(
"a0",
"v0",
null,
"CAST(\"dim1\", 'LONG')",
CalciteTests.createExprMacroTable()
)
))
@ -5915,11 +5934,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.virtualColumns(
expressionVirtualColumn("v0", "CAST(substring(\"dim1\", 0, 10), 'LONG')", ColumnType.LONG)
)
.aggregators(aggregators(
new LongSumAggregatorFactory(
"a0",
"v0",
null,
"CAST(substring(\"dim1\", 0, 10), 'LONG')",
CalciteTests.createExprMacroTable()
)
))
@ -8263,6 +8285,11 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
"v0",
"case_searched(((947005200000 <= \"__time\") && (\"__time\" < 1641402000000)),\"dim1\",null)",
ColumnType.STRING
),
expressionVirtualColumn(
"v1",
"case_searched(((947005200000 <= \"__time\") && (\"__time\" < 1641402000000)),1,0)",
ColumnType.LONG
)
)
.setDimensions(
@ -8278,8 +8305,8 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
aggregators(
new LongSumAggregatorFactory(
"a0",
"v1",
null,
"case_searched(((947005200000 <= \"__time\") && (\"__time\" < 1641402000000)),1,0)",
ExprMacroTable.nil()
),
new GroupingAggregatorFactory(
@ -12703,11 +12730,12 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new CountAggregatorFactory("a0"),
not(selector("v0", null, null))
),
new LongSumAggregatorFactory("a1:sum", null, "325323", TestExprMacroTable.INSTANCE),
new LongSumAggregatorFactory("a1:sum", "v1", null, TestExprMacroTable.INSTANCE),
new CountAggregatorFactory("a1:count")
);
virtualColumns = ImmutableList.of(
expressionVirtualColumn("v0", "'10.1'", ColumnType.STRING)
expressionVirtualColumn("v0", "'10.1'", ColumnType.STRING),
expressionVirtualColumn("v1", "325323", ColumnType.LONG)
);
} else {
aggs = ImmutableList.of(
@ -13687,19 +13715,17 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.filters(selector("dim1", "none", null))
.granularity(Granularities.ALL)
.virtualColumns(
expressionVirtualColumn(
"v0",
"'none'",
ColumnType.STRING
)
expressionVirtualColumn("v0", "'none'", ColumnType.STRING),
expressionVirtualColumn("v1", "0", ColumnType.LONG),
expressionVirtualColumn("v2", "0", ColumnType.DOUBLE)
)
.dimension(
new DefaultDimensionSpec("v0", "d0")
)
.aggregators(
aggregators(
new LongSumAggregatorFactory("a0", null, "0", ExprMacroTable.nil()),
new DoubleSumAggregatorFactory("a1", null, "0", ExprMacroTable.nil())
new LongSumAggregatorFactory("a0", "v1", null, ExprMacroTable.nil()),
new DoubleSumAggregatorFactory("a1", "v2", null, ExprMacroTable.nil())
))
.context(QUERY_CONTEXT_DEFAULT)
.metric(new DimensionTopNMetricSpec(null, StringComparators.LEXICOGRAPHIC))