From 184b202411ffe6d04037e6b670985a58c3bf53e5 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Mon, 14 Sep 2020 18:18:56 -0700 Subject: [PATCH] add computed Expr output types (#10370) * push down ValueType to ExprType conversion, tidy up * determine expr output type for given input types * revert unintended name change * add nullable * tidy up * fixup * more better * fix signatures * naming things is hard * fix inspection * javadoc * make default implementation of Expr.getOutputType that returns null * rename method * more test * add output for contains expr macro, split operation and function auto conversion --- .../apache/druid/math/expr/ApplyFunction.java | 86 +- .../math/expr/BinaryLogicalOperatorExpr.java | 79 +- .../druid/math/expr/BinaryOperatorExpr.java | 9 +- .../apache/druid/math/expr/ConstantExpr.java | 41 +- .../java/org/apache/druid/math/expr/Expr.java | 73 +- .../org/apache/druid/math/expr/ExprEval.java | 5 + .../druid/math/expr/ExprListenerImpl.java | 2 +- .../druid/math/expr/ExprMacroTable.java | 14 +- .../org/apache/druid/math/expr/ExprType.java | 148 ++- .../org/apache/druid/math/expr/Function.java | 1067 +++++++++++------ .../druid/math/expr/FunctionalExpr.java | 55 +- .../druid/math/expr/IdentifierExpr.java | 10 +- .../org/apache/druid/math/expr/Parser.java | 28 +- .../druid/math/expr/UnaryOperatorExpr.java | 21 +- .../org/apache/druid/math/expr/ExprTest.java | 40 +- .../druid/math/expr/OutputTypeTest.java | 463 +++++++ .../apache/druid/math/expr/ParserTest.java | 8 +- .../expressions/BloomFilterExprMacro.java | 10 +- .../druid/query/expression/ContainsExpr.java | 12 +- .../expression/IPv4AddressMatchExprMacro.java | 10 +- .../expression/IPv4AddressParseExprMacro.java | 9 + .../IPv4AddressStringifyExprMacro.java | 9 + .../druid/query/expression/LikeExprMacro.java | 10 +- .../query/expression/LookupExprMacro.java | 9 + .../expression/RegexpExtractExprMacro.java | 9 + .../query/expression/RegexpLikeExprMacro.java | 12 +- .../expression/TimestampCeilExprMacro.java | 16 + .../expression/TimestampExtractExprMacro.java | 15 + .../expression/TimestampFloorExprMacro.java | 16 + .../expression/TimestampFormatExprMacro.java | 9 + .../expression/TimestampParseExprMacro.java | 9 + .../expression/TimestampShiftExprMacro.java | 16 + .../druid/query/expression/TrimExprMacro.java | 18 +- .../segment/filter/ExpressionFilter.java | 4 +- .../join/filter/JoinFilterCorrelations.java | 4 +- .../segment/virtual/ExpressionSelectors.java | 57 +- ...RowBasedExpressionColumnValueSelector.java | 10 +- .../expression/RegexpLikeExprMacroTest.java | 17 +- .../ReductionOperatorConversionHelper.java | 4 +- .../druid/sql/calcite/rel/Projection.java | 4 +- 40 files changed, 1905 insertions(+), 533 deletions(-) create mode 100644 core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java diff --git a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java index 4bf2fa5e934..d6f4ed2bd87 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java +++ b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java @@ -23,6 +23,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; import it.unimi.dsi.fastutil.objects.Object2IntArrayMap; import it.unimi.dsi.fastutil.objects.Object2IntMap; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.RE; import org.apache.druid.java.util.common.StringUtils; @@ -74,6 +75,15 @@ public interface ApplyFunction */ void validateArguments(LambdaExpr lambdaExpr, List args); + /** + * Compute the output type of this function for a given lambda and the argument expressions which will be applied as + * its inputs. + * + * @see Expr#getOutputType + */ + @Nullable + ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args); + /** * Base class for "map" functions, which are a class of {@link ApplyFunction} which take a lambda function that is * mapped to the values of an {@link IndexableMapLambdaObjectBinding} which is created from the outer @@ -87,6 +97,13 @@ public interface ApplyFunction return true; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) + { + return ExprType.asArrayType(expr.getOutputType(new LambdaInputBindingTypes(inputTypes, expr, args))); + } + /** * Evaluate {@link LambdaExpr} against every index position of an {@link IndexableMapLambdaObjectBinding} */ @@ -274,7 +291,7 @@ public interface ApplyFunction accumulator = evaluated.value(); } if (accumulator instanceof Boolean) { - return ExprEval.of((boolean) accumulator, ExprType.LONG); + return ExprEval.ofLongBoolean((boolean) accumulator); } return ExprEval.bestEffortOf(accumulator); } @@ -282,8 +299,16 @@ public interface ApplyFunction @Override public boolean hasArrayOutput(LambdaExpr lambdaExpr) { - Expr.BindingDetails lambdaBindingDetails = lambdaExpr.analyzeInputs(); - return lambdaBindingDetails.isOutputArray(); + Expr.BindingAnalysis lambdaBindingAnalysis = lambdaExpr.analyzeInputs(); + return lambdaBindingAnalysis.isOutputArray(); + } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) + { + // output type is accumulator type, which is last argument + return args.get(args.size() - 1).getOutputType(inputTypes); } } @@ -481,6 +506,14 @@ public interface ApplyFunction ); } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) + { + // output type is input array type + return args.get(0).getOutputType(inputTypes); + } + private Stream filter(T[] array, LambdaExpr expr, SettableLambdaBinding binding) { return Arrays.stream(array).filter(s -> expr.eval(binding.withBinding(expr.getIdentifier(), s)).asBoolean()); @@ -501,7 +534,7 @@ public interface ApplyFunction final Object[] array = arrayEval.asArray(); if (array == null) { - return ExprEval.of(false, ExprType.LONG); + return ExprEval.ofLongBoolean(false); } SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(lambdaExpr, bindings); @@ -528,6 +561,13 @@ public interface ApplyFunction ); } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) + { + return ExprType.LONG; + } + public abstract ExprEval match(Object[] values, LambdaExpr expr, SettableLambdaBinding bindings); } @@ -550,7 +590,7 @@ public interface ApplyFunction { boolean anyMatch = Arrays.stream(values) .anyMatch(o -> expr.eval(bindings.withBinding(expr.getIdentifier(), o)).asBoolean()); - return ExprEval.of(anyMatch, ExprType.LONG); + return ExprEval.ofLongBoolean(anyMatch); } } @@ -573,7 +613,7 @@ public interface ApplyFunction { boolean allMatch = Arrays.stream(values) .allMatch(o -> expr.eval(bindings.withBinding(expr.getIdentifier(), o)).asBoolean()); - return ExprEval.of(allMatch, ExprType.LONG); + return ExprEval.ofLongBoolean(allMatch); } } @@ -848,4 +888,38 @@ public interface ApplyFunction return this; } } + + /** + * Helper that can wrap another {@link Expr.InputBindingTypes} to use to supply the type information of a + * {@link LambdaExpr} when evaluating {@link ApplyFunctionExpr#getOutputType}. Lambda identifiers do not exist + * in the underlying {@link Expr.InputBindingTypes}, but can be created by mapping the lambda identifiers to the + * arguments that will be applied to them, to map the type information. + */ + class LambdaInputBindingTypes implements Expr.InputBindingTypes + { + private final Object2IntMap lambdaIdentifiers; + private final Expr.InputBindingTypes inputTypes; + private final List args; + + public LambdaInputBindingTypes(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) + { + this.inputTypes = inputTypes; + this.args = args; + List identifiers = expr.getIdentifiers(); + this.lambdaIdentifiers = new Object2IntOpenHashMap<>(args.size()); + for (int i = 0; i < args.size(); i++) { + lambdaIdentifiers.put(identifiers.get(i), i); + } + } + + @Nullable + @Override + public ExprType getType(String name) + { + if (lambdaIdentifiers.containsKey(name)) { + return ExprType.elementType(args.get(lambdaIdentifiers.getInt(name)).getOutputType(inputTypes)); + } + return inputTypes.getType(name); + } + } } diff --git a/core/src/main/java/org/apache/druid/math/expr/BinaryLogicalOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/BinaryLogicalOperatorExpr.java index dad35f30560..58cb5a08de8 100644 --- a/core/src/main/java/org/apache/druid/math/expr/BinaryLogicalOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/BinaryLogicalOperatorExpr.java @@ -42,7 +42,7 @@ class BinLtExpr extends BinaryEvalOpExprBase @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { - return ExprEval.of(Comparators.naturalNullsFirst().compare(left, right) < 0, ExprType.LONG); + return ExprEval.ofLongBoolean(Comparators.naturalNullsFirst().compare(left, right) < 0); } @Override @@ -57,6 +57,17 @@ class BinLtExpr extends BinaryEvalOpExprBase // Use Double.compare for more consistent NaN handling. return Evals.asDouble(Double.compare(left, right) < 0); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } class BinLeqExpr extends BinaryEvalOpExprBase @@ -75,7 +86,7 @@ class BinLeqExpr extends BinaryEvalOpExprBase @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { - return ExprEval.of(Comparators.naturalNullsFirst().compare(left, right) <= 0, ExprType.LONG); + return ExprEval.ofLongBoolean(Comparators.naturalNullsFirst().compare(left, right) <= 0); } @Override @@ -90,6 +101,17 @@ class BinLeqExpr extends BinaryEvalOpExprBase // Use Double.compare for more consistent NaN handling. return Evals.asDouble(Double.compare(left, right) <= 0); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } class BinGtExpr extends BinaryEvalOpExprBase @@ -108,7 +130,7 @@ class BinGtExpr extends BinaryEvalOpExprBase @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { - return ExprEval.of(Comparators.naturalNullsFirst().compare(left, right) > 0, ExprType.LONG); + return ExprEval.ofLongBoolean(Comparators.naturalNullsFirst().compare(left, right) > 0); } @Override @@ -123,6 +145,17 @@ class BinGtExpr extends BinaryEvalOpExprBase // Use Double.compare for more consistent NaN handling. return Evals.asDouble(Double.compare(left, right) > 0); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } class BinGeqExpr extends BinaryEvalOpExprBase @@ -141,7 +174,7 @@ class BinGeqExpr extends BinaryEvalOpExprBase @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { - return ExprEval.of(Comparators.naturalNullsFirst().compare(left, right) >= 0, ExprType.LONG); + return ExprEval.ofLongBoolean(Comparators.naturalNullsFirst().compare(left, right) >= 0); } @Override @@ -156,6 +189,17 @@ class BinGeqExpr extends BinaryEvalOpExprBase // Use Double.compare for more consistent NaN handling. return Evals.asDouble(Double.compare(left, right) >= 0); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } class BinEqExpr extends BinaryEvalOpExprBase @@ -174,7 +218,7 @@ class BinEqExpr extends BinaryEvalOpExprBase @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { - return ExprEval.of(Objects.equals(left, right), ExprType.LONG); + return ExprEval.ofLongBoolean(Objects.equals(left, right)); } @Override @@ -188,6 +232,17 @@ class BinEqExpr extends BinaryEvalOpExprBase { return Evals.asDouble(left == right); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } class BinNeqExpr extends BinaryEvalOpExprBase @@ -206,7 +261,7 @@ class BinNeqExpr extends BinaryEvalOpExprBase @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { - return ExprEval.of(!Objects.equals(left, right), ExprType.LONG); + return ExprEval.ofLongBoolean(!Objects.equals(left, right)); } @Override @@ -220,6 +275,17 @@ class BinNeqExpr extends BinaryEvalOpExprBase { return Evals.asDouble(left != right); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } class BinAndExpr extends BinaryOpExprBase @@ -262,5 +328,4 @@ class BinOrExpr extends BinaryOpExprBase ExprEval leftVal = left.eval(bindings); return leftVal.asBoolean() ? leftVal : right.eval(bindings); } - } diff --git a/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java index 9c390587bd4..9db527bf5b4 100644 --- a/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java @@ -81,12 +81,19 @@ abstract class BinaryOpExprBase implements Expr protected abstract BinaryOpExprBase copy(Expr left, Expr right); @Override - public BindingDetails analyzeInputs() + public BindingAnalysis analyzeInputs() { // currently all binary operators operate on scalar inputs return left.analyzeInputs().with(right).withScalarArguments(ImmutableSet.of(left, right)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.operatorAutoTypeConversion(left.getOutputType(inputTypes), right.getOutputType(inputTypes)); + } + @Override public boolean equals(Object o) { diff --git a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java index 4f6099cf366..ef090b5edd3 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java @@ -35,6 +35,20 @@ import java.util.Objects; */ abstract class ConstantExpr implements Expr { + final ExprType outputType; + + protected ConstantExpr(ExprType outputType) + { + this.outputType = outputType; + } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return outputType; + } + @Override public boolean isLiteral() { @@ -54,9 +68,9 @@ abstract class ConstantExpr implements Expr } @Override - public BindingDetails analyzeInputs() + public BindingAnalysis analyzeInputs() { - return new BindingDetails(); + return new BindingAnalysis(); } @Override @@ -71,6 +85,11 @@ abstract class ConstantExpr implements Expr */ abstract class NullNumericConstantExpr extends ConstantExpr { + protected NullNumericConstantExpr(ExprType outputType) + { + super(outputType); + } + @Override public Object getLiteralValue() { @@ -82,6 +101,8 @@ abstract class NullNumericConstantExpr extends ConstantExpr { return NULL_LITERAL; } + + } class LongExpr extends ConstantExpr @@ -90,6 +111,7 @@ class LongExpr extends ConstantExpr LongExpr(Long value) { + super(ExprType.LONG); this.value = Preconditions.checkNotNull(value, "value"); } @@ -133,6 +155,11 @@ class LongExpr extends ConstantExpr class NullLongExpr extends NullNumericConstantExpr { + NullLongExpr() + { + super(ExprType.LONG); + } + @Override public ExprEval eval(ObjectBinding bindings) { @@ -158,6 +185,7 @@ class LongArrayExpr extends ConstantExpr LongArrayExpr(Long[] value) { + super(ExprType.LONG_ARRAY); this.value = Preconditions.checkNotNull(value, "value"); } @@ -215,6 +243,7 @@ class StringExpr extends ConstantExpr StringExpr(@Nullable String value) { + super(ExprType.STRING); this.value = NullHandling.emptyToNullIfNeeded(value); } @@ -270,6 +299,7 @@ class StringArrayExpr extends ConstantExpr StringArrayExpr(String[] value) { + super(ExprType.STRING_ARRAY); this.value = Preconditions.checkNotNull(value, "value"); } @@ -338,6 +368,7 @@ class DoubleExpr extends ConstantExpr DoubleExpr(Double value) { + super(ExprType.DOUBLE); this.value = Preconditions.checkNotNull(value, "value"); } @@ -381,6 +412,11 @@ class DoubleExpr extends ConstantExpr class NullDoubleExpr extends NullNumericConstantExpr { + NullDoubleExpr() + { + super(ExprType.DOUBLE); + } + @Override public ExprEval eval(ObjectBinding bindings) { @@ -406,6 +442,7 @@ class DoubleArrayExpr extends ConstantExpr DoubleArrayExpr(Double[] value) { + super(ExprType.DOUBLE_ARRAY); this.value = Preconditions.checkNotNull(value, "value"); } diff --git a/core/src/main/java/org/apache/druid/math/expr/Expr.java b/core/src/main/java/org/apache/druid/math/expr/Expr.java index e0a1525c7df..2a13be3f845 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Expr.java +++ b/core/src/main/java/org/apache/druid/math/expr/Expr.java @@ -116,16 +116,39 @@ public interface Expr void visit(Visitor visitor); /** - * Programatically rewrite the {@link Expr} tree with a {@link Shuttle}.Each {@link Expr} is responsible for + * Programatically rewrite the {@link Expr} tree with a {@link Shuttle}. Each {@link Expr} is responsible for * ensuring the {@link Shuttle} can visit all of its {@link Expr} children, as well as updating its children * {@link Expr} with the results from the {@link Shuttle}, before finally visiting an updated form of itself. */ Expr visit(Shuttle shuttle); /** - * Examine the usage of {@link IdentifierExpr} children of an {@link Expr}, constructing a {@link BindingDetails} + * Examine the usage of {@link IdentifierExpr} children of an {@link Expr}, constructing a {@link BindingAnalysis} */ - BindingDetails analyzeInputs(); + BindingAnalysis analyzeInputs(); + + /** + * Given an {@link InputBindingTypes}, compute what the output {@link ExprType} will be for this expression. A return + * value of null indicates that the given type information was not enough to resolve the output type, so the + * expression must be evaluated using default {@link #eval} handling where types are only known after evaluation, + * through {@link ExprEval#type}. + */ + @Nullable + default ExprType getOutputType(InputBindingTypes inputTypes) + { + return null; + } + + /** + * Mechanism to supply input types for the bindings which will back {@link IdentifierExpr}, to use in the aid of + * inferring the output type of an expression with {@link #getOutputType}. A null value means that either the binding + * doesn't exist, or, that the type information is unavailable. + */ + interface InputBindingTypes + { + @Nullable + ExprType getType(String name); + } /** * Mechanism to supply values to back {@link IdentifierExpr} during expression evaluation @@ -180,7 +203,7 @@ public interface Expr * * This means in rare cases and mostly for "questionable" expressions which we still allow to function 'correctly', * these lists might not be fully reliable without a complete type inference system in place. Due to this shortcoming, - * boolean values {@link BindingDetails#hasInputArrays()} and {@link BindingDetails#isOutputArray()} are provided to + * boolean values {@link BindingAnalysis#hasInputArrays()} and {@link BindingAnalysis#isOutputArray()} are provided to * allow functions to explicitly declare that they utilize array typed values, used when determining if some types of * optimizations can be applied when constructing the expression column value selector. * @@ -194,7 +217,7 @@ public interface Expr * @see org.apache.druid.segment.virtual.ExpressionSelectors#makeColumnValueSelector */ @SuppressWarnings("JavadocReference") - class BindingDetails + class BindingAnalysis { private final ImmutableSet freeVariables; private final ImmutableSet scalarVariables; @@ -202,17 +225,17 @@ public interface Expr private final boolean hasInputArrays; private final boolean isOutputArray; - BindingDetails() + BindingAnalysis() { this(ImmutableSet.of(), ImmutableSet.of(), ImmutableSet.of(), false, false); } - BindingDetails(IdentifierExpr expr) + BindingAnalysis(IdentifierExpr expr) { this(ImmutableSet.of(expr), ImmutableSet.of(), ImmutableSet.of(), false, false); } - private BindingDetails( + private BindingAnalysis( ImmutableSet freeVariables, ImmutableSet scalarVariables, ImmutableSet arrayVariables, @@ -310,19 +333,19 @@ public interface Expr } /** - * Combine with {@link BindingDetails} from {@link Expr#analyzeInputs()} + * Combine with {@link BindingAnalysis} from {@link Expr#analyzeInputs()} */ - public BindingDetails with(Expr other) + public BindingAnalysis with(Expr other) { return with(other.analyzeInputs()); } /** - * Combine (union) another {@link BindingDetails} + * Combine (union) another {@link BindingAnalysis} */ - public BindingDetails with(BindingDetails other) + public BindingAnalysis with(BindingAnalysis other) { - return new BindingDetails( + return new BindingAnalysis( ImmutableSet.copyOf(Sets.union(freeVariables, other.freeVariables)), ImmutableSet.copyOf(Sets.union(scalarVariables, other.scalarVariables)), ImmutableSet.copyOf(Sets.union(arrayVariables, other.arrayVariables)), @@ -332,10 +355,10 @@ public interface Expr } /** - * Add set of arguments as {@link BindingDetails#scalarVariables} that are *directly* {@link IdentifierExpr}, + * Add set of arguments as {@link BindingAnalysis#scalarVariables} that are *directly* {@link IdentifierExpr}, * else they are ignored. */ - public BindingDetails withScalarArguments(Set scalarArguments) + public BindingAnalysis withScalarArguments(Set scalarArguments) { Set moreScalars = new HashSet<>(); for (Expr expr : scalarArguments) { @@ -344,7 +367,7 @@ public interface Expr moreScalars.add((IdentifierExpr) expr); } } - return new BindingDetails( + return new BindingAnalysis( ImmutableSet.copyOf(Sets.union(freeVariables, moreScalars)), ImmutableSet.copyOf(Sets.union(scalarVariables, moreScalars)), arrayVariables, @@ -354,10 +377,10 @@ public interface Expr } /** - * Add set of arguments as {@link BindingDetails#arrayVariables} that are *directly* {@link IdentifierExpr}, + * Add set of arguments as {@link BindingAnalysis#arrayVariables} that are *directly* {@link IdentifierExpr}, * else they are ignored. */ - BindingDetails withArrayArguments(Set arrayArguments) + BindingAnalysis withArrayArguments(Set arrayArguments) { Set arrayIdentifiers = new HashSet<>(); for (Expr expr : arrayArguments) { @@ -366,7 +389,7 @@ public interface Expr arrayIdentifiers.add((IdentifierExpr) expr); } } - return new BindingDetails( + return new BindingAnalysis( ImmutableSet.copyOf(Sets.union(freeVariables, arrayIdentifiers)), scalarVariables, ImmutableSet.copyOf(Sets.union(arrayVariables, arrayIdentifiers)), @@ -378,9 +401,9 @@ public interface Expr /** * Copy, setting if an expression has array inputs */ - BindingDetails withArrayInputs(boolean hasArrays) + BindingAnalysis withArrayInputs(boolean hasArrays) { - return new BindingDetails( + return new BindingAnalysis( freeVariables, scalarVariables, arrayVariables, @@ -392,9 +415,9 @@ public interface Expr /** * Copy, setting if an expression produces an array output */ - BindingDetails withArrayOutput(boolean isOutputArray) + BindingAnalysis withArrayOutput(boolean isOutputArray) { - return new BindingDetails( + return new BindingAnalysis( freeVariables, scalarVariables, arrayVariables, @@ -407,9 +430,9 @@ public interface Expr * Remove any {@link IdentifierExpr} that are from a {@link LambdaExpr}, since the {@link ApplyFunction} will * provide bindings for these variables. */ - BindingDetails removeLambdaArguments(Set lambda) + BindingAnalysis removeLambdaArguments(Set lambda) { - return new BindingDetails( + return new BindingAnalysis( ImmutableSet.copyOf(freeVariables.stream().filter(x -> !lambda.contains(x.getIdentifier())).iterator()), ImmutableSet.copyOf(scalarVariables.stream().filter(x -> !lambda.contains(x.getIdentifier())).iterator()), ImmutableSet.copyOf(arrayVariables.stream().filter(x -> !lambda.contains(x.getIdentifier())).iterator()), diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java index 61cdc26f6dd..1c02186296b 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java @@ -89,6 +89,11 @@ public abstract class ExprEval } } + public static ExprEval ofLongBoolean(boolean value) + { + return ExprEval.of(Evals.asLong(value)); + } + public static ExprEval bestEffortOf(@Nullable Object val) { if (val instanceof ExprEval) { diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java b/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java index ae41653950f..3f69f6e0b7e 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java @@ -482,7 +482,7 @@ public class ExprListenerImpl extends ExprBaseListener * {@link IdentifierExpr#identifier} be the same as {@link IdentifierExpr#binding} because they have * synthetic bindings set at evaluation time. This is done to aid in analysis needed for the automatic expression * translation which maps scalar expressions to multi-value inputs. See - * {@link Parser#applyUnappliedBindings(Expr, Expr.BindingDetails, List)}} for additional details. + * {@link Parser#applyUnappliedBindings(Expr, Expr.BindingAnalysis, List)}} for additional details. */ private IdentifierExpr createIdentifierExpr(String binding) { diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java b/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java index f7cf1d0f648..616297a57c3 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java @@ -102,7 +102,7 @@ public class ExprMacroTable protected final Expr arg; // Use Supplier to memoize values as ExpressionSelectors#makeExprEvalSelector() can make repeated calls for them - private final Supplier analyzeInputsSupplier; + private final Supplier analyzeInputsSupplier; public BaseScalarUnivariateMacroFunctionExpr(String name, Expr arg) { @@ -119,7 +119,7 @@ public class ExprMacroTable } @Override - public BindingDetails analyzeInputs() + public BindingAnalysis analyzeInputs() { return analyzeInputsSupplier.get(); } @@ -150,7 +150,7 @@ public class ExprMacroTable return Objects.hash(name, arg); } - private BindingDetails supplyAnalyzeInputs() + private BindingAnalysis supplyAnalyzeInputs() { return arg.analyzeInputs().withScalarArguments(ImmutableSet.of(arg)); } @@ -165,7 +165,7 @@ public class ExprMacroTable protected final List args; // Use Supplier to memoize values as ExpressionSelectors#makeExprEvalSelector() can make repeated calls for them - private final Supplier analyzeInputsSupplier; + private final Supplier analyzeInputsSupplier; public BaseScalarMacroFunctionExpr(String name, final List args) { @@ -194,7 +194,7 @@ public class ExprMacroTable } @Override - public BindingDetails analyzeInputs() + public BindingAnalysis analyzeInputs() { return analyzeInputsSupplier.get(); } @@ -219,10 +219,10 @@ public class ExprMacroTable return Objects.hash(name, args); } - private BindingDetails supplyAnalyzeInputs() + private BindingAnalysis supplyAnalyzeInputs() { final Set argSet = Sets.newHashSetWithExpectedSize(args.size()); - BindingDetails accumulator = new BindingDetails(); + BindingAnalysis accumulator = new BindingAnalysis(); for (Expr arg : args) { accumulator = accumulator.with(arg); argSet.add(arg); diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java index 0bc1573bef5..3b9108de921 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java @@ -19,6 +19,12 @@ package org.apache.druid.math.expr; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.segment.column.ValueType; + +import javax.annotation.Nullable; + /** * Base 'value' types of Druid expression language, all {@link Expr} must evaluate to one of these types. */ @@ -29,5 +35,145 @@ public enum ExprType STRING, DOUBLE_ARRAY, LONG_ARRAY, - STRING_ARRAY + STRING_ARRAY; + + public boolean isNumeric() + { + return isNumeric(this); + } + + /** + * The expression system does not distinguish between {@link ValueType#FLOAT} and {@link ValueType#DOUBLE}, and + * cannot currently handle {@link ValueType#COMPLEX} inputs. This method will convert {@link ValueType#FLOAT} to + * {@link #DOUBLE}, or throw an exception if a {@link ValueType#COMPLEX} is encountered. + * + * @throws IllegalStateException + */ + public static ExprType fromValueType(@Nullable ValueType valueType) + { + if (valueType == null) { + throw new IllegalStateException("Unsupported unknown value type"); + } + switch (valueType) { + case LONG: + return LONG; + case LONG_ARRAY: + return LONG_ARRAY; + case FLOAT: + case DOUBLE: + return DOUBLE; + case DOUBLE_ARRAY: + return DOUBLE_ARRAY; + case STRING: + return STRING; + case STRING_ARRAY: + return STRING_ARRAY; + case COMPLEX: + default: + throw new ISE("Unsupported value type[%s]", valueType); + } + } + + public static boolean isNumeric(ExprType type) + { + return LONG.equals(type) || DOUBLE.equals(type); + } + + public static boolean isArray(@Nullable ExprType type) + { + return LONG_ARRAY.equals(type) || DOUBLE_ARRAY.equals(type) || STRING_ARRAY.equals(type); + } + + @Nullable + public static ExprType elementType(@Nullable ExprType type) + { + if (type != null) { + switch (type) { + case STRING_ARRAY: + return STRING; + case LONG_ARRAY: + return LONG; + case DOUBLE_ARRAY: + return DOUBLE; + } + } + return type; + } + + @Nullable + public static ExprType asArrayType(@Nullable ExprType elementType) + { + if (elementType != null) { + switch (elementType) { + case STRING: + return STRING_ARRAY; + case LONG: + return LONG_ARRAY; + case DOUBLE: + return DOUBLE_ARRAY; + } + } + return elementType; + } + + /** + * Given 2 'input' types, choose the most appropriate combined type, if possible + */ + @Nullable + public static ExprType operatorAutoTypeConversion(@Nullable ExprType type, @Nullable ExprType other) + { + if (type == null || other == null) { + // cannot auto conversion unknown types + return null; + } + // arrays cannot be auto converted + if (isArray(type) || isArray(other)) { + if (!type.equals(other)) { + throw new IAE("Cannot implicitly cast %s to %s", type, other); + } + return type; + } + // if both arguments are a string, type becomes a string + if (STRING.equals(type) && STRING.equals(other)) { + return STRING; + } + + return numericAutoTypeConversion(type, other); + } + + /** + * Given 2 'input' types, choose the most appropriate combined type, if possible + */ + @Nullable + public static ExprType functionAutoTypeConversion(@Nullable ExprType type, @Nullable ExprType other) + { + if (type == null || other == null) { + // cannot auto conversion unknown types + return null; + } + // arrays cannot be auto converted + if (isArray(type) || isArray(other)) { + if (!type.equals(other)) { + throw new IAE("Cannot implicitly cast %s to %s", type, other); + } + return type; + } + // if either argument is a string, type becomes a string + if (STRING.equals(type) || STRING.equals(other)) { + return STRING; + } + + return numericAutoTypeConversion(type, other); + } + + @Nullable + public static ExprType numericAutoTypeConversion(ExprType type, ExprType other) + { + // all numbers win over longs + if (LONG.equals(type) && LONG.equals(other)) { + return LONG; + } + // floats vs doubles would be handled here, but we currently only support doubles... + return DOUBLE; + } } diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index a2086392987..2e27aab84ae 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -36,9 +36,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; -import java.util.EnumSet; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.function.BinaryOperator; import java.util.function.DoubleBinaryOperator; @@ -104,6 +104,14 @@ public interface Function */ void validateArguments(List args); + /** + * Compute the output type of this function for a given set of argument expression inputs. + * + * @see Expr#getOutputType + */ + @Nullable + ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args); + /** * Base class for a single variable input {@link Function} implementation */ @@ -180,6 +188,26 @@ public interface Function { return eval((long) param); } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return args.get(0).getOutputType(inputTypes); + } + } + + /** + * Many math functions always output a {@link Double} primitive, regardless of input type. + */ + abstract class DoubleUnivariateMathFunction extends UnivariateMathFunction + { + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.DOUBLE; + } } /** @@ -210,6 +238,26 @@ public interface Function { return eval((long) x, (long) y); } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.functionAutoTypeConversion(args.get(0).getOutputType(inputTypes), args.get(1).getOutputType(inputTypes)); + } + } + + /** + * Many math functions always output a {@link Double} primitive, regardless of input type. + */ + abstract class DoubleBivariateMathFunction extends BivariateMathFunction + { + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.DOUBLE; + } } /** @@ -325,6 +373,100 @@ public interface Function abstract ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr); } + abstract class ReduceFunction implements Function + { + private final DoubleBinaryOperator doubleReducer; + private final LongBinaryOperator longReducer; + private final BinaryOperator stringReducer; + + ReduceFunction( + DoubleBinaryOperator doubleReducer, + LongBinaryOperator longReducer, + BinaryOperator stringReducer + ) + { + this.doubleReducer = doubleReducer; + this.longReducer = longReducer; + this.stringReducer = stringReducer; + } + + @Override + public void validateArguments(List args) + { + // anything goes + } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + ExprType outputType = ExprType.LONG; + for (Expr expr : args) { + outputType = ExprType.functionAutoTypeConversion(outputType, expr.getOutputType(inputTypes)); + } + return outputType; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + if (args.isEmpty()) { + return ExprEval.of(null); + } + + // evaluate arguments and collect output type + List> evals = new ArrayList<>(); + ExprType outputType = ExprType.LONG; + + for (Expr expr : args) { + ExprEval exprEval = expr.eval(bindings); + ExprType exprType = exprEval.type(); + + if (isValidType(exprType)) { + outputType = ExprType.functionAutoTypeConversion(outputType, exprType); + } + + if (exprEval.value() != null) { + evals.add(exprEval); + } + } + + if (evals.isEmpty()) { + // The GREATEST/LEAST functions are not in the SQL standard. Emulate the behavior of postgres (return null if + // all expressions are null, otherwise skip null values) since it is used as a base for a wide number of + // databases. This also matches the behavior the the long/double greatest/least post aggregators. Some other + // databases (e.g., MySQL) return null if any expression is null. + // https://www.postgresql.org/docs/9.5/functions-conditional.html + // https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least + return ExprEval.of(null); + } + + switch (outputType) { + case DOUBLE: + //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) + return ExprEval.of(evals.stream().mapToDouble(ExprEval::asDouble).reduce(doubleReducer).getAsDouble()); + case LONG: + //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) + return ExprEval.of(evals.stream().mapToLong(ExprEval::asLong).reduce(longReducer).getAsLong()); + default: + //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) + return ExprEval.of(evals.stream().map(ExprEval::asString).reduce(stringReducer).get()); + } + } + + private boolean isValidType(ExprType exprType) + { + switch (exprType) { + case DOUBLE: + case LONG: + case STRING: + return true; + default: + throw new IAE("Function[%s] does not accept %s types", name(), exprType); + } + } + } + // ------------------------------ implementations ------------------------------ class ParseLong implements Function @@ -343,6 +485,13 @@ public interface Function } } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { @@ -393,6 +542,13 @@ public interface Function throw new IAE("Function[%s] needs 0 argument", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.DOUBLE; + } } class Abs extends UnivariateMathFunction @@ -416,7 +572,7 @@ public interface Function } } - class Acos extends UnivariateMathFunction + class Acos extends DoubleUnivariateMathFunction { @Override public String name() @@ -431,7 +587,7 @@ public interface Function } } - class Asin extends UnivariateMathFunction + class Asin extends DoubleUnivariateMathFunction { @Override public String name() @@ -446,7 +602,7 @@ public interface Function } } - class Atan extends UnivariateMathFunction + class Atan extends DoubleUnivariateMathFunction { @Override public String name() @@ -461,7 +617,7 @@ public interface Function } } - class Cbrt extends UnivariateMathFunction + class Cbrt extends DoubleUnivariateMathFunction { @Override public String name() @@ -476,7 +632,7 @@ public interface Function } } - class Ceil extends UnivariateMathFunction + class Ceil extends DoubleUnivariateMathFunction { @Override public String name() @@ -491,7 +647,7 @@ public interface Function } } - class Cos extends UnivariateMathFunction + class Cos extends DoubleUnivariateMathFunction { @Override public String name() @@ -506,7 +662,7 @@ public interface Function } } - class Cosh extends UnivariateMathFunction + class Cosh extends DoubleUnivariateMathFunction { @Override public String name() @@ -521,7 +677,7 @@ public interface Function } } - class Cot extends UnivariateMathFunction + class Cot extends DoubleUnivariateMathFunction { @Override public String name() @@ -557,7 +713,7 @@ public interface Function } } - class Exp extends UnivariateMathFunction + class Exp extends DoubleUnivariateMathFunction { @Override public String name() @@ -572,7 +728,7 @@ public interface Function } } - class Expm1 extends UnivariateMathFunction + class Expm1 extends DoubleUnivariateMathFunction { @Override public String name() @@ -587,7 +743,7 @@ public interface Function } } - class Floor extends UnivariateMathFunction + class Floor extends DoubleUnivariateMathFunction { @Override public String name() @@ -617,7 +773,7 @@ public interface Function } } - class Log extends UnivariateMathFunction + class Log extends DoubleUnivariateMathFunction { @Override public String name() @@ -632,7 +788,7 @@ public interface Function } } - class Log10 extends UnivariateMathFunction + class Log10 extends DoubleUnivariateMathFunction { @Override public String name() @@ -647,7 +803,7 @@ public interface Function } } - class Log1p extends UnivariateMathFunction + class Log1p extends DoubleUnivariateMathFunction { @Override public String name() @@ -662,7 +818,7 @@ public interface Function } } - class NextUp extends UnivariateMathFunction + class NextUp extends DoubleUnivariateMathFunction { @Override public String name() @@ -677,7 +833,7 @@ public interface Function } } - class Rint extends UnivariateMathFunction + class Rint extends DoubleUnivariateMathFunction { @Override public String name() @@ -740,6 +896,13 @@ public interface Function } } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return args.get(0).getOutputType(inputTypes); + } + private ExprEval eval(ExprEval param) { return eval(param, 0); @@ -773,7 +936,7 @@ public interface Function } } - class Signum extends UnivariateMathFunction + class Signum extends DoubleUnivariateMathFunction { @Override public String name() @@ -788,7 +951,7 @@ public interface Function } } - class Sin extends UnivariateMathFunction + class Sin extends DoubleUnivariateMathFunction { @Override public String name() @@ -803,7 +966,7 @@ public interface Function } } - class Sinh extends UnivariateMathFunction + class Sinh extends DoubleUnivariateMathFunction { @Override public String name() @@ -818,7 +981,7 @@ public interface Function } } - class Sqrt extends UnivariateMathFunction + class Sqrt extends DoubleUnivariateMathFunction { @Override public String name() @@ -833,7 +996,7 @@ public interface Function } } - class Tan extends UnivariateMathFunction + class Tan extends DoubleUnivariateMathFunction { @Override public String name() @@ -848,7 +1011,7 @@ public interface Function } } - class Tanh extends UnivariateMathFunction + class Tanh extends DoubleUnivariateMathFunction { @Override public String name() @@ -863,7 +1026,7 @@ public interface Function } } - class ToDegrees extends UnivariateMathFunction + class ToDegrees extends DoubleUnivariateMathFunction { @Override public String name() @@ -878,7 +1041,7 @@ public interface Function } } - class ToRadians extends UnivariateMathFunction + class ToRadians extends DoubleUnivariateMathFunction { @Override public String name() @@ -893,7 +1056,7 @@ public interface Function } } - class Ulp extends UnivariateMathFunction + class Ulp extends DoubleUnivariateMathFunction { @Override public String name() @@ -908,7 +1071,7 @@ public interface Function } } - class Atan2 extends BivariateMathFunction + class Atan2 extends DoubleBivariateMathFunction { @Override public String name() @@ -923,7 +1086,7 @@ public interface Function } } - class CopySign extends BivariateMathFunction + class CopySign extends DoubleBivariateMathFunction { @Override public String name() @@ -938,7 +1101,7 @@ public interface Function } } - class Hypot extends BivariateMathFunction + class Hypot extends DoubleBivariateMathFunction { @Override public String name() @@ -953,7 +1116,7 @@ public interface Function } } - class Remainder extends BivariateMathFunction + class Remainder extends DoubleBivariateMathFunction { @Override public String name() @@ -1010,173 +1173,7 @@ public interface Function } } - class GreatestFunc extends ReduceFunc - { - public static final String NAME = "greatest"; - - public GreatestFunc() - { - super( - Math::max, - Math::max, - BinaryOperator.maxBy(Comparator.naturalOrder()) - ); - } - - @Override - public String name() - { - return NAME; - } - } - - class LeastFunc extends ReduceFunc - { - public static final String NAME = "least"; - - public LeastFunc() - { - super( - Math::min, - Math::min, - BinaryOperator.minBy(Comparator.naturalOrder()) - ); - } - - @Override - public String name() - { - return NAME; - } - } - - abstract class ReduceFunc implements Function - { - private final DoubleBinaryOperator doubleReducer; - private final LongBinaryOperator longReducer; - private final BinaryOperator stringReducer; - - ReduceFunc( - DoubleBinaryOperator doubleReducer, - LongBinaryOperator longReducer, - BinaryOperator stringReducer - ) - { - this.doubleReducer = doubleReducer; - this.longReducer = longReducer; - this.stringReducer = stringReducer; - } - - @Override - public void validateArguments(List args) - { - // anything goes - } - - @Override - public ExprEval apply(List args, Expr.ObjectBinding bindings) - { - if (args.isEmpty()) { - return ExprEval.of(null); - } - - ExprAnalysis exprAnalysis = analyzeExprs(args, bindings); - if (exprAnalysis.exprEvals.isEmpty()) { - // The GREATEST/LEAST functions are not in the SQL standard. Emulate the behavior of postgres (return null if - // all expressions are null, otherwise skip null values) since it is used as a base for a wide number of - // databases. This also matches the behavior the the long/double greatest/least post aggregators. Some other - // databases (e.g., MySQL) return null if any expression is null. - // https://www.postgresql.org/docs/9.5/functions-conditional.html - // https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least - return ExprEval.of(null); - } - - Stream> exprEvalStream = exprAnalysis.exprEvals.stream(); - switch (exprAnalysis.comparisonType) { - case DOUBLE: - //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) - return ExprEval.of(exprEvalStream.mapToDouble(ExprEval::asDouble).reduce(doubleReducer).getAsDouble()); - case LONG: - //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) - return ExprEval.of(exprEvalStream.mapToLong(ExprEval::asLong).reduce(longReducer).getAsLong()); - default: - //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) - return ExprEval.of(exprEvalStream.map(ExprEval::asString).reduce(stringReducer).get()); - } - } - - /** - * Determines which {@link ExprType} to use to compare non-null evaluated expressions. - * - * @param exprs Expressions to analyze - * @param bindings Bindings for expressions - * - * @return Comparison type and non-null evaluated expressions. - */ - private ExprAnalysis analyzeExprs(List exprs, Expr.ObjectBinding bindings) - { - Set presentTypes = EnumSet.noneOf(ExprType.class); - List> exprEvals = new ArrayList<>(); - - for (Expr expr : exprs) { - ExprEval exprEval = expr.eval(bindings); - ExprType exprType = exprEval.type(); - - if (isValidType(exprType)) { - presentTypes.add(exprType); - } - - if (exprEval.value() != null) { - exprEvals.add(exprEval); - } - } - - ExprType comparisonType = getComparisionType(presentTypes); - return new ExprAnalysis(comparisonType, exprEvals); - } - - private boolean isValidType(ExprType exprType) - { - switch (exprType) { - case DOUBLE: - case LONG: - case STRING: - return true; - default: - throw new IAE("Function[%s] does not accept %s types", name(), exprType); - } - } - - /** - * Implements rules similar to: https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least - * - * @see org.apache.druid.sql.calcite.expression.builtin.ReductionOperatorConversionHelper#TYPE_INFERENCE - */ - private static ExprType getComparisionType(Set exprTypes) - { - if (exprTypes.contains(ExprType.STRING)) { - return ExprType.STRING; - } else if (exprTypes.contains(ExprType.DOUBLE)) { - return ExprType.DOUBLE; - } else { - return ExprType.LONG; - } - } - - private static class ExprAnalysis - { - final ExprType comparisonType; - final List> exprEvals; - - ExprAnalysis(ExprType comparisonType, List> exprEvals) - { - this.comparisonType = comparisonType; - this.exprEvals = exprEvals; - } - } - } - - class NextAfter extends BivariateMathFunction + class NextAfter extends DoubleBivariateMathFunction { @Override public String name() @@ -1191,7 +1188,7 @@ public interface Function } } - class Pow extends BivariateMathFunction + class Pow extends DoubleBivariateMathFunction { @Override public String name() @@ -1214,6 +1211,13 @@ public interface Function return "scalb"; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.DOUBLE; + } + @Override protected ExprEval eval(ExprEval x, ExprEval y) { @@ -1221,102 +1225,6 @@ public interface Function } } - class ConditionFunc implements Function - { - @Override - public String name() - { - return "if"; - } - - @Override - public ExprEval apply(List args, Expr.ObjectBinding bindings) - { - ExprEval x = args.get(0).eval(bindings); - return x.asBoolean() ? args.get(1).eval(bindings) : args.get(2).eval(bindings); - } - - @Override - public void validateArguments(List args) - { - if (args.size() != 3) { - throw new IAE("Function[%s] needs 3 arguments", name()); - } - } - } - - /** - * "Searched CASE" function, similar to {@code CASE WHEN boolean_expr THEN result [ELSE else_result] END} in SQL. - */ - class CaseSearchedFunc implements Function - { - @Override - public String name() - { - return "case_searched"; - } - - @Override - public ExprEval apply(final List args, final Expr.ObjectBinding bindings) - { - for (int i = 0; i < args.size(); i += 2) { - if (i == args.size() - 1) { - // ELSE else_result. - return args.get(i).eval(bindings); - } else if (args.get(i).eval(bindings).asBoolean()) { - // Matching WHEN boolean_expr THEN result - return args.get(i + 1).eval(bindings); - } - } - - return ExprEval.of(null); - } - - @Override - public void validateArguments(List args) - { - if (args.size() < 2) { - throw new IAE("Function[%s] must have at least 2 arguments", name()); - } - } - } - - /** - * "Simple CASE" function, similar to {@code CASE expr WHEN value THEN result [ELSE else_result] END} in SQL. - */ - class CaseSimpleFunc implements Function - { - @Override - public String name() - { - return "case_simple"; - } - - @Override - public ExprEval apply(final List args, final Expr.ObjectBinding bindings) - { - for (int i = 1; i < args.size(); i += 2) { - if (i == args.size() - 1) { - // ELSE else_result. - return args.get(i).eval(bindings); - } else if (new BinEqExpr("==", args.get(0), args.get(i)).eval(bindings).asBoolean()) { - // Matching WHEN value THEN result - return args.get(i + 1).eval(bindings); - } - } - - return ExprEval.of(null); - } - - @Override - public void validateArguments(List args) - { - if (args.size() < 3) { - throw new IAE("Function[%s] must have at least 3 arguments", name()); - } - } - } - class CastFunc extends BivariateFunction { @Override @@ -1376,68 +1284,176 @@ public interface Function // unknown cast, can't safely assume either way return Collections.emptySet(); } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + // can only know cast output type if cast to argument is constant + if (args.get(1).isLiteral()) { + return ExprType.valueOf(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())); + } + return null; + } } - class TimestampFromEpochFunc implements Function + class GreatestFunc extends ReduceFunction + { + public static final String NAME = "greatest"; + + public GreatestFunc() + { + super( + Math::max, + Math::max, + BinaryOperator.maxBy(Comparator.naturalOrder()) + ); + } + + @Override + public String name() + { + return NAME; + } + } + + class LeastFunc extends ReduceFunction + { + public static final String NAME = "least"; + + public LeastFunc() + { + super( + Math::min, + Math::min, + BinaryOperator.minBy(Comparator.naturalOrder()) + ); + } + + @Override + public String name() + { + return NAME; + } + } + + class ConditionFunc implements Function { @Override public String name() { - return "timestamp"; + return "if"; } @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - ExprEval value = args.get(0).eval(bindings); - if (value.type() != ExprType.STRING) { - throw new IAE("first argument should be string type but got %s type", value.type()); - } - - DateTimes.UtcFormatter formatter = DateTimes.ISO_DATE_OPTIONAL_TIME; - if (args.size() > 1) { - ExprEval format = args.get(1).eval(bindings); - if (format.type() != ExprType.STRING) { - throw new IAE("second argument should be string type but got %s type", format.type()); - } - formatter = DateTimes.wrapFormatter(DateTimeFormat.forPattern(format.asString())); - } - DateTime date; - try { - date = formatter.parse(value.asString()); - } - catch (IllegalArgumentException e) { - throw new IAE(e, "invalid value %s", value.asString()); - } - return toValue(date); + ExprEval x = args.get(0).eval(bindings); + return x.asBoolean() ? args.get(1).eval(bindings) : args.get(2).eval(bindings); } @Override public void validateArguments(List args) { - if (args.size() != 1 && args.size() != 2) { - throw new IAE("Function[%s] needs 1 or 2 arguments", name()); + if (args.size() != 3) { + throw new IAE("Function[%s] needs 3 arguments", name()); } } - protected ExprEval toValue(DateTime date) + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { - return ExprEval.of(date.getMillis()); + // output type is defined by else + return args.get(2).getOutputType(inputTypes); } } - class UnixTimestampFunc extends TimestampFromEpochFunc + /** + * "Searched CASE" function, similar to {@code CASE WHEN boolean_expr THEN result [ELSE else_result] END} in SQL. + */ + class CaseSearchedFunc implements Function { @Override public String name() { - return "unix_timestamp"; + return "case_searched"; } @Override - protected final ExprEval toValue(DateTime date) + public ExprEval apply(final List args, final Expr.ObjectBinding bindings) { - return ExprEval.of(date.getMillis() / 1000); + for (int i = 0; i < args.size(); i += 2) { + if (i == args.size() - 1) { + // ELSE else_result. + return args.get(i).eval(bindings); + } else if (args.get(i).eval(bindings).asBoolean()) { + // Matching WHEN boolean_expr THEN result + return args.get(i + 1).eval(bindings); + } + } + + return ExprEval.of(null); + } + + @Override + public void validateArguments(List args) + { + if (args.size() < 2) { + throw new IAE("Function[%s] must have at least 2 arguments", name()); + } + } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + // output type is defined by else + return args.get(args.size() - 1).getOutputType(inputTypes); + } + } + + /** + * "Simple CASE" function, similar to {@code CASE expr WHEN value THEN result [ELSE else_result] END} in SQL. + */ + class CaseSimpleFunc implements Function + { + @Override + public String name() + { + return "case_simple"; + } + + @Override + public ExprEval apply(final List args, final Expr.ObjectBinding bindings) + { + for (int i = 1; i < args.size(); i += 2) { + if (i == args.size() - 1) { + // ELSE else_result. + return args.get(i).eval(bindings); + } else if (new BinEqExpr("==", args.get(0), args.get(i)).eval(bindings).asBoolean()) { + // Matching WHEN value THEN result + return args.get(i + 1).eval(bindings); + } + } + + return ExprEval.of(null); + } + + @Override + public void validateArguments(List args) + { + if (args.size() < 3) { + throw new IAE("Function[%s] must have at least 3 arguments", name()); + } + } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + // output type is defined by else + return args.get(args.size() - 1).getOutputType(inputTypes); } } @@ -1463,6 +1479,75 @@ public interface Function throw new IAE("Function[%s] needs 2 arguments", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return args.get(0).getOutputType(inputTypes); + } + } + + class IsNullFunc implements Function + { + @Override + public String name() + { + return "isnull"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval expr = args.get(0).eval(bindings); + return ExprEval.ofLongBoolean(expr.value() == null); + } + + @Override + public void validateArguments(List args) + { + if (args.size() != 1) { + throw new IAE("Function[%s] needs 1 argument", name()); + } + } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + } + + class IsNotNullFunc implements Function + { + @Override + public String name() + { + return "notnull"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval expr = args.get(0).eval(bindings); + return ExprEval.ofLongBoolean(expr.value() != null); + } + + @Override + public void validateArguments(List args) + { + if (args.size() != 1) { + throw new IAE("Function[%s] needs 1 argument", name()); + } + } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } } class ConcatFunc implements Function @@ -1506,6 +1591,13 @@ public interface Function { // anything goes } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class StrlenFunc implements Function @@ -1530,6 +1622,13 @@ public interface Function throw new IAE("Function[%s] needs 1 argument", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } } class StringFormatFunc implements Function @@ -1564,6 +1663,13 @@ public interface Function throw new IAE("Function[%s] needs 1 or more arguments", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class StrposFunc implements Function @@ -1602,6 +1708,13 @@ public interface Function throw new IAE("Function[%s] needs 2 or 3 arguments", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } } class SubstringFunc implements Function @@ -1645,6 +1758,13 @@ public interface Function throw new IAE("Function[%s] needs 3 arguments", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class RightFunc extends StringLongFunction @@ -1655,6 +1775,13 @@ public interface Function return "right"; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } + @Override protected ExprEval eval(@Nullable String x, int y) { @@ -1680,6 +1807,13 @@ public interface Function return "left"; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } + @Override protected ExprEval eval(@Nullable String x, int y) { @@ -1723,6 +1857,13 @@ public interface Function throw new IAE("Function[%s] needs 3 arguments", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class LowerFunc implements Function @@ -1750,6 +1891,13 @@ public interface Function throw new IAE("Function[%s] needs 1 argument", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class UpperFunc implements Function @@ -1777,6 +1925,13 @@ public interface Function throw new IAE("Function[%s] needs 1 argument", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class ReverseFunc extends UnivariateFunction @@ -1787,6 +1942,13 @@ public interface Function return "reverse"; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } + @Override protected ExprEval eval(ExprEval param) { @@ -1809,6 +1971,13 @@ public interface Function return "repeat"; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } + @Override protected ExprEval eval(String x, int y) { @@ -1816,54 +1985,6 @@ public interface Function } } - class IsNullFunc implements Function - { - @Override - public String name() - { - return "isnull"; - } - - @Override - public ExprEval apply(List args, Expr.ObjectBinding bindings) - { - final ExprEval expr = args.get(0).eval(bindings); - return ExprEval.of(expr.value() == null, ExprType.LONG); - } - - @Override - public void validateArguments(List args) - { - if (args.size() != 1) { - throw new IAE("Function[%s] needs 1 argument", name()); - } - } - } - - class IsNotNullFunc implements Function - { - @Override - public String name() - { - return "notnull"; - } - - @Override - public ExprEval apply(List args, Expr.ObjectBinding bindings) - { - final ExprEval expr = args.get(0).eval(bindings); - return ExprEval.of(expr.value() != null, ExprType.LONG); - } - - @Override - public void validateArguments(List args) - { - if (args.size() != 1) { - throw new IAE("Function[%s] needs 1 argument", name()); - } - } - } - class LpadFunc implements Function { @Override @@ -1894,6 +2015,13 @@ public interface Function throw new IAE("Function[%s] needs 3 arguments", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class RpadFunc implements Function @@ -1926,6 +2054,83 @@ public interface Function throw new IAE("Function[%s] needs 3 arguments", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } + } + + class TimestampFromEpochFunc implements Function + { + @Override + public String name() + { + return "timestamp"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + ExprEval value = args.get(0).eval(bindings); + if (value.type() != ExprType.STRING) { + throw new IAE("first argument should be string type but got %s type", value.type()); + } + + DateTimes.UtcFormatter formatter = DateTimes.ISO_DATE_OPTIONAL_TIME; + if (args.size() > 1) { + ExprEval format = args.get(1).eval(bindings); + if (format.type() != ExprType.STRING) { + throw new IAE("second argument should be string type but got %s type", format.type()); + } + formatter = DateTimes.wrapFormatter(DateTimeFormat.forPattern(format.asString())); + } + DateTime date; + try { + date = formatter.parse(value.asString()); + } + catch (IllegalArgumentException e) { + throw new IAE(e, "invalid value %s", value.asString()); + } + return toValue(date); + } + + @Override + public void validateArguments(List args) + { + if (args.size() != 1 && args.size() != 2) { + throw new IAE("Function[%s] needs 1 or 2 arguments", name()); + } + } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + + protected ExprEval toValue(DateTime date) + { + return ExprEval.of(date.getMillis()); + } + } + + class UnixTimestampFunc extends TimestampFromEpochFunc + { + @Override + public String name() + { + return "unix_timestamp"; + } + + @Override + protected final ExprEval toValue(DateTime date) + { + return ExprEval.of(date.getMillis() / 1000); + } } class SubMonthFunc implements Function @@ -1958,6 +2163,13 @@ public interface Function throw new IAE("Function[%s] needs 3 arguments", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } } class ArrayConstructorFunction implements Function @@ -2064,6 +2276,17 @@ public interface Function throw new IAE("Function[%s] needs at least 1 argument", name()); } } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + ExprType type = ExprType.LONG; + for (Expr arg : args) { + type = ExprType.functionAutoTypeConversion(type, arg.getOutputType(inputTypes)); + } + return ExprType.asArrayType(type); + } } class ArrayLengthFunction implements Function @@ -2110,6 +2333,13 @@ public interface Function } } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + @Override public Set getScalarInputs(List args) { @@ -2133,6 +2363,13 @@ public interface Function } } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING_ARRAY; + } + @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { @@ -2167,6 +2404,13 @@ public interface Function return "array_to_string"; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } + @Override ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) { @@ -2189,6 +2433,13 @@ public interface Function return "array_offset"; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.elementType(args.get(0).getOutputType(inputTypes)); + } + @Override ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) { @@ -2210,6 +2461,13 @@ public interface Function return "array_ordinal"; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.elementType(args.get(0).getOutputType(inputTypes)); + } + @Override ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) { @@ -2231,6 +2489,13 @@ public interface Function return "array_offset_of"; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + @Override ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) { @@ -2262,6 +2527,13 @@ public interface Function return "array_ordinal_of"; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + @Override ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) { @@ -2298,6 +2570,14 @@ public interface Function return true; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + ExprType arrayType = args.get(0).getOutputType(inputTypes); + return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType); + } + @Override ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) { @@ -2354,6 +2634,14 @@ public interface Function return true; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + ExprType arrayType = args.get(0).getOutputType(inputTypes); + return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType); + } + @Override ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) { @@ -2409,12 +2697,19 @@ public interface Function return true; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + @Override ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) { final Object[] array1 = lhsExpr.asArray(); final Object[] array2 = rhsExpr.asArray(); - return ExprEval.of(Arrays.asList(array1).containsAll(Arrays.asList(array2)), ExprType.LONG); + return ExprEval.ofLongBoolean(Arrays.asList(array1).containsAll(Arrays.asList(array2))); } } @@ -2426,6 +2721,13 @@ public interface Function return "array_overlap"; } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + @Override ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) { @@ -2435,7 +2737,7 @@ public interface Function for (Object check : array1) { any |= array2.contains(check); } - return ExprEval.of(any, ExprType.LONG); + return ExprEval.ofLongBoolean(any); } } @@ -2455,6 +2757,13 @@ public interface Function } } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return args.get(0).getOutputType(inputTypes); + } + @Override public Set getScalarInputs(List args) { @@ -2534,6 +2843,14 @@ public interface Function } } + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + ExprType arrayType = args.get(1).getOutputType(inputTypes); + return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType); + } + @Override public Set getScalarInputs(List args) { diff --git a/core/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java b/core/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java index 2b3474a6c82..e81d5bafd2c 100644 --- a/core/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java @@ -105,13 +105,19 @@ class LambdaExpr implements Expr } @Override - public BindingDetails analyzeInputs() + public BindingAnalysis analyzeInputs() { final Set lambdaArgs = args.stream().map(IdentifierExpr::toString).collect(Collectors.toSet()); - BindingDetails bodyDetails = expr.analyzeInputs(); + BindingAnalysis bodyDetails = expr.analyzeInputs(); return bodyDetails.removeLambdaArguments(lambdaArgs); } + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return expr.getOutputType(inputTypes); + } + @Override public boolean equals(Object o) { @@ -187,9 +193,9 @@ class FunctionExpr implements Expr } @Override - public BindingDetails analyzeInputs() + public BindingAnalysis analyzeInputs() { - BindingDetails accumulator = new BindingDetails(); + BindingAnalysis accumulator = new BindingAnalysis(); for (Expr arg : args) { accumulator = accumulator.with(arg); @@ -200,6 +206,12 @@ class FunctionExpr implements Expr .withArrayOutput(function.hasArrayOutput()); } + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return function.getOutputType(inputTypes, args); + } + @Override public boolean equals(Object o) { @@ -232,9 +244,9 @@ class ApplyFunctionExpr implements Expr final String name; final LambdaExpr lambdaExpr; final ImmutableList argsExpr; - final BindingDetails bindingDetails; - final BindingDetails lambdaBindingDetails; - final ImmutableList argsBindingDetails; + final BindingAnalysis bindingAnalysis; + final BindingAnalysis lambdaBindingAnalysis; + final ImmutableList argsBindingAnalyses; ApplyFunctionExpr(ApplyFunction function, String name, LambdaExpr expr, List args) { @@ -247,21 +259,21 @@ class ApplyFunctionExpr implements Expr // apply function expressions are examined during expression selector creation, so precompute and cache the // binding details of children - ImmutableList.Builder argBindingDetailsBuilder = ImmutableList.builder(); - BindingDetails accumulator = new BindingDetails(); + ImmutableList.Builder argBindingDetailsBuilder = ImmutableList.builder(); + BindingAnalysis accumulator = new BindingAnalysis(); for (Expr arg : argsExpr) { - BindingDetails argDetails = arg.analyzeInputs(); + BindingAnalysis argDetails = arg.analyzeInputs(); argBindingDetailsBuilder.add(argDetails); accumulator = accumulator.with(argDetails); } - lambdaBindingDetails = lambdaExpr.analyzeInputs(); + lambdaBindingAnalysis = lambdaExpr.analyzeInputs(); - bindingDetails = accumulator.with(lambdaBindingDetails) - .withArrayArguments(function.getArrayInputs(argsExpr)) - .withArrayInputs(true) - .withArrayOutput(function.hasArrayOutput(lambdaExpr)); - argsBindingDetails = argBindingDetailsBuilder.build(); + bindingAnalysis = accumulator.with(lambdaBindingAnalysis) + .withArrayArguments(function.getArrayInputs(argsExpr)) + .withArrayInputs(true) + .withArrayOutput(function.hasArrayOutput(lambdaExpr)); + argsBindingAnalyses = argBindingDetailsBuilder.build(); } @Override @@ -306,9 +318,16 @@ class ApplyFunctionExpr implements Expr } @Override - public BindingDetails analyzeInputs() + public BindingAnalysis analyzeInputs() { - return bindingDetails; + return bindingAnalysis; + } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return function.getOutputType(inputTypes, lambdaExpr, argsExpr); } @Override diff --git a/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java b/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java index d23657a3bd9..437370641c6 100644 --- a/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java @@ -102,9 +102,15 @@ class IdentifierExpr implements Expr } @Override - public BindingDetails analyzeInputs() + public BindingAnalysis analyzeInputs() { - return new BindingDetails(this); + return new BindingAnalysis(this); + } + + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return inputTypes.getType(binding); } @Override diff --git a/core/src/main/java/org/apache/druid/math/expr/Parser.java b/core/src/main/java/org/apache/druid/math/expr/Parser.java index d8fb564c4f2..c9388bff17f 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Parser.java +++ b/core/src/main/java/org/apache/druid/math/expr/Parser.java @@ -169,7 +169,7 @@ public class Parser * @param bindingsToApply * @return */ - public static Expr applyUnappliedBindings(Expr expr, Expr.BindingDetails bindingDetails, List bindingsToApply) + public static Expr applyUnappliedBindings(Expr expr, Expr.BindingAnalysis bindingAnalysis, List bindingsToApply) { if (bindingsToApply.isEmpty()) { // nothing to do, expression is fine as is @@ -177,7 +177,7 @@ public class Parser } // filter the list of bindings to those which are used in this expression List unappliedBindingsInExpression = bindingsToApply.stream() - .filter(x -> bindingDetails.getRequiredBindings().contains(x)) + .filter(x -> bindingAnalysis.getRequiredBindings().contains(x)) .collect(Collectors.toList()); // any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten @@ -193,7 +193,7 @@ public class Parser List newArgs = new ArrayList<>(); for (Expr arg : fnExpr.args) { if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) { - Expr newArg = applyUnappliedBindings(arg, bindingDetails, unappliedBindingsInExpression); + Expr newArg = applyUnappliedBindings(arg, bindingAnalysis, unappliedBindingsInExpression); newArgs.add(newArg); } else { newArgs.add(arg); @@ -207,7 +207,7 @@ public class Parser } ); - Expr.BindingDetails newExprBindings = newExpr.analyzeInputs(); + Expr.BindingAnalysis newExprBindings = newExpr.analyzeInputs(); final Set expectedArrays = newExprBindings.getArrayVariables(); List remainingUnappliedBindings = @@ -288,11 +288,11 @@ public class Parser // recursively evaluate arguments to ensure they are properly transformed into arrays as necessary Set unappliedInThisApply = unappliedArgs.stream() - .filter(u -> !expr.bindingDetails.getArrayBindings().contains(u)) + .filter(u -> !expr.bindingAnalysis.getArrayBindings().contains(u)) .collect(Collectors.toSet()); List unappliedIdentifiers = - expr.bindingDetails + expr.bindingAnalysis .getFreeVariables() .stream() .filter(x -> unappliedInThisApply.contains(x.getBindingIfIdentifier())) @@ -304,7 +304,7 @@ public class Parser newArgs.add( applyUnappliedBindings( expr.argsExpr.get(i), - expr.argsBindingDetails.get(i), + expr.argsBindingAnalyses.get(i), unappliedIdentifiers ) ); @@ -312,11 +312,11 @@ public class Parser // this will _not_ include the lambda identifiers.. anything in this list needs to be applied List unappliedLambdaBindings = - expr.lambdaBindingDetails.getFreeVariables() - .stream() - .filter(x -> unappliedArgs.contains(x.getBindingIfIdentifier())) - .map(x -> new IdentifierExpr(x.getIdentifier(), x.getBinding())) - .collect(Collectors.toList()); + expr.lambdaBindingAnalysis.getFreeVariables() + .stream() + .filter(x -> unappliedArgs.contains(x.getBindingIfIdentifier())) + .map(x -> new IdentifierExpr(x.getIdentifier(), x.getBinding())) + .collect(Collectors.toList()); if (unappliedLambdaBindings.isEmpty()) { return new ApplyFunctionExpr(expr.function, expr.name, expr.lambdaExpr, newArgs); @@ -397,10 +397,10 @@ public class Parser /** * Validate that an expression uses input bindings in a type consistent manner. */ - public static void validateExpr(Expr expression, Expr.BindingDetails bindingDetails) + public static void validateExpr(Expr expression, Expr.BindingAnalysis bindingAnalysis) { final Set conflicted = - Sets.intersection(bindingDetails.getScalarBindings(), bindingDetails.getArrayBindings()); + Sets.intersection(bindingAnalysis.getScalarBindings(), bindingAnalysis.getArrayBindings()); if (!conflicted.isEmpty()) { throw new RE("Invalid expression: %s; %s used as both scalar and array variables", expression, conflicted); } diff --git a/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java index 5a41e904250..3d68430ea65 100644 --- a/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java @@ -24,6 +24,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.StringUtils; +import javax.annotation.Nullable; import java.util.Objects; /** @@ -59,12 +60,19 @@ abstract class UnaryExpr implements Expr } @Override - public BindingDetails analyzeInputs() + public BindingAnalysis analyzeInputs() { // currently all unary operators only operate on scalar inputs return expr.analyzeInputs().withScalarArguments(ImmutableSet.of(expr)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return expr.getOutputType(inputTypes); + } + @Override public boolean equals(Object o) { @@ -163,4 +171,15 @@ class UnaryNotExpr extends UnaryExpr { return StringUtils.format("!%s", expr); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } diff --git a/core/src/test/java/org/apache/druid/math/expr/ExprTest.java b/core/src/test/java/org/apache/druid/math/expr/ExprTest.java index ff12669bbc6..6dfa61d6d18 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ExprTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ExprTest.java @@ -113,7 +113,7 @@ public class ExprTest { EqualsVerifier.forClass(ApplyFunctionExpr.class) .usingGetClass() - .withIgnoredFields("function", "bindingDetails", "lambdaBindingDetails", "argsBindingDetails") + .withIgnoredFields("function", "bindingAnalysis", "lambdaBindingAnalysis", "argsBindingAnalyses") .verify(); } @@ -132,37 +132,55 @@ public class ExprTest @Test public void testEqualsContractForStringExpr() { - EqualsVerifier.forClass(StringExpr.class).usingGetClass().verify(); + EqualsVerifier.forClass(StringExpr.class) + .withIgnoredFields("outputType") + .usingGetClass() + .verify(); } @Test public void testEqualsContractForDoubleExpr() { - EqualsVerifier.forClass(DoubleExpr.class).usingGetClass().verify(); + EqualsVerifier.forClass(DoubleExpr.class) + .withIgnoredFields("outputType") + .usingGetClass() + .verify(); } @Test public void testEqualsContractForLongExpr() { - EqualsVerifier.forClass(LongExpr.class).usingGetClass().verify(); + EqualsVerifier.forClass(LongExpr.class) + .withIgnoredFields("outputType") + .usingGetClass() + .verify(); } @Test public void testEqualsContractForStringArrayExpr() { - EqualsVerifier.forClass(StringArrayExpr.class).usingGetClass().verify(); + EqualsVerifier.forClass(StringArrayExpr.class) + .withIgnoredFields("outputType") + .usingGetClass() + .verify(); } @Test public void testEqualsContractForLongArrayExpr() { - EqualsVerifier.forClass(LongArrayExpr.class).usingGetClass().verify(); + EqualsVerifier.forClass(LongArrayExpr.class) + .withIgnoredFields("outputType") + .usingGetClass() + .verify(); } @Test public void testEqualsContractForDoubleArrayExpr() { - EqualsVerifier.forClass(DoubleArrayExpr.class).usingGetClass().verify(); + EqualsVerifier.forClass(DoubleArrayExpr.class) + .withIgnoredFields("outputType") + .usingGetClass() + .verify(); } @Test @@ -179,12 +197,16 @@ public class ExprTest @Test public void testEqualsContractForNullLongExpr() { - EqualsVerifier.forClass(NullLongExpr.class).verify(); + EqualsVerifier.forClass(NullLongExpr.class) + .withIgnoredFields("outputType") + .verify(); } @Test public void testEqualsContractForNullDoubleExpr() { - EqualsVerifier.forClass(NullDoubleExpr.class).verify(); + EqualsVerifier.forClass(NullDoubleExpr.class) + .withIgnoredFields("outputType") + .verify(); } } diff --git a/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java b/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java new file mode 100644 index 00000000000..7b977d3b9ac --- /dev/null +++ b/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java @@ -0,0 +1,463 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.math.expr; + +import com.google.common.collect.ImmutableMap; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.util.Map; + +public class OutputTypeTest extends InitializedNullHandlingTest +{ + private final Expr.InputBindingTypes inputTypes = inputTypesFromMap( + ImmutableMap.builder().put("x", ExprType.STRING) + .put("x_", ExprType.STRING) + .put("y", ExprType.LONG) + .put("y_", ExprType.LONG) + .put("z", ExprType.DOUBLE) + .put("z_", ExprType.DOUBLE) + .put("a", ExprType.STRING_ARRAY) + .put("a_", ExprType.STRING_ARRAY) + .put("b", ExprType.LONG_ARRAY) + .put("b_", ExprType.LONG_ARRAY) + .put("c", ExprType.DOUBLE_ARRAY) + .put("c_", ExprType.DOUBLE_ARRAY) + .build() + ); + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testConstantsAndIdentifiers() + { + assertOutputType("'hello'", inputTypes, ExprType.STRING); + assertOutputType("23", inputTypes, ExprType.LONG); + assertOutputType("3.2", inputTypes, ExprType.DOUBLE); + assertOutputType("['a', 'b']", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("[1,2,3]", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("[1.0]", inputTypes, ExprType.DOUBLE_ARRAY); + assertOutputType("x", inputTypes, ExprType.STRING); + assertOutputType("y", inputTypes, ExprType.LONG); + assertOutputType("z", inputTypes, ExprType.DOUBLE); + assertOutputType("a", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("b", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("c", inputTypes, ExprType.DOUBLE_ARRAY); + } + + @Test + public void testUnaryOperators() + { + assertOutputType("-1", inputTypes, ExprType.LONG); + assertOutputType("-1.1", inputTypes, ExprType.DOUBLE); + assertOutputType("-y", inputTypes, ExprType.LONG); + assertOutputType("-z", inputTypes, ExprType.DOUBLE); + + assertOutputType("!'true'", inputTypes, ExprType.LONG); + assertOutputType("!1", inputTypes, ExprType.LONG); + assertOutputType("!1.1", inputTypes, ExprType.DOUBLE); + assertOutputType("!x", inputTypes, ExprType.LONG); + assertOutputType("!y", inputTypes, ExprType.LONG); + assertOutputType("!z", inputTypes, ExprType.DOUBLE); + } + + @Test + public void testBinaryMathOperators() + { + assertOutputType("1+1", inputTypes, ExprType.LONG); + assertOutputType("1-1", inputTypes, ExprType.LONG); + assertOutputType("1*1", inputTypes, ExprType.LONG); + assertOutputType("1/1", inputTypes, ExprType.LONG); + assertOutputType("1^1", inputTypes, ExprType.LONG); + assertOutputType("1%1", inputTypes, ExprType.LONG); + + assertOutputType("y+y_", inputTypes, ExprType.LONG); + assertOutputType("y-y_", inputTypes, ExprType.LONG); + assertOutputType("y*y_", inputTypes, ExprType.LONG); + assertOutputType("y/y_", inputTypes, ExprType.LONG); + assertOutputType("y^y_", inputTypes, ExprType.LONG); + assertOutputType("y%y_", inputTypes, ExprType.LONG); + + assertOutputType("y+z", inputTypes, ExprType.DOUBLE); + assertOutputType("y-z", inputTypes, ExprType.DOUBLE); + assertOutputType("y*z", inputTypes, ExprType.DOUBLE); + assertOutputType("y/z", inputTypes, ExprType.DOUBLE); + assertOutputType("y^z", inputTypes, ExprType.DOUBLE); + assertOutputType("y%z", inputTypes, ExprType.DOUBLE); + + assertOutputType("z+z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z-z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z*z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z/z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z^z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z%z_", inputTypes, ExprType.DOUBLE); + + assertOutputType("y>y_", inputTypes, ExprType.LONG); + assertOutputType("y_=y", inputTypes, ExprType.LONG); + assertOutputType("y_==y", inputTypes, ExprType.LONG); + assertOutputType("y_!=y", inputTypes, ExprType.LONG); + assertOutputType("y_ && y", inputTypes, ExprType.LONG); + assertOutputType("y_ || y", inputTypes, ExprType.LONG); + + assertOutputType("z>y_", inputTypes, ExprType.DOUBLE); + assertOutputType("z=z", inputTypes, ExprType.DOUBLE); + assertOutputType("z==y", inputTypes, ExprType.DOUBLE); + assertOutputType("z!=y", inputTypes, ExprType.DOUBLE); + assertOutputType("z && y", inputTypes, ExprType.DOUBLE); + assertOutputType("y || z", inputTypes, ExprType.DOUBLE); + + assertOutputType("z>z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z=z", inputTypes, ExprType.DOUBLE); + assertOutputType("z==z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z!=z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z && z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z_ || z", inputTypes, ExprType.DOUBLE); + + assertOutputType("1*(2 + 3.0)", inputTypes, ExprType.DOUBLE); + } + + @Test + public void testUnivariateMathFunctions() + { + assertOutputType("pi()", inputTypes, ExprType.DOUBLE); + assertOutputType("abs(x)", inputTypes, ExprType.STRING); + assertOutputType("abs(y)", inputTypes, ExprType.LONG); + assertOutputType("abs(z)", inputTypes, ExprType.DOUBLE); + assertOutputType("cos(y)", inputTypes, ExprType.DOUBLE); + assertOutputType("cos(z)", inputTypes, ExprType.DOUBLE); + } + + @Test + public void testBivariateMathFunctions() + { + assertOutputType("div(y,y_)", inputTypes, ExprType.LONG); + assertOutputType("div(y,z_)", inputTypes, ExprType.DOUBLE); + assertOutputType("div(z,z_)", inputTypes, ExprType.DOUBLE); + + assertOutputType("max(y,y_)", inputTypes, ExprType.LONG); + assertOutputType("max(y,z_)", inputTypes, ExprType.DOUBLE); + assertOutputType("max(z,z_)", inputTypes, ExprType.DOUBLE); + + assertOutputType("hypot(y,y_)", inputTypes, ExprType.DOUBLE); + assertOutputType("hypot(y,z_)", inputTypes, ExprType.DOUBLE); + assertOutputType("hypot(z,z_)", inputTypes, ExprType.DOUBLE); + } + + @Test + public void testConditionalFunctions() + { + assertOutputType("if(y, 'foo', 'bar')", inputTypes, ExprType.STRING); + assertOutputType("if(y,2,3)", inputTypes, ExprType.LONG); + assertOutputType("if(y,2,3.0)", inputTypes, ExprType.DOUBLE); + + assertOutputType( + "case_simple(x,'baz','is baz','foo','is foo','is other')", + inputTypes, + ExprType.STRING + ); + assertOutputType( + "case_simple(y,2,2,3,3,4)", + inputTypes, + ExprType.LONG + ); + assertOutputType( + "case_simple(z,2.0,2.0,3.0,3.0,4.0)", + inputTypes, + ExprType.DOUBLE + ); + + assertOutputType( + "case_searched(x=='baz','is baz',x=='foo','is foo','is other')", + inputTypes, + ExprType.STRING + ); + assertOutputType( + "case_searched(y==1,1,y==2,2,0)", + inputTypes, + ExprType.LONG + ); + assertOutputType( + "case_searched(z==1.0,1.0,z==2.0,2.0,0.0)", + inputTypes, + ExprType.DOUBLE + ); + + assertOutputType("nvl(x, 'foo')", inputTypes, ExprType.STRING); + assertOutputType("nvl(y, 1)", inputTypes, ExprType.LONG); + assertOutputType("nvl(z, 2.0)", inputTypes, ExprType.DOUBLE); + assertOutputType("isnull(x)", inputTypes, ExprType.LONG); + assertOutputType("isnull(y)", inputTypes, ExprType.LONG); + assertOutputType("isnull(z)", inputTypes, ExprType.LONG); + assertOutputType("notnull(x)", inputTypes, ExprType.LONG); + assertOutputType("notnull(y)", inputTypes, ExprType.LONG); + assertOutputType("notnull(z)", inputTypes, ExprType.LONG); + } + + @Test + public void testStringFunctions() + { + assertOutputType("concat(x, 'foo')", inputTypes, ExprType.STRING); + assertOutputType("concat(y, 'foo')", inputTypes, ExprType.STRING); + assertOutputType("concat(z, 'foo')", inputTypes, ExprType.STRING); + + assertOutputType("strlen(x)", inputTypes, ExprType.LONG); + assertOutputType("format('%s', x)", inputTypes, ExprType.STRING); + assertOutputType("format('%s', y)", inputTypes, ExprType.STRING); + assertOutputType("format('%s', z)", inputTypes, ExprType.STRING); + assertOutputType("strpos(x, x_)", inputTypes, ExprType.LONG); + assertOutputType("strpos(x, y)", inputTypes, ExprType.LONG); + assertOutputType("strpos(x, z)", inputTypes, ExprType.LONG); + assertOutputType("substring(x, 1, 2)", inputTypes, ExprType.STRING); + assertOutputType("left(x, 1)", inputTypes, ExprType.STRING); + assertOutputType("right(x, 1)", inputTypes, ExprType.STRING); + assertOutputType("replace(x, 'foo', '')", inputTypes, ExprType.STRING); + assertOutputType("lower(x)", inputTypes, ExprType.STRING); + assertOutputType("upper(x)", inputTypes, ExprType.STRING); + assertOutputType("reverse(x)", inputTypes, ExprType.STRING); + assertOutputType("repeat(x, 4)", inputTypes, ExprType.STRING); + } + + @Test + public void testArrayFunctions() + { + assertOutputType("array(1, 2, 3)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array(1, 2, 3.0)", inputTypes, ExprType.DOUBLE_ARRAY); + + assertOutputType("array_length(a)", inputTypes, ExprType.LONG); + assertOutputType("array_length(b)", inputTypes, ExprType.LONG); + assertOutputType("array_length(c)", inputTypes, ExprType.LONG); + + assertOutputType("string_to_array(x, ',')", inputTypes, ExprType.STRING_ARRAY); + + assertOutputType("array_to_string(a, ',')", inputTypes, ExprType.STRING); + assertOutputType("array_to_string(b, ',')", inputTypes, ExprType.STRING); + assertOutputType("array_to_string(c, ',')", inputTypes, ExprType.STRING); + + assertOutputType("array_offset(a, 1)", inputTypes, ExprType.STRING); + assertOutputType("array_offset(b, 1)", inputTypes, ExprType.LONG); + assertOutputType("array_offset(c, 1)", inputTypes, ExprType.DOUBLE); + + assertOutputType("array_ordinal(a, 1)", inputTypes, ExprType.STRING); + assertOutputType("array_ordinal(b, 1)", inputTypes, ExprType.LONG); + assertOutputType("array_ordinal(c, 1)", inputTypes, ExprType.DOUBLE); + + assertOutputType("array_offset_of(a, 'a')", inputTypes, ExprType.LONG); + assertOutputType("array_offset_of(b, 1)", inputTypes, ExprType.LONG); + assertOutputType("array_offset_of(c, 1.0)", inputTypes, ExprType.LONG); + + assertOutputType("array_ordinal_of(a, 'a')", inputTypes, ExprType.LONG); + assertOutputType("array_ordinal_of(b, 1)", inputTypes, ExprType.LONG); + assertOutputType("array_ordinal_of(c, 1.0)", inputTypes, ExprType.LONG); + + assertOutputType("array_append(x, x_)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_append(a, x_)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_append(y, y_)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_append(b, y_)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_append(z, z_)", inputTypes, ExprType.DOUBLE_ARRAY); + assertOutputType("array_append(c, z_)", inputTypes, ExprType.DOUBLE_ARRAY); + + assertOutputType("array_concat(x, a)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_concat(a, a)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_concat(y, b)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_concat(b, b)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_concat(z, c)", inputTypes, ExprType.DOUBLE_ARRAY); + assertOutputType("array_concat(c, c)", inputTypes, ExprType.DOUBLE_ARRAY); + + assertOutputType("array_contains(a, 'a')", inputTypes, ExprType.LONG); + assertOutputType("array_contains(b, 1)", inputTypes, ExprType.LONG); + assertOutputType("array_contains(c, 2.0)", inputTypes, ExprType.LONG); + + assertOutputType("array_overlap(a, a)", inputTypes, ExprType.LONG); + assertOutputType("array_overlap(b, b)", inputTypes, ExprType.LONG); + assertOutputType("array_overlap(c, c)", inputTypes, ExprType.LONG); + + assertOutputType("array_slice(a, 1, 2)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_slice(b, 1, 2)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_slice(c, 1, 2)", inputTypes, ExprType.DOUBLE_ARRAY); + + assertOutputType("array_prepend(x, a)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_prepend(x, x_)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_prepend(y, b)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_prepend(y, y_)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_prepend(z, c)", inputTypes, ExprType.DOUBLE_ARRAY); + assertOutputType("array_prepend(z, z_)", inputTypes, ExprType.DOUBLE_ARRAY); + } + + @Test + public void testReduceFunctions() + { + assertOutputType("greatest('B', x, 'A')", inputTypes, ExprType.STRING); + assertOutputType("greatest(y, 0)", inputTypes, ExprType.LONG); + assertOutputType("greatest(34.0, z, 5.0, 767.0)", inputTypes, ExprType.DOUBLE); + + assertOutputType("least('B', x, 'A')", inputTypes, ExprType.STRING); + assertOutputType("least(y, 0)", inputTypes, ExprType.LONG); + assertOutputType("least(34.0, z, 5.0, 767.0)", inputTypes, ExprType.DOUBLE); + } + + @Test + public void testApplyFunctions() + { + assertOutputType("map((x) -> concat(x, 'foo'), x)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("map((x) -> x + x, y)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("map((x) -> x + x, z)", inputTypes, ExprType.DOUBLE_ARRAY); + assertOutputType("map((x) -> concat(x, 'foo'), a)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("map((x) -> x + x, b)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("map((x) -> x + x, c)", inputTypes, ExprType.DOUBLE_ARRAY); + assertOutputType( + "cartesian_map((x, y) -> concat(x, y), ['foo', 'bar', 'baz', 'foobar'], ['bar', 'baz'])", + inputTypes, + ExprType.STRING_ARRAY + ); + assertOutputType("fold((x, acc) -> x + acc, y, 0)", inputTypes, ExprType.LONG); + assertOutputType("fold((x, acc) -> x + acc, y, y)", inputTypes, ExprType.LONG); + assertOutputType("fold((x, acc) -> x + acc, y, 1.0)", inputTypes, ExprType.DOUBLE); + assertOutputType("fold((x, acc) -> x + acc, y, z)", inputTypes, ExprType.DOUBLE); + + assertOutputType("cartesian_fold((x, y, acc) -> x + y + acc, y, z, 0)", inputTypes, ExprType.LONG); + assertOutputType("cartesian_fold((x, y, acc) -> x + y + acc, y, z, y)", inputTypes, ExprType.LONG); + assertOutputType("cartesian_fold((x, y, acc) -> x + y + acc, y, z, 1.0)", inputTypes, ExprType.DOUBLE); + assertOutputType("cartesian_fold((x, y, acc) -> x + y + acc, y, z, z)", inputTypes, ExprType.DOUBLE); + + assertOutputType("filter((x) -> x == 'foo', a)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("filter((x) -> x > 1, b)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("filter((x) -> x > 1, c)", inputTypes, ExprType.DOUBLE_ARRAY); + + assertOutputType("any((x) -> x == 'foo', a)", inputTypes, ExprType.LONG); + assertOutputType("any((x) -> x > 1, b)", inputTypes, ExprType.LONG); + assertOutputType("any((x) -> x > 1.2, c)", inputTypes, ExprType.LONG); + + assertOutputType("all((x) -> x == 'foo', a)", inputTypes, ExprType.LONG); + assertOutputType("all((x) -> x > 1, b)", inputTypes, ExprType.LONG); + assertOutputType("all((x) -> x > 1.2, c)", inputTypes, ExprType.LONG); + } + + + @Test + public void testOperatorAutoConversion() + { + // nulls output nulls + Assert.assertNull(ExprType.operatorAutoTypeConversion(ExprType.LONG, null)); + Assert.assertNull(ExprType.operatorAutoTypeConversion(null, ExprType.LONG)); + Assert.assertNull(ExprType.operatorAutoTypeConversion(ExprType.DOUBLE, null)); + Assert.assertNull(ExprType.operatorAutoTypeConversion(null, ExprType.DOUBLE)); + Assert.assertNull(ExprType.operatorAutoTypeConversion(ExprType.STRING, null)); + Assert.assertNull(ExprType.operatorAutoTypeConversion(null, ExprType.STRING)); + // only long stays long + Assert.assertEquals(ExprType.LONG, ExprType.operatorAutoTypeConversion(ExprType.LONG, ExprType.LONG)); + // only string stays string + Assert.assertEquals(ExprType.STRING, ExprType.operatorAutoTypeConversion(ExprType.STRING, ExprType.STRING)); + // for operators, doubles is the catch all + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.LONG, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.DOUBLE, ExprType.LONG)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.DOUBLE, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.DOUBLE, ExprType.STRING)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.STRING, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.STRING, ExprType.LONG)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.LONG, ExprType.STRING)); + // unless it is an array, and those have to be the same + Assert.assertEquals(ExprType.LONG_ARRAY, ExprType.operatorAutoTypeConversion(ExprType.LONG_ARRAY, ExprType.LONG_ARRAY)); + Assert.assertEquals( + ExprType.DOUBLE_ARRAY, + ExprType.operatorAutoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.DOUBLE_ARRAY) + ); + Assert.assertEquals( + ExprType.STRING_ARRAY, + ExprType.operatorAutoTypeConversion(ExprType.STRING_ARRAY, ExprType.STRING_ARRAY) + ); + } + + @Test + public void testFunctionAutoConversion() + { + // nulls output nulls + Assert.assertNull(ExprType.functionAutoTypeConversion(ExprType.LONG, null)); + Assert.assertNull(ExprType.functionAutoTypeConversion(null, ExprType.LONG)); + Assert.assertNull(ExprType.functionAutoTypeConversion(ExprType.DOUBLE, null)); + Assert.assertNull(ExprType.functionAutoTypeConversion(null, ExprType.DOUBLE)); + Assert.assertNull(ExprType.functionAutoTypeConversion(ExprType.STRING, null)); + Assert.assertNull(ExprType.functionAutoTypeConversion(null, ExprType.STRING)); + // only long stays long + Assert.assertEquals(ExprType.LONG, ExprType.functionAutoTypeConversion(ExprType.LONG, ExprType.LONG)); + // any double makes all doubles + Assert.assertEquals(ExprType.DOUBLE, ExprType.functionAutoTypeConversion(ExprType.LONG, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.functionAutoTypeConversion(ExprType.DOUBLE, ExprType.LONG)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.functionAutoTypeConversion(ExprType.DOUBLE, ExprType.DOUBLE)); + // any string makes become string + Assert.assertEquals(ExprType.STRING, ExprType.functionAutoTypeConversion(ExprType.LONG, ExprType.STRING)); + Assert.assertEquals(ExprType.STRING, ExprType.functionAutoTypeConversion(ExprType.STRING, ExprType.LONG)); + Assert.assertEquals(ExprType.STRING, ExprType.functionAutoTypeConversion(ExprType.DOUBLE, ExprType.STRING)); + Assert.assertEquals(ExprType.STRING, ExprType.functionAutoTypeConversion(ExprType.STRING, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.STRING, ExprType.functionAutoTypeConversion(ExprType.STRING, ExprType.STRING)); + // unless it is an array, and those have to be the same + Assert.assertEquals(ExprType.LONG_ARRAY, ExprType.functionAutoTypeConversion(ExprType.LONG_ARRAY, ExprType.LONG_ARRAY)); + Assert.assertEquals( + ExprType.DOUBLE_ARRAY, + ExprType.functionAutoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.DOUBLE_ARRAY) + ); + Assert.assertEquals( + ExprType.STRING_ARRAY, + ExprType.functionAutoTypeConversion(ExprType.STRING_ARRAY, ExprType.STRING_ARRAY) + ); + } + + @Test + public void testAutoConversionArrayMismatchArrays() + { + expectedException.expect(IAE.class); + ExprType.functionAutoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.LONG_ARRAY); + } + + @Test + public void testAutoConversionArrayMismatchArrayScalar() + { + expectedException.expect(IAE.class); + ExprType.functionAutoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.LONG); + } + + @Test + public void testAutoConversionArrayMismatchScalarArray() + { + expectedException.expect(IAE.class); + ExprType.functionAutoTypeConversion(ExprType.STRING, ExprType.LONG_ARRAY); + } + + private void assertOutputType(String expression, Expr.InputBindingTypes inputTypes, ExprType outputType) + { + final Expr expr = Parser.parse(expression, ExprMacroTable.nil(), false); + Assert.assertEquals(outputType, expr.getOutputType(inputTypes)); + } + + Expr.InputBindingTypes inputTypesFromMap(Map types) + { + return types::get; + } +} diff --git a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java index b1ef6736ed5..1ebd71f1cc4 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java @@ -577,7 +577,7 @@ public class ParserTest extends InitializedNullHandlingTest ) { final Expr parsed = Parser.parse(expression, ExprMacroTable.nil()); - final Expr.BindingDetails deets = parsed.analyzeInputs(); + final Expr.BindingAnalysis deets = parsed.analyzeInputs(); Assert.assertEquals(expression, expected, parsed.toString()); Assert.assertEquals(expression, identifiers, deets.getRequiredBindingsList()); Assert.assertEquals(expression, scalars, deets.getScalarVariables()); @@ -586,7 +586,7 @@ public class ParserTest extends InitializedNullHandlingTest final Expr parsedNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false); final Expr roundTrip = Parser.parse(parsedNoFlatten.stringify(), ExprMacroTable.nil()); Assert.assertEquals(parsed.stringify(), roundTrip.stringify()); - final Expr.BindingDetails roundTripDeets = roundTrip.analyzeInputs(); + final Expr.BindingAnalysis roundTripDeets = roundTrip.analyzeInputs(); Assert.assertEquals(expression, identifiers, roundTripDeets.getRequiredBindingsList()); Assert.assertEquals(expression, scalars, roundTripDeets.getScalarVariables()); Assert.assertEquals(expression, arrays, roundTripDeets.getArrayVariables()); @@ -600,7 +600,7 @@ public class ParserTest extends InitializedNullHandlingTest ) { final Expr parsed = Parser.parse(expression, ExprMacroTable.nil()); - Expr.BindingDetails deets = parsed.analyzeInputs(); + Expr.BindingAnalysis deets = parsed.analyzeInputs(); Parser.validateExpr(parsed, deets); final Expr transformed = Parser.applyUnappliedBindings(parsed, deets, identifiers); Assert.assertEquals(expression, unapplied, parsed.toString()); @@ -608,7 +608,7 @@ public class ParserTest extends InitializedNullHandlingTest final Expr parsedNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false); final Expr parsedRoundTrip = Parser.parse(parsedNoFlatten.stringify(), ExprMacroTable.nil()); - Expr.BindingDetails roundTripDeets = parsedRoundTrip.analyzeInputs(); + Expr.BindingAnalysis roundTripDeets = parsedRoundTrip.analyzeInputs(); Parser.validateExpr(parsedRoundTrip, roundTripDeets); final Expr transformedRoundTrip = Parser.applyUnappliedBindings(parsedRoundTrip, roundTripDeets, identifiers); Assert.assertEquals(expression, unapplied, parsedRoundTrip.toString()); diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java index dcc3a16ccad..8e7d04ccb7d 100644 --- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java +++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java @@ -29,6 +29,7 @@ import org.apache.druid.math.expr.ExprType; import org.apache.druid.query.filter.BloomKFilter; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.io.IOException; import java.util.List; @@ -108,7 +109,7 @@ public class BloomFilterExprMacro implements ExprMacroTable.ExprMacro break; } - return ExprEval.of(matches, ExprType.LONG); + return ExprEval.ofLongBoolean(matches); } private boolean nullMatch() @@ -123,6 +124,13 @@ public class BloomFilterExprMacro implements ExprMacroTable.ExprMacro Expr newArg = arg.visit(shuttle); return shuttle.visit(new BloomExpr(newArg)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } } return new BloomExpr(arg); diff --git a/processing/src/main/java/org/apache/druid/query/expression/ContainsExpr.java b/processing/src/main/java/org/apache/druid/query/expression/ContainsExpr.java index f9550f32429..f36311229e7 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/ContainsExpr.java +++ b/processing/src/main/java/org/apache/druid/query/expression/ContainsExpr.java @@ -28,6 +28,7 @@ import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.function.Function; /** @@ -62,13 +63,20 @@ class ContainsExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr if (s == null) { // same behavior as regexp_like. - return ExprEval.of(false, ExprType.LONG); + return ExprEval.ofLongBoolean(false); } else { final boolean doesContain = searchFunction.apply(s); - return ExprEval.of(doesContain, ExprType.LONG); + return ExprEval.ofLongBoolean(doesContain); } } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public Expr visit(Expr.Shuttle shuttle) { diff --git a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacro.java index 5e9cc85fe54..1aff62d199f 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacro.java @@ -28,6 +28,7 @@ import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; /** @@ -98,7 +99,7 @@ public class IPv4AddressMatchExprMacro implements ExprMacroTable.ExprMacro default: match = false; } - return ExprEval.of(match, ExprType.LONG); + return ExprEval.ofLongBoolean(match); } private boolean isStringMatch(String stringValue) @@ -118,6 +119,13 @@ public class IPv4AddressMatchExprMacro implements ExprMacroTable.ExprMacro return shuttle.visit(new IPv4AddressMatchExpr(newArg, subnetInfo)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressParseExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressParseExprMacro.java index fdf67b4cc5b..a75fa323fdb 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressParseExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressParseExprMacro.java @@ -23,8 +23,10 @@ import org.apache.druid.java.util.common.IAE; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.net.Inet4Address; import java.util.List; @@ -92,6 +94,13 @@ public class IPv4AddressParseExprMacro implements ExprMacroTable.ExprMacro Expr newArg = arg.visit(shuttle); return shuttle.visit(new IPv4AddressParseExpr(newArg)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } } return new IPv4AddressParseExpr(arg); diff --git a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressStringifyExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressStringifyExprMacro.java index 4aea0aa3718..17431a0e592 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressStringifyExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressStringifyExprMacro.java @@ -23,8 +23,10 @@ import org.apache.druid.java.util.common.IAE; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.net.Inet4Address; import java.util.List; @@ -91,6 +93,13 @@ public class IPv4AddressStringifyExprMacro implements ExprMacroTable.ExprMacro Expr newArg = arg.visit(shuttle); return shuttle.visit(new IPv4AddressStringifyExpr(newArg)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.STRING; + } } return new IPv4AddressStringifyExpr(arg); diff --git a/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java index 2332b2858ea..d5bbf02dad0 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java @@ -29,6 +29,7 @@ import org.apache.druid.math.expr.ExprType; import org.apache.druid.query.filter.LikeDimFilter; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; public class LikeExprMacro implements ExprMacroTable.ExprMacro @@ -81,7 +82,7 @@ public class LikeExprMacro implements ExprMacroTable.ExprMacro @Override public ExprEval eval(final ObjectBinding bindings) { - return ExprEval.of(likeMatcher.matches(arg.eval(bindings).asString()), ExprType.LONG); + return ExprEval.ofLongBoolean(likeMatcher.matches(arg.eval(bindings).asString())); } @Override @@ -91,6 +92,13 @@ public class LikeExprMacro implements ExprMacroTable.ExprMacro return shuttle.visit(new LikeExtractExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/LookupExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/LookupExprMacro.java index a827aea2601..6ff028778a4 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/LookupExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/LookupExprMacro.java @@ -26,10 +26,12 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; import org.apache.druid.query.lookup.RegisteredLookupExtractionFn; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; public class LookupExprMacro implements ExprMacroTable.ExprMacro @@ -94,6 +96,13 @@ public class LookupExprMacro implements ExprMacroTable.ExprMacro return shuttle.visit(new LookupExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.STRING; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/RegexpExtractExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/RegexpExtractExprMacro.java index 9bef704a663..3964c1793d7 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/RegexpExtractExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/RegexpExtractExprMacro.java @@ -25,8 +25,10 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -97,6 +99,13 @@ public class RegexpExtractExprMacro implements ExprMacroTable.ExprMacro return shuttle.visit(new RegexpExtractExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.STRING; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/RegexpLikeExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/RegexpLikeExprMacro.java index 83735e86349..9279c84774d 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/RegexpLikeExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/RegexpLikeExprMacro.java @@ -28,6 +28,7 @@ import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -76,10 +77,10 @@ public class RegexpLikeExprMacro implements ExprMacroTable.ExprMacro if (s == null) { // True nulls do not match anything. Note: this branch only executes in SQL-compatible null handling mode. - return ExprEval.of(false, ExprType.LONG); + return ExprEval.ofLongBoolean(false); } else { final Matcher matcher = pattern.matcher(s); - return ExprEval.of(matcher.find(), ExprType.LONG); + return ExprEval.ofLongBoolean(matcher.find()); } } @@ -90,6 +91,13 @@ public class RegexpLikeExprMacro implements ExprMacroTable.ExprMacro return shuttle.visit(new RegexpLikeExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java index 8d6a628d97a..6779bf6ddf7 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java @@ -27,9 +27,11 @@ import org.apache.druid.java.util.common.granularity.PeriodGranularity; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.joda.time.DateTime; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -93,6 +95,13 @@ public class TimestampCeilExprMacro implements ExprMacroTable.ExprMacro return shuttle.visit(new TimestampCeilExpr(newArgs)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public boolean equals(Object o) { @@ -153,5 +162,12 @@ public class TimestampCeilExprMacro implements ExprMacroTable.ExprMacro List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampCeilDynamicExpr(newArgs)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampExtractExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampExtractExprMacro.java index d3184dd5ee3..27807690187 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampExtractExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampExtractExprMacro.java @@ -25,11 +25,13 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.joda.time.chrono.ISOChronology; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; public class TimestampExtractExprMacro implements ExprMacroTable.ExprMacro @@ -162,6 +164,19 @@ public class TimestampExtractExprMacro implements ExprMacroTable.ExprMacro return shuttle.visit(new TimestampExtractExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + switch (unit) { + case CENTURY: + case MILLENNIUM: + return ExprType.DOUBLE; + default: + return ExprType.LONG; + } + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java index aef159ae7ce..a3a95306c0a 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java @@ -25,8 +25,10 @@ import org.apache.druid.java.util.common.granularity.PeriodGranularity; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -111,6 +113,13 @@ public class TimestampFloorExprMacro implements ExprMacroTable.ExprMacro return shuttle.visit(new TimestampFloorExpr(newArgs)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public boolean equals(Object o) { @@ -155,5 +164,12 @@ public class TimestampFloorExprMacro implements ExprMacroTable.ExprMacro List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampFloorDynamicExpr(newArgs)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampFormatExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampFormatExprMacro.java index e7f46966685..455d445fe98 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampFormatExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampFormatExprMacro.java @@ -25,12 +25,14 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.joda.time.DateTimeZone; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; import org.joda.time.format.ISODateTimeFormat; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; public class TimestampFormatExprMacro implements ExprMacroTable.ExprMacro @@ -97,6 +99,13 @@ public class TimestampFormatExprMacro implements ExprMacroTable.ExprMacro return shuttle.visit(new TimestampFormatExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.STRING; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampParseExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampParseExprMacro.java index 535c7332554..935a2b7cbae 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampParseExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampParseExprMacro.java @@ -25,6 +25,7 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.joda.time.DateTimeZone; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; @@ -33,6 +34,7 @@ import org.joda.time.format.DateTimeParser; import org.joda.time.format.ISODateTimeFormat; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; public class TimestampParseExprMacro implements ExprMacroTable.ExprMacro @@ -100,6 +102,13 @@ public class TimestampParseExprMacro implements ExprMacroTable.ExprMacro return shuttle.visit(new TimestampParseExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java index b3f8d1e767f..259d054e411 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java @@ -24,11 +24,13 @@ import org.apache.druid.java.util.common.granularity.PeriodGranularity; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.joda.time.Chronology; import org.joda.time.Period; import org.joda.time.chrono.ISOChronology; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; import java.util.stream.Collectors; @@ -101,6 +103,13 @@ public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampShiftExpr(newArgs)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } } private static class TimestampShiftDynamicExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr @@ -127,5 +136,12 @@ public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampShiftDynamicExpr(newArgs)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java index c7ce44f8fe9..f019edc93e4 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java @@ -26,8 +26,10 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -168,6 +170,13 @@ public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro return shuttle.visit(new TrimStaticCharsExpr(mode, newStringExpr, chars, charsExpr)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.STRING; + } + @Override public String stringify() { @@ -290,13 +299,20 @@ public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro } @Override - public BindingDetails analyzeInputs() + public BindingAnalysis analyzeInputs() { return stringExpr.analyzeInputs() .with(charsExpr) .withScalarArguments(ImmutableSet.of(stringExpr, charsExpr)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.STRING; + } + @Override public boolean equals(Object o) { diff --git a/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java b/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java index 880e90ced2e..acf0dbeaf04 100644 --- a/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java +++ b/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java @@ -45,7 +45,7 @@ import java.util.Set; public class ExpressionFilter implements Filter { private final Supplier expr; - private final Supplier bindingDetails; + private final Supplier bindingDetails; private final FilterTuning filterTuning; public ExpressionFilter(final Supplier expr, final FilterTuning filterTuning) @@ -107,7 +107,7 @@ public class ExpressionFilter implements Filter @Override public boolean supportsBitmapIndex(final BitmapIndexSelector selector) { - final Expr.BindingDetails details = this.bindingDetails.get(); + final Expr.BindingAnalysis details = this.bindingDetails.get(); if (details.getRequiredBindings().isEmpty()) { // Constant expression. diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java index 84fbccd8a57..ed9fe075625 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java @@ -380,8 +380,8 @@ public class JoinFilterCorrelations String identifier = lhsExpr.getBindingIfIdentifier(); if (identifier == null) { // We push down if the function only requires base table columns - Expr.BindingDetails bindingDetails = lhsExpr.analyzeInputs(); - Set requiredBindings = bindingDetails.getRequiredBindings(); + Expr.BindingAnalysis bindingAnalysis = lhsExpr.analyzeInputs(); + Set requiredBindings = bindingAnalysis.getRequiredBindings(); if (joinableClauses.areSomeColumnsFromJoin(requiredBindings)) { break; diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java index 5dd6a997043..f2a0571d610 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java @@ -136,9 +136,9 @@ public class ExpressionSelectors Expr expression ) { - final Expr.BindingDetails exprDetails = expression.analyzeInputs(); - Parser.validateExpr(expression, exprDetails); - final List columns = exprDetails.getRequiredBindingsList(); + final Expr.BindingAnalysis bindingAnalysis = expression.analyzeInputs(); + Parser.validateExpr(expression, bindingAnalysis); + final List columns = bindingAnalysis.getRequiredBindingsList(); if (columns.size() == 1) { final String column = Iterables.getOnlyElement(columns); @@ -155,7 +155,7 @@ public class ExpressionSelectors && capabilities.getType() == ValueType.STRING && capabilities.isDictionaryEncoded().isTrue() && capabilities.hasMultipleValues().isFalse() - && exprDetails.getArrayBindings().isEmpty()) { + && bindingAnalysis.getArrayBindings().isEmpty()) { // Optimization for expressions that hit one scalar string column and nothing else. return new SingleStringInputCachingExpressionColumnValueSelector( columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(column, column, ValueType.STRING)), @@ -165,22 +165,22 @@ public class ExpressionSelectors } final Pair, Set> arrayUsage = - examineColumnSelectorFactoryArrays(columnSelectorFactory, exprDetails, columns); + examineColumnSelectorFactoryArrays(columnSelectorFactory, bindingAnalysis, columns); final Set actualArrays = arrayUsage.lhs; final Set unknownIfArrays = arrayUsage.rhs; final List needsApplied = columns.stream() - .filter(c -> actualArrays.contains(c) && !exprDetails.getArrayBindings().contains(c)) + .filter(c -> actualArrays.contains(c) && !bindingAnalysis.getArrayBindings().contains(c)) .collect(Collectors.toList()); final Expr finalExpr; if (needsApplied.size() > 0) { - finalExpr = Parser.applyUnappliedBindings(expression, exprDetails, needsApplied); + finalExpr = Parser.applyUnappliedBindings(expression, bindingAnalysis, needsApplied); } else { finalExpr = expression; } - final Expr.ObjectBinding bindings = createBindings(exprDetails, columnSelectorFactory); + final Expr.ObjectBinding bindings = createBindings(bindingAnalysis, columnSelectorFactory); if (bindings.equals(ExprUtils.nilBindings())) { // Optimization for constant expressions. @@ -192,7 +192,7 @@ public class ExpressionSelectors if (unknownIfArrays.size() > 0) { return new RowBasedExpressionColumnValueSelector( finalExpr, - exprDetails, + bindingAnalysis, bindings, unknownIfArrays ); @@ -212,9 +212,9 @@ public class ExpressionSelectors @Nullable final ExtractionFn extractionFn ) { - final Expr.BindingDetails exprDetails = expression.analyzeInputs(); - Parser.validateExpr(expression, exprDetails); - final List columns = exprDetails.getRequiredBindingsList(); + final Expr.BindingAnalysis bindingAnalysis = expression.analyzeInputs(); + Parser.validateExpr(expression, bindingAnalysis); + final List columns = bindingAnalysis.getRequiredBindingsList(); if (columns.size() == 1) { final String column = Iterables.getOnlyElement(columns); @@ -226,7 +226,7 @@ public class ExpressionSelectors if (capabilities != null && capabilities.getType() == ValueType.STRING && capabilities.isDictionaryEncoded().isTrue() - && canMapOverDictionary(exprDetails, capabilities.hasMultipleValues()) + && canMapOverDictionary(bindingAnalysis, capabilities.hasMultipleValues()) ) { return new SingleStringInputDimensionSelector( columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(column, column, ValueType.STRING)), @@ -236,14 +236,14 @@ public class ExpressionSelectors } final Pair, Set> arrayUsage = - examineColumnSelectorFactoryArrays(columnSelectorFactory, exprDetails, columns); + examineColumnSelectorFactoryArrays(columnSelectorFactory, bindingAnalysis, columns); final Set actualArrays = arrayUsage.lhs; final Set unknownIfArrays = arrayUsage.rhs; final ColumnValueSelector baseSelector = makeExprEvalSelector(columnSelectorFactory, expression); final boolean multiVal = actualArrays.size() > 0 || - exprDetails.getArrayBindings().size() > 0 || + bindingAnalysis.getArrayBindings().size() > 0 || unknownIfArrays.size() > 0; if (baseSelector instanceof ConstantExprEvalSelector) { @@ -344,30 +344,30 @@ public class ExpressionSelectors * This function should only be called if you have already determined that an expression is over a single column, * and that single column has a dictionary. * - * @param exprDetails result of calling {@link Expr#analyzeInputs()} on an expression + * @param bindingAnalysis result of calling {@link Expr#analyzeInputs()} on an expression * @param hasMultipleValues result of calling {@link ColumnCapabilities#hasMultipleValues()} */ public static boolean canMapOverDictionary( - final Expr.BindingDetails exprDetails, + final Expr.BindingAnalysis bindingAnalysis, final ColumnCapabilities.Capable hasMultipleValues ) { - Preconditions.checkState(exprDetails.getRequiredBindings().size() == 1, "requiredBindings.size == 1"); - return !hasMultipleValues.isUnknown() && !exprDetails.hasInputArrays() && !exprDetails.isOutputArray(); + Preconditions.checkState(bindingAnalysis.getRequiredBindings().size() == 1, "requiredBindings.size == 1"); + return !hasMultipleValues.isUnknown() && !bindingAnalysis.hasInputArrays() && !bindingAnalysis.isOutputArray(); } /** - * Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.BindingDetails} which + * Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.BindingAnalysis} which * provides the set of identifiers which need a binding (list of required columns), and context of whether or not they * are used as array or scalar inputs */ private static Expr.ObjectBinding createBindings( - Expr.BindingDetails bindingDetails, + Expr.BindingAnalysis bindingAnalysis, ColumnSelectorFactory columnSelectorFactory ) { final Map> suppliers = new HashMap<>(); - final List columns = bindingDetails.getRequiredBindingsList(); + final List columns = bindingAnalysis.getRequiredBindingsList(); for (String columnName : columns) { final ColumnCapabilities columnCapabilities = columnSelectorFactory .getColumnCapabilities(columnName); @@ -376,16 +376,13 @@ public class ExpressionSelectors final Supplier supplier; if (nativeType == ValueType.FLOAT) { - ColumnValueSelector selector = columnSelectorFactory - .makeColumnValueSelector(columnName); + ColumnValueSelector selector = columnSelectorFactory.makeColumnValueSelector(columnName); supplier = makeNullableNumericSupplier(selector, selector::getFloat); } else if (nativeType == ValueType.LONG) { - ColumnValueSelector selector = columnSelectorFactory - .makeColumnValueSelector(columnName); + ColumnValueSelector selector = columnSelectorFactory.makeColumnValueSelector(columnName); supplier = makeNullableNumericSupplier(selector, selector::getLong); } else if (nativeType == ValueType.DOUBLE) { - ColumnValueSelector selector = columnSelectorFactory - .makeColumnValueSelector(columnName); + ColumnValueSelector selector = columnSelectorFactory.makeColumnValueSelector(columnName); supplier = makeNullableNumericSupplier(selector, selector::getDouble); } else if (nativeType == ValueType.STRING) { supplier = supplierFromDimensionSelector( @@ -604,7 +601,7 @@ public class ExpressionSelectors */ private static Pair, Set> examineColumnSelectorFactoryArrays( ColumnSelectorFactory columnSelectorFactory, - Expr.BindingDetails exprDetails, + Expr.BindingAnalysis bindingAnalysis, List columns ) { @@ -618,7 +615,7 @@ public class ExpressionSelectors } else if ( capabilities.getType().equals(ValueType.STRING) && capabilities.hasMultipleValues().isMaybeTrue() && - !exprDetails.getArrayBindings().contains(column) + !bindingAnalysis.getArrayBindings().contains(column) ) { unknownIfArrays.add(column); } diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java index 727f1e4a243..5a33bc771b4 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java @@ -40,22 +40,22 @@ import java.util.stream.Collectors; public class RowBasedExpressionColumnValueSelector extends ExpressionColumnValueSelector { private final List unknownColumns; - private final Expr.BindingDetails baseExprBindingDetails; + private final Expr.BindingAnalysis baseBindingAnalysis; private final Set ignoredColumns; private final Int2ObjectMap transformedCache; public RowBasedExpressionColumnValueSelector( Expr expression, - Expr.BindingDetails baseExprBindingDetails, + Expr.BindingAnalysis baseBindingAnalysis, Expr.ObjectBinding bindings, Set unknownColumnsSet ) { super(expression, bindings); this.unknownColumns = unknownColumnsSet.stream() - .filter(x -> !baseExprBindingDetails.getArrayBindings().contains(x)) + .filter(x -> !baseBindingAnalysis.getArrayBindings().contains(x)) .collect(Collectors.toList()); - this.baseExprBindingDetails = baseExprBindingDetails; + this.baseBindingAnalysis = baseBindingAnalysis; this.ignoredColumns = new HashSet<>(); this.transformedCache = new Int2ObjectArrayMap<>(unknownColumns.size()); } @@ -79,7 +79,7 @@ public class RowBasedExpressionColumnValueSelector extends ExpressionColumnValue if (transformedCache.containsKey(key)) { return transformedCache.get(key).eval(bindings); } - Expr transformed = Parser.applyUnappliedBindings(expression, baseExprBindingDetails, arrayBindings); + Expr transformed = Parser.applyUnappliedBindings(expression, baseBindingAnalysis, arrayBindings); transformedCache.put(key, transformed); return transformed.eval(bindings); } diff --git a/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java index a6bdfb36a03..b57db64b232 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java @@ -22,7 +22,6 @@ package org.apache.druid.query.expression; import com.google.common.collect.ImmutableMap; import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.ExprEval; -import org.apache.druid.math.expr.ExprType; import org.apache.druid.math.expr.Parser; import org.junit.Assert; import org.junit.Test; @@ -53,7 +52,7 @@ public class RegexpLikeExprMacroTest extends MacroTestBase { final ExprEval result = eval("regexp_like(a, 'f.o')", Parser.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofLongBoolean(true).value(), result.value() ); } @@ -63,7 +62,7 @@ public class RegexpLikeExprMacroTest extends MacroTestBase { final ExprEval result = eval("regexp_like(a, 'f.x')", Parser.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(false, ExprType.LONG).value(), + ExprEval.ofLongBoolean(false).value(), result.value() ); } @@ -77,7 +76,7 @@ public class RegexpLikeExprMacroTest extends MacroTestBase final ExprEval result = eval("regexp_like(a, null)", Parser.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofLongBoolean(true).value(), result.value() ); } @@ -87,7 +86,7 @@ public class RegexpLikeExprMacroTest extends MacroTestBase { final ExprEval result = eval("regexp_like(a, '')", Parser.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofLongBoolean(true).value(), result.value() ); } @@ -101,7 +100,7 @@ public class RegexpLikeExprMacroTest extends MacroTestBase final ExprEval result = eval("regexp_like(a, null)", Parser.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofLongBoolean(true).value(), result.value() ); } @@ -111,7 +110,7 @@ public class RegexpLikeExprMacroTest extends MacroTestBase { final ExprEval result = eval("regexp_like(a, '')", Parser.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofLongBoolean(true).value(), result.value() ); } @@ -125,7 +124,7 @@ public class RegexpLikeExprMacroTest extends MacroTestBase final ExprEval result = eval("regexp_like(a, null)", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofLongBoolean(true).value(), result.value() ); } @@ -135,7 +134,7 @@ public class RegexpLikeExprMacroTest extends MacroTestBase { final ExprEval result = eval("regexp_like(a, '')", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( - ExprEval.of(NullHandling.replaceWithDefault(), ExprType.LONG).value(), + ExprEval.ofLongBoolean(NullHandling.replaceWithDefault()).value(), result.value() ); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java index f76d8352bbb..5aa9a9e3645 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java @@ -37,8 +37,8 @@ class ReductionOperatorConversionHelper * Implements type precedence rules similar to: * https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least * - * @see org.apache.druid.math.expr.Function.ReduceFunc#apply - * @see org.apache.druid.math.expr.Function.ReduceFunc#getComparisionType + * @see org.apache.druid.math.expr.Function.ReduceFunction#apply + * @see org.apache.druid.math.expr.ExprType#functionAutoTypeConversion */ static final SqlReturnTypeInference TYPE_INFERENCE = opBinding -> { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/Projection.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/Projection.java index d353a4ae116..02cc3ce2aae 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/Projection.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/Projection.java @@ -316,8 +316,8 @@ public class Projection } // Check if a cast is necessary. - final ExprType toExprType = Expressions.exprTypeForValueType(columnValueType); - final ExprType fromExprType = Expressions.exprTypeForValueType( + final ExprType toExprType = ExprType.fromValueType(columnValueType); + final ExprType fromExprType = ExprType.fromValueType( Calcites.getValueTypeForRelDataType(rexNode.getType()) );