diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index bf6a8d2e050..af52555793b 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -35,9 +35,14 @@ import java.math.RoundingMode; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Comparator; +import java.util.EnumSet; import java.util.List; import java.util.Objects; import java.util.Set; +import java.util.function.BinaryOperator; +import java.util.function.DoubleBinaryOperator; +import java.util.function.LongBinaryOperator; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -48,7 +53,7 @@ import java.util.stream.Stream; * Do NOT remove "unused" members in this class. They are used by generated Antlr */ @SuppressWarnings("unused") -interface Function +public interface Function { /** * Name of the function. @@ -976,6 +981,172 @@ interface Function } } + class GreatestFunc extends ReduceFunc + { + public static final String NAME = "greatest"; + + public GreatestFunc() + { + super( + Math::max, + Math::max, + BinaryOperator.maxBy(Comparator.naturalOrder()) + ); + } + + @Override + public String name() + { + return NAME; + } + } + + class LeastFunc extends ReduceFunc + { + public static final String NAME = "least"; + + public LeastFunc() + { + super( + Math::min, + Math::min, + BinaryOperator.minBy(Comparator.naturalOrder()) + ); + } + + @Override + public String name() + { + return NAME; + } + } + + abstract class ReduceFunc implements Function + { + private final DoubleBinaryOperator doubleReducer; + private final LongBinaryOperator longReducer; + private final BinaryOperator stringReducer; + + ReduceFunc( + DoubleBinaryOperator doubleReducer, + LongBinaryOperator longReducer, + BinaryOperator stringReducer + ) + { + this.doubleReducer = doubleReducer; + this.longReducer = longReducer; + this.stringReducer = stringReducer; + } + + @Override + public void validateArguments(List args) + { + // anything goes + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + if (args.isEmpty()) { + return ExprEval.of(null); + } + + ExprAnalysis exprAnalysis = analyzeExprs(args, bindings); + if (exprAnalysis.exprEvals.isEmpty()) { + // The GREATEST/LEAST functions are not in the SQL standard. Emulate the behavior of postgres (return null if + // all expressions are null, otherwise skip null values) since it is used as a base for a wide number of + // databases. This also matches the behavior the the long/double greatest/least post aggregators. Some other + // databases (e.g., MySQL) return null if any expression is null. + // https://www.postgresql.org/docs/9.5/functions-conditional.html + // https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least + return ExprEval.of(null); + } + + Stream> exprEvalStream = exprAnalysis.exprEvals.stream(); + switch (exprAnalysis.comparisonType) { + case DOUBLE: + //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) + return ExprEval.of(exprEvalStream.mapToDouble(ExprEval::asDouble).reduce(doubleReducer).getAsDouble()); + case LONG: + //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) + return ExprEval.of(exprEvalStream.mapToLong(ExprEval::asLong).reduce(longReducer).getAsLong()); + default: + //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) + return ExprEval.of(exprEvalStream.map(ExprEval::asString).reduce(stringReducer).get()); + } + } + + /** + * Determines which {@link ExprType} to use to compare non-null evaluated expressions. + * + * @param exprs Expressions to analyze + * @param bindings Bindings for expressions + * + * @return Comparison type and non-null evaluated expressions. + */ + private ExprAnalysis analyzeExprs(List exprs, Expr.ObjectBinding bindings) + { + Set presentTypes = EnumSet.noneOf(ExprType.class); + List> exprEvals = new ArrayList<>(); + + for (Expr expr : exprs) { + ExprEval exprEval = expr.eval(bindings); + ExprType exprType = exprEval.type(); + + if (isValidType(exprType)) { + presentTypes.add(exprType); + } + + if (exprEval.value() != null) { + exprEvals.add(exprEval); + } + } + + ExprType comparisonType = getComparisionType(presentTypes); + return new ExprAnalysis(comparisonType, exprEvals); + } + + private boolean isValidType(ExprType exprType) + { + switch (exprType) { + case DOUBLE: + case LONG: + case STRING: + return true; + default: + throw new IAE("Function[%s] does not accept %s types", name(), exprType); + } + } + + /** + * Implements rules similar to: https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least + * + * @see org.apache.druid.sql.calcite.expression.builtin.ReductionOperatorConversionHelper#TYPE_INFERENCE + */ + private static ExprType getComparisionType(Set exprTypes) + { + if (exprTypes.contains(ExprType.STRING)) { + return ExprType.STRING; + } else if (exprTypes.contains(ExprType.DOUBLE)) { + return ExprType.DOUBLE; + } else { + return ExprType.LONG; + } + } + + private static class ExprAnalysis + { + final ExprType comparisonType; + final List> exprEvals; + + ExprAnalysis(ExprType comparisonType, List> exprEvals) + { + this.comparisonType = comparisonType; + this.exprEvals = exprEvals; + } + } + } + class NextAfter extends BivariateMathFunction { @Override @@ -2390,6 +2561,7 @@ interface Function throw new RE("Unable to prepend to unknown type %s", arrayExpr.type()); } + private Stream prepend(T val, T[] array) { List l = new ArrayList<>(Arrays.asList(array)); diff --git a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java index c2b84ffaa7d..9239e778390 100644 --- a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -26,6 +26,8 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import javax.annotation.Nullable; + public class FunctionTest extends InitializedNullHandlingTest { private Expr.ObjectBinding bindings; @@ -68,11 +70,11 @@ public class FunctionTest extends InitializedNullHandlingTest if (NullHandling.replaceWithDefault()) { assertExpr("concat(x,' ',nonexistent,' ',y)", "foo 2"); } else { - assertExpr("concat(x,' ',nonexistent,' ',y)", null); + assertArrayExpr("concat(x,' ',nonexistent,' ',y)", null); } assertExpr("concat(z)", "3.1"); - assertExpr("concat()", null); + assertArrayExpr("concat()", null); } @Test @@ -144,9 +146,9 @@ public class FunctionTest extends InitializedNullHandlingTest assertExpr("lpad(x, 5, 'ab')", "abfoo"); assertExpr("lpad(x, 4, 'ab')", "afoo"); assertExpr("lpad(x, 2, 'ab')", "fo"); - assertExpr("lpad(x, 0, 'ab')", null); - assertExpr("lpad(x, 5, null)", null); - assertExpr("lpad(null, 5, x)", null); + assertArrayExpr("lpad(x, 0, 'ab')", null); + assertArrayExpr("lpad(x, 5, null)", null); + assertArrayExpr("lpad(null, 5, x)", null); } @Test @@ -155,18 +157,18 @@ public class FunctionTest extends InitializedNullHandlingTest assertExpr("rpad(x, 5, 'ab')", "fooab"); assertExpr("rpad(x, 4, 'ab')", "fooa"); assertExpr("rpad(x, 2, 'ab')", "fo"); - assertExpr("rpad(x, 0, 'ab')", null); - assertExpr("rpad(x, 5, null)", null); - assertExpr("rpad(null, 5, x)", null); + assertArrayExpr("rpad(x, 0, 'ab')", null); + assertArrayExpr("rpad(x, 5, null)", null); + assertArrayExpr("rpad(null, 5, x)", null); } @Test public void testArrayConstructor() { - assertExpr("array(1, 2, 3, 4)", new Long[]{1L, 2L, 3L, 4L}); - assertExpr("array(1, 2, 3, 'bar')", new Long[]{1L, 2L, 3L, null}); - assertExpr("array(1.0)", new Double[]{1.0}); - assertExpr("array('foo', 'bar')", new String[]{"foo", "bar"}); + assertArrayExpr("array(1, 2, 3, 4)", new Long[]{1L, 2L, 3L, 4L}); + assertArrayExpr("array(1, 2, 3, 'bar')", new Long[]{1L, 2L, 3L, null}); + assertArrayExpr("array(1.0)", new Double[]{1.0}); + assertArrayExpr("array('foo', 'bar')", new String[]{"foo", "bar"}); } @Test @@ -180,7 +182,7 @@ public class FunctionTest extends InitializedNullHandlingTest public void testArrayOffset() { assertExpr("array_offset([1, 2, 3], 2)", 3L); - assertExpr("array_offset([1, 2, 3], 3)", null); + assertArrayExpr("array_offset([1, 2, 3], 3)", null); assertExpr("array_offset(a, 2)", "baz"); } @@ -188,7 +190,7 @@ public class FunctionTest extends InitializedNullHandlingTest public void testArrayOrdinal() { assertExpr("array_ordinal([1, 2, 3], 3)", 3L); - assertExpr("array_ordinal([1, 2, 3], 4)", null); + assertArrayExpr("array_ordinal([1, 2, 3], 4)", null); assertExpr("array_ordinal(a, 3)", "baz"); } @@ -228,20 +230,20 @@ public class FunctionTest extends InitializedNullHandlingTest @Test public void testArrayAppend() { - assertExpr("array_append([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L}); - assertExpr("array_append([1, 2, 3], 'bar')", new Long[]{1L, 2L, 3L, null}); - assertExpr("array_append([], 1)", new String[]{"1"}); - assertExpr("array_append([], 1)", new Long[]{1L}); + assertArrayExpr("array_append([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L}); + assertArrayExpr("array_append([1, 2, 3], 'bar')", new Long[]{1L, 2L, 3L, null}); + assertArrayExpr("array_append([], 1)", new String[]{"1"}); + assertArrayExpr("array_append([], 1)", new Long[]{1L}); } @Test public void testArrayConcat() { - assertExpr("array_concat([1, 2, 3], [2, 4, 6])", new Long[]{1L, 2L, 3L, 2L, 4L, 6L}); - assertExpr("array_concat([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L}); - assertExpr("array_concat(0, [1, 2, 3])", new Long[]{0L, 1L, 2L, 3L}); - assertExpr("array_concat(map(y -> y * 3, b), [1, 2, 3])", new Long[]{3L, 6L, 9L, 12L, 15L, 1L, 2L, 3L}); - assertExpr("array_concat(0, 1)", new Long[]{0L, 1L}); + assertArrayExpr("array_concat([1, 2, 3], [2, 4, 6])", new Long[]{1L, 2L, 3L, 2L, 4L, 6L}); + assertArrayExpr("array_concat([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L}); + assertArrayExpr("array_concat(0, [1, 2, 3])", new Long[]{0L, 1L, 2L, 3L}); + assertArrayExpr("array_concat(map(y -> y * 3, b), [1, 2, 3])", new Long[]{3L, 6L, 9L, 12L, 15L, 1L, 2L, 3L}); + assertArrayExpr("array_concat(0, 1)", new Long[]{0L, 1L}); } @Test @@ -255,43 +257,99 @@ public class FunctionTest extends InitializedNullHandlingTest @Test public void testStringToArray() { - assertExpr("string_to_array('1,2,3', ',')", new String[]{"1", "2", "3"}); - assertExpr("string_to_array('1', ',')", new String[]{"1"}); - assertExpr("string_to_array(array_to_string(a, ','), ',')", new String[]{"foo", "bar", "baz", "foobar"}); + assertArrayExpr("string_to_array('1,2,3', ',')", new String[]{"1", "2", "3"}); + assertArrayExpr("string_to_array('1', ',')", new String[]{"1"}); + assertArrayExpr("string_to_array(array_to_string(a, ','), ',')", new String[]{"foo", "bar", "baz", "foobar"}); } @Test public void testArrayCast() { - assertExpr("cast([1, 2, 3], 'STRING_ARRAY')", new String[]{"1", "2", "3"}); - assertExpr("cast([1, 2, 3], 'DOUBLE_ARRAY')", new Double[]{1.0, 2.0, 3.0}); - assertExpr("cast(c, 'LONG_ARRAY')", new Long[]{3L, 4L, 5L}); - assertExpr("cast(string_to_array(array_to_string(b, ','), ','), 'LONG_ARRAY')", new Long[]{1L, 2L, 3L, 4L, 5L}); - assertExpr("cast(['1.0', '2.0', '3.0'], 'LONG_ARRAY')", new Long[]{1L, 2L, 3L}); + assertArrayExpr("cast([1, 2, 3], 'STRING_ARRAY')", new String[]{"1", "2", "3"}); + assertArrayExpr("cast([1, 2, 3], 'DOUBLE_ARRAY')", new Double[]{1.0, 2.0, 3.0}); + assertArrayExpr("cast(c, 'LONG_ARRAY')", new Long[]{3L, 4L, 5L}); + assertArrayExpr("cast(string_to_array(array_to_string(b, ','), ','), 'LONG_ARRAY')", new Long[]{1L, 2L, 3L, 4L, 5L}); + assertArrayExpr("cast(['1.0', '2.0', '3.0'], 'LONG_ARRAY')", new Long[]{1L, 2L, 3L}); } @Test public void testArraySlice() { - assertExpr("array_slice([1, 2, 3, 4], 1, 3)", new Long[] {2L, 3L}); - assertExpr("array_slice([1.0, 2.1, 3.2, 4.3], 2)", new Double[] {3.2, 4.3}); - assertExpr("array_slice(['a', 'b', 'c', 'd'], 4, 6)", new String[] {null, null}); - assertExpr("array_slice([1, 2, 3, 4], 2, 2)", new Long[] {}); - assertExpr("array_slice([1, 2, 3, 4], 5, 7)", null); - assertExpr("array_slice([1, 2, 3, 4], 2, 1)", null); + assertArrayExpr("array_slice([1, 2, 3, 4], 1, 3)", new Long[] {2L, 3L}); + assertArrayExpr("array_slice([1.0, 2.1, 3.2, 4.3], 2)", new Double[] {3.2, 4.3}); + assertArrayExpr("array_slice(['a', 'b', 'c', 'd'], 4, 6)", new String[] {null, null}); + assertArrayExpr("array_slice([1, 2, 3, 4], 2, 2)", new Long[] {}); + assertArrayExpr("array_slice([1, 2, 3, 4], 5, 7)", null); + assertArrayExpr("array_slice([1, 2, 3, 4], 2, 1)", null); } @Test public void testArrayPrepend() { - assertExpr("array_prepend(4, [1, 2, 3])", new Long[]{4L, 1L, 2L, 3L}); - assertExpr("array_prepend('bar', [1, 2, 3])", new Long[]{null, 1L, 2L, 3L}); - assertExpr("array_prepend(1, [])", new String[]{"1"}); - assertExpr("array_prepend(1, [])", new Long[]{1L}); - assertExpr("array_prepend(1, [])", new Double[]{1.0}); + assertArrayExpr("array_prepend(4, [1, 2, 3])", new Long[]{4L, 1L, 2L, 3L}); + assertArrayExpr("array_prepend('bar', [1, 2, 3])", new Long[]{null, 1L, 2L, 3L}); + assertArrayExpr("array_prepend(1, [])", new String[]{"1"}); + assertArrayExpr("array_prepend(1, [])", new Long[]{1L}); + assertArrayExpr("array_prepend(1, [])", new Double[]{1.0}); } - private void assertExpr(final String expression, final Object expectedResult) + @Test + public void testGreatest() + { + // Same types + assertExpr("greatest(y, 0)", 2L); + assertExpr("greatest(34.0, z, 5.0, 767.0", 767.0); + assertExpr("greatest('B', x, 'A')", "foo"); + + // Different types + assertExpr("greatest(-1, z, 'A')", "A"); + assertExpr("greatest(-1, z)", 3.1); + assertExpr("greatest(1, 'A')", "A"); + + // Invalid types + try { + assertExpr("greatest(1, ['A'])", null); + Assert.fail("Did not throw IllegalArgumentException"); + } + catch (IllegalArgumentException e) { + Assert.assertEquals("Function[greatest] does not accept STRING_ARRAY types", e.getMessage()); + } + + // Null handling + assertExpr("greatest()", null); + assertExpr("greatest(null, null)", null); + assertExpr("greatest(1, null, 'A')", "A"); + } + + @Test + public void testLeast() + { + // Same types + assertExpr("least(y, 0)", 0L); + assertExpr("least(34.0, z, 5.0, 767.0", 3.1); + assertExpr("least('B', x, 'A')", "A"); + + // Different types + assertExpr("least(-1, z, 'A')", "-1"); + assertExpr("least(-1, z)", -1.0); + assertExpr("least(1, 'A')", "1"); + + // Invalid types + try { + assertExpr("least(1, [2, 3])", null); + Assert.fail("Did not throw IllegalArgumentException"); + } + catch (IllegalArgumentException e) { + Assert.assertEquals("Function[least] does not accept LONG_ARRAY types", e.getMessage()); + } + + // Null handling + assertExpr("least()", null); + assertExpr("least(null, null)", null); + assertExpr("least(1, null, 'A')", "1"); + } + + private void assertExpr(final String expression, @Nullable final Object expectedResult) { final Expr expr = Parser.parse(expression, ExprMacroTable.nil()); Assert.assertEquals(expression, expectedResult, expr.eval(bindings).value()); @@ -307,7 +365,7 @@ public class FunctionTest extends InitializedNullHandlingTest Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify()); } - private void assertExpr(final String expression, final Object[] expectedResult) + private void assertArrayExpr(final String expression, @Nullable final Object[] expectedResult) { final Expr expr = Parser.parse(expression, ExprMacroTable.nil()); Assert.assertArrayEquals(expression, expectedResult, expr.eval(bindings).asArray()); diff --git a/docs/misc/math-expr.md b/docs/misc/math-expr.md index 6407433982e..8793d63c048 100644 --- a/docs/misc/math-expr.md +++ b/docs/misc/math-expr.md @@ -181,6 +181,22 @@ See javadoc of java.lang.Math for detailed explanation for each function. | all(lambda,arr) | returns 1 if all elements in the array matches the lambda expression, else 0 | +### Reduction functions + +Reduction functions operate on zero or more expressions and return a single expression. If no expressions are passed as +arguments, then the result is `NULL`. The expressions must all be convertible to a common data type, which will be the +type of the result: +* If all arguments are `NULL`, the result is `NULL`. Otherwise, `NULL` arguments are ignored. +* If the arguments comprise a mix of numbers and strings, the arguments are interpreted as strings. +* If all arguments are integer numbers, the arguments are interpreted as longs. +* If all arguments are numbers and at least one argument is a double, the arguments are interpreted as doubles. + +| function | description | +| --- | --- | +| greatest([expr1, ...]) | Evaluates zero or more expressions and returns the maximum value based on comparisons as described above. | +| least([expr1, ...]) | Evaluates zero or more expressions and returns the minimum value based on comparisons as described above. | + + ## IP address functions For the IPv4 address functions, the `address` argument can either be an IPv4 dotted-decimal string (e.g., "192.168.0.1") or an IP address represented as a long (e.g., 3232235521). The `subnet` argument should be a string formatted as an IPv4 address subnet in CIDR notation (e.g., "192.168.0.0/16"). diff --git a/docs/querying/sql.md b/docs/querying/sql.md index bd228e712f1..42d38922b28 100644 --- a/docs/querying/sql.md +++ b/docs/querying/sql.md @@ -198,8 +198,6 @@ Only the COUNT aggregation can accept DISTINCT. |`SUM(expr)`|Sums numbers.| |`MIN(expr)`|Takes the minimum of numbers.| |`MAX(expr)`|Takes the maximum of numbers.| -|`LEAST(expr1, [expr2, ...])`|Takes the minimum of numbers across one or more expression(s).| -|`GREATEST(expr1, [expr2, ...])`|Takes the maximum of numbers across one or more expression(s).| |`AVG(expr)`|Averages numbers.| |`APPROX_COUNT_DISTINCT(expr)`|Counts distinct values of expr, which can be a regular column or a hyperUnique column. This is always approximate, regardless of the value of "useApproximateCountDistinct". This uses Druid's built-in "cardinality" or "hyperUnique" aggregators. See also `COUNT(DISTINCT expr)`.| |`APPROX_COUNT_DISTINCT_DS_HLL(expr, [lgK, tgtHllType])`|Counts distinct values of expr, which can be a regular column or an [HLL sketch](../development/extensions-core/datasketches-hll.html) column. The `lgK` and `tgtHllType` parameters are described in the HLL sketch documentation. This is always approximate, regardless of the value of "useApproximateCountDistinct". See also `COUNT(DISTINCT expr)`. The [DataSketches extension](../development/extensions-core/datasketches-extension.html) must be loaded to use this function.| @@ -334,6 +332,22 @@ simplest way to write literal timestamps in other time zones is to use TIME_PARS |timestamp_expr { + | - } |Add or subtract an amount of time from a timestamp. interval_expr can include interval literals like `INTERVAL '2' HOUR`, and may include interval arithmetic as well. This operator treats days as uniformly 86400 seconds long, and does not take into account daylight savings time. To account for daylight savings time, use TIME_SHIFT instead.| +### Reduction functions + +Reduction functions operate on zero or more expressions and return a single expression. If no expressions are passed as +arguments, then the result is `NULL`. The expressions must all be convertible to a common data type, which will be the +type of the result: +* If all argument are `NULL`, the result is `NULL`. Otherwise, `NULL` arguments are ignored. +* If the arguments comprise a mix of numbers and strings, the arguments are interpreted as strings. +* If all arguments are integer numbers, the arguments are interpreted as longs. +* If all arguments are numbers and at least one argument is a double, the arguments are interpreted as doubles. + +|Function|Notes| +|--------|-----| +|`GREATEST([expr1, ...])`|Evaluates zero or more expressions and returns the maximum value based on comparisons as described above.| +|`LEAST([expr1, ...])`|Evaluates zero or more expressions and returns the minimum value based on comparisons as described above.| + + ### IP address functions For the IPv4 address functions, the `address` argument can either be an IPv4 dotted-decimal string diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GreatestSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GreatestSqlAggregator.java deleted file mode 100644 index defead45ebf..00000000000 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GreatestSqlAggregator.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.druid.sql.calcite.aggregation.builtin; - -import org.apache.calcite.sql.SqlAggFunction; -import org.apache.calcite.sql.SqlFunctionCategory; -import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.InferTypes; -import org.apache.calcite.sql.type.OperandTypes; -import org.apache.calcite.sql.type.ReturnTypes; -import org.apache.calcite.sql.type.SqlTypeTransforms; -import org.apache.calcite.util.Optionality; -import org.apache.druid.java.util.common.ISE; -import org.apache.druid.math.expr.ExprMacroTable; -import org.apache.druid.query.aggregation.AggregatorFactory; -import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory; -import org.apache.druid.query.aggregation.LongMaxAggregatorFactory; -import org.apache.druid.query.aggregation.PostAggregator; -import org.apache.druid.query.aggregation.post.DoubleGreatestPostAggregator; -import org.apache.druid.query.aggregation.post.LongGreatestPostAggregator; -import org.apache.druid.segment.column.ValueType; - -import java.util.List; - -/** - * Calcite integration class for Greatest post aggregators of Long & Double types. - * It applies Max aggregators over the provided fields/expressions & combines their results via Field access post aggregators. - */ -public class GreatestSqlAggregator extends MultiColumnSqlAggregator -{ - private static final SqlAggFunction FUNCTION_INSTANCE = new GreatestSqlAggFunction(); - private static final String NAME = "GREATEST"; - - @Override - public SqlAggFunction calciteFunction() - { - return FUNCTION_INSTANCE; - } - - @Override - AggregatorFactory createAggregatorFactory( - ValueType valueType, - String prefixedName, - FieldInfo fieldInfo, - ExprMacroTable macroTable - ) - { - final AggregatorFactory aggregatorFactory; - switch (valueType) { - case LONG: - aggregatorFactory = new LongMaxAggregatorFactory(prefixedName, fieldInfo.fieldName, fieldInfo.expression, macroTable); - break; - case FLOAT: - case DOUBLE: - aggregatorFactory = new DoubleMaxAggregatorFactory(prefixedName, fieldInfo.fieldName, fieldInfo.expression, macroTable); - break; - default: - throw new ISE("Cannot create aggregator factory for type[%s]", valueType); - } - return aggregatorFactory; - } - - @Override - PostAggregator createFinalPostAggregator( - ValueType valueType, - String name, - List postAggregators - ) - { - final PostAggregator finalPostAggregator; - switch (valueType) { - case LONG: - finalPostAggregator = new LongGreatestPostAggregator(name, postAggregators); - break; - case FLOAT: - case DOUBLE: - finalPostAggregator = new DoubleGreatestPostAggregator(name, postAggregators); - break; - default: - throw new ISE("Cannot create aggregator factory for type[%s]", valueType); - } - return finalPostAggregator; - } - - /** - * Calcite SQL function definition - */ - private static class GreatestSqlAggFunction extends SqlAggFunction - { - GreatestSqlAggFunction() - { - /* - * The constructor params are explained as follows, - * name: SQL function name - * sqlIdentifier: null for built-in functions - * kind: SqlKind.GREATEST - * returnTypeInference: biggest operand type & nullable if any of the operands is nullable - * operandTypeInference: same as return type - * operandTypeChecker: variadic function with at least one argument - * funcType: System - * requiresOrder: No - * requiresOver: No - * requiresGroupOrder: Not allowed - */ - super( - NAME, - null, - SqlKind.GREATEST, - ReturnTypes.cascade(ReturnTypes.LEAST_RESTRICTIVE, SqlTypeTransforms.TO_NULLABLE), - InferTypes.RETURN_TYPE, - OperandTypes.ONE_OR_MORE, - SqlFunctionCategory.SYSTEM, - false, - false, - Optionality.FORBIDDEN - ); - } - } -} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/LeastSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/LeastSqlAggregator.java deleted file mode 100644 index e11c24cce41..00000000000 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/LeastSqlAggregator.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.druid.sql.calcite.aggregation.builtin; - -import org.apache.calcite.sql.SqlAggFunction; -import org.apache.calcite.sql.SqlFunctionCategory; -import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.InferTypes; -import org.apache.calcite.sql.type.OperandTypes; -import org.apache.calcite.sql.type.ReturnTypes; -import org.apache.calcite.sql.type.SqlTypeTransforms; -import org.apache.calcite.util.Optionality; -import org.apache.druid.java.util.common.ISE; -import org.apache.druid.math.expr.ExprMacroTable; -import org.apache.druid.query.aggregation.AggregatorFactory; -import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory; -import org.apache.druid.query.aggregation.LongMinAggregatorFactory; -import org.apache.druid.query.aggregation.PostAggregator; -import org.apache.druid.query.aggregation.post.DoubleLeastPostAggregator; -import org.apache.druid.query.aggregation.post.LongLeastPostAggregator; -import org.apache.druid.segment.column.ValueType; - -import java.util.List; - -/** - * Calcite integration class for Least post aggregators of Long & Double types. - * It applies Min aggregators over the provided fields/expressions & combines their results via Field access post aggregators. - */ -public class LeastSqlAggregator extends MultiColumnSqlAggregator -{ - private static final SqlAggFunction FUNCTION_INSTANCE = new LeastSqlAggFunction(); - private static final String NAME = "LEAST"; - - @Override - public SqlAggFunction calciteFunction() - { - return FUNCTION_INSTANCE; - } - - @Override - AggregatorFactory createAggregatorFactory( - ValueType valueType, - String prefixedName, - FieldInfo fieldInfo, - ExprMacroTable macroTable - ) - { - final AggregatorFactory aggregatorFactory; - switch (valueType) { - case LONG: - aggregatorFactory = new LongMinAggregatorFactory(prefixedName, fieldInfo.fieldName, fieldInfo.expression, macroTable); - break; - case FLOAT: - case DOUBLE: - aggregatorFactory = new DoubleMinAggregatorFactory(prefixedName, fieldInfo.fieldName, fieldInfo.expression, macroTable); - break; - default: - throw new ISE("Cannot create aggregator factory for type[%s]", valueType); - } - return aggregatorFactory; - } - - @Override - PostAggregator createFinalPostAggregator( - ValueType valueType, - String name, - List postAggregators - ) - { - final PostAggregator finalPostAggregator; - switch (valueType) { - case LONG: - finalPostAggregator = new LongLeastPostAggregator(name, postAggregators); - break; - case FLOAT: - case DOUBLE: - finalPostAggregator = new DoubleLeastPostAggregator(name, postAggregators); - break; - default: - throw new ISE("Cannot create aggregator factory for type[%s]", valueType); - } - return finalPostAggregator; - } - - /** - * Calcite SQL function definition - */ - private static class LeastSqlAggFunction extends SqlAggFunction - { - LeastSqlAggFunction() - { - /* - * The constructor params are explained as follows, - * name: SQL function name - * sqlIdentifier: null for built-in functions - * kind: SqlKind.LEAST - * returnTypeInference: biggest operand type & nullable if any of the operands is nullable - * operandTypeInference: same as return type - * operandTypeChecker: variadic function with at least one argument - * funcType: System - * requiresOrder: No - * requiresOver: No - * requiresGroupOrder: Not allowed - */ - super( - NAME, - null, - SqlKind.LEAST, - ReturnTypes.cascade(ReturnTypes.LEAST_RESTRICTIVE, SqlTypeTransforms.TO_NULLABLE), - InferTypes.RETURN_TYPE, - OperandTypes.ONE_OR_MORE, - SqlFunctionCategory.SYSTEM, - false, - false, - Optionality.FORBIDDEN - ); - } - } -} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/GreatestOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/GreatestOperatorConversion.java new file mode 100644 index 00000000000..ebdb99931ad --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/GreatestOperatorConversion.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.expression.builtin; + +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.math.expr.Function; +import org.apache.druid.sql.calcite.expression.DirectOperatorConversion; +import org.apache.druid.sql.calcite.expression.OperatorConversions; + +public class GreatestOperatorConversion extends DirectOperatorConversion +{ + private static final SqlFunction SQL_FUNCTION = OperatorConversions + .operatorBuilder(StringUtils.toUpperCase(Function.GreatestFunc.NAME)) + .operandTypeChecker(OperandTypes.VARIADIC) + .returnTypeInference(ReductionOperatorConversionHelper.TYPE_INFERENCE) + .build(); + + public GreatestOperatorConversion() + { + super(SQL_FUNCTION, Function.GreatestFunc.NAME); + } + + @Override + public SqlOperator calciteOperator() + { + return SQL_FUNCTION; + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeastOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeastOperatorConversion.java new file mode 100644 index 00000000000..09578c6eaa9 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeastOperatorConversion.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.expression.builtin; + +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.math.expr.Function; +import org.apache.druid.sql.calcite.expression.DirectOperatorConversion; +import org.apache.druid.sql.calcite.expression.OperatorConversions; + +public class LeastOperatorConversion extends DirectOperatorConversion +{ + private static final SqlFunction SQL_FUNCTION = OperatorConversions + .operatorBuilder(StringUtils.toUpperCase(Function.LeastFunc.NAME)) + .operandTypeChecker(OperandTypes.VARIADIC) + .returnTypeInference(ReductionOperatorConversionHelper.TYPE_INFERENCE) + .build(); + + public LeastOperatorConversion() + { + super(SQL_FUNCTION, Function.LeastFunc.NAME); + } + + @Override + public SqlOperator calciteOperator() + { + return SQL_FUNCTION; + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java new file mode 100644 index 00000000000..595955e1c0d --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.expression.builtin; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.sql.calcite.planner.Calcites; + +class ReductionOperatorConversionHelper +{ + private ReductionOperatorConversionHelper() + { + } + + /** + * Implements type precedence rules similar to: + * https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least + * + * @see org.apache.druid.math.expr.Function.ReduceFunc#apply + * @see org.apache.druid.math.expr.Function.ReduceFunc#getComparisionType + */ + static final SqlReturnTypeInference TYPE_INFERENCE = + opBinding -> { + final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + + final int n = opBinding.getOperandCount(); + if (n == 0) { + return typeFactory.createSqlType(SqlTypeName.NULL); + } + + SqlTypeName returnSqlTypeName = SqlTypeName.NULL; + boolean hasDouble = false; + + for (int i = 0; i < n; i++) { + RelDataType type = opBinding.getOperandType(i); + SqlTypeName sqlTypeName = type.getSqlTypeName(); + ValueType valueType = Calcites.getValueTypeForSqlTypeName(sqlTypeName); + + // Return types are listed in order of preference: + if (valueType == ValueType.STRING) { + returnSqlTypeName = sqlTypeName; + break; + } else if (valueType == ValueType.DOUBLE || valueType == ValueType.FLOAT) { + returnSqlTypeName = SqlTypeName.DOUBLE; + hasDouble = true; + } else if (valueType == ValueType.LONG && !hasDouble) { + returnSqlTypeName = SqlTypeName.BIGINT; + } else if (sqlTypeName != SqlTypeName.NULL) { + throw new IAE("Argument %d has invalid type: %s", i, sqlTypeName); + } + } + + return typeFactory.createSqlType(returnSqlTypeName); + }; +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java index 731f036bca5..0fd811a0d28 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java @@ -37,8 +37,6 @@ import org.apache.druid.sql.calcite.aggregation.builtin.ApproxCountDistinctSqlAg import org.apache.druid.sql.calcite.aggregation.builtin.AvgSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.CountSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.EarliestLatestAnySqlAggregator; -import org.apache.druid.sql.calcite.aggregation.builtin.GreatestSqlAggregator; -import org.apache.druid.sql.calcite.aggregation.builtin.LeastSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.MaxSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.MinSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.SumSqlAggregator; @@ -66,11 +64,13 @@ import org.apache.druid.sql.calcite.expression.builtin.ConcatOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.DateTruncOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.ExtractOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.FloorOperatorConversion; +import org.apache.druid.sql.calcite.expression.builtin.GreatestOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.IPv4AddressMatchOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.IPv4AddressParseOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.IPv4AddressStringifyOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.LPadOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.LTrimOperatorConversion; +import org.apache.druid.sql.calcite.expression.builtin.LeastOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.LeftOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.LikeOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.MillisToTimestampOperatorConversion; @@ -126,8 +126,6 @@ public class DruidOperatorTable implements SqlOperatorTable .add(EarliestLatestAnySqlAggregator.ANY_VALUE) .add(new MinSqlAggregator()) .add(new MaxSqlAggregator()) - .add(new GreatestSqlAggregator()) - .add(new LeastSqlAggregator()) .add(new SumSqlAggregator()) .add(new SumZeroSqlAggregator()) .build(); @@ -219,6 +217,12 @@ public class DruidOperatorTable implements SqlOperatorTable .add(new StringToMultiValueStringOperatorConversion()) .build(); + private static final List REDUCTION_OPERATOR_CONVERSIONS = + ImmutableList.builder() + .add(new GreatestOperatorConversion()) + .add(new LeastOperatorConversion()) + .build(); + private static final List IPV4ADDRESS_OPERATOR_CONVERSIONS = ImmutableList.builder() .add(new IPv4AddressMatchOperatorConversion()) @@ -282,6 +286,7 @@ public class DruidOperatorTable implements SqlOperatorTable .addAll(VALUE_COERCION_OPERATOR_CONVERSIONS) .addAll(ARRAY_OPERATOR_CONVERSIONS) .addAll(MULTIVALUE_STRING_OPERATOR_CONVERSIONS) + .addAll(REDUCTION_OPERATOR_CONVERSIONS) .addAll(IPV4ADDRESS_OPERATOR_CONVERSIONS) .build(); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 8aabb454500..b34b3ddbf20 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -26,7 +26,6 @@ import org.apache.calcite.tools.ValidationException; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.IAE; -import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.JodaUtils; import org.apache.druid.java.util.common.granularity.Granularities; @@ -39,7 +38,6 @@ import org.apache.druid.query.ResourceLimitExceededException; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory; -import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory; @@ -63,12 +61,8 @@ import org.apache.druid.query.aggregation.last.FloatLastAggregatorFactory; import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory; import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; -import org.apache.druid.query.aggregation.post.DoubleGreatestPostAggregator; -import org.apache.druid.query.aggregation.post.DoubleLeastPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; -import org.apache.druid.query.aggregation.post.LongGreatestPostAggregator; -import org.apache.druid.query.aggregation.post.LongLeastPostAggregator; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.ExtractionDimensionSpec; import org.apache.druid.query.extraction.RegexDimExtractionFn; @@ -6269,444 +6263,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ); } - @Test - public void testGreatestLongAndDoubleWithGroupBy() throws Exception - { - // Cannot vectorize due to virtual columns. - cannotVectorize(); - - testQuery( - "SELECT * FROM (" - + " SELECT greatest(cntl1, cntl2), greatest(cntd1, cntd2) FROM (\n" - + " SELECT TIME_FLOOR(__time, 'P1D') AS t,\n" - + " count(1) AS cntl1, 10 AS cntl2,\n" - + " (1.2 + count(1)) AS cntd1, 10.2 AS cntd2\n" - + " FROM \"foo\"\n" - + " GROUP BY 1\n" - + " )" - + ")\n", - ImmutableList.of( - GroupByQuery.builder() - .setDataSource( - new QueryDataSource( - GroupByQuery.builder() - .setDataSource(CalciteTests.DATASOURCE1) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setVirtualColumns( - expressionVirtualColumn( - "v0", - "timestamp_floor(\"__time\",'P1D',null,'UTC')", - ValueType.LONG - ) - ) - .setDimensions(dimensions(new DefaultDimensionSpec( - "v0", - "d0", - ValueType.LONG - ))) - .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0"))) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ) - ) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setAggregatorSpecs(aggregators( - new LongMaxAggregatorFactory("_a0:0", "a0"), - new LongMaxAggregatorFactory("_a0:1", null, "10", ExprMacroTable.nil()), - new DoubleMaxAggregatorFactory("_a1:0", null, "(1.2 + \"a0\")", ExprMacroTable.nil()), - new DoubleMaxAggregatorFactory("_a1:1", null, "10.2", ExprMacroTable.nil()) - )) - .setPostAggregatorSpecs( - ImmutableList.of( - new LongGreatestPostAggregator( - "_a0", - ImmutableList.of( - new FieldAccessPostAggregator(null, "_a0:0"), - new FieldAccessPostAggregator(null, "_a0:1") - ) - ), - new DoubleGreatestPostAggregator( - "_a1", - ImmutableList.of( - new FieldAccessPostAggregator(null, "_a1:0"), - new FieldAccessPostAggregator(null, "_a1:1") - ) - ) - ) - ) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of(new Object[]{10L, 10.2D}) - ); - } - - @Test - public void testLeastLongAndDoubleWithGroupBy() throws Exception - { - // Cannot vectorize due to virtual columns. - cannotVectorize(); - - testQuery( - "SELECT * FROM (" - + " SELECT least(cntl1, cntl2), least(cntd1, cntd2) FROM (\n" - + " SELECT TIME_FLOOR(__time, 'P1D') AS t,\n" - + " count(1) AS cntl1, 10 AS cntl2,\n" - + " (1.2 + count(1)) AS cntd1, 10.2 AS cntd2\n" - + " FROM \"foo\"\n" - + " GROUP BY 1\n" - + " )" - + ")\n", - ImmutableList.of( - GroupByQuery.builder() - .setDataSource( - new QueryDataSource( - GroupByQuery.builder() - .setDataSource(CalciteTests.DATASOURCE1) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setVirtualColumns( - expressionVirtualColumn( - "v0", - "timestamp_floor(\"__time\",'P1D',null,'UTC')", - ValueType.LONG - ) - ) - .setDimensions(dimensions(new DefaultDimensionSpec( - "v0", - "d0", - ValueType.LONG - ))) - .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0"))) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ) - ) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setAggregatorSpecs(aggregators( - new LongMinAggregatorFactory("_a0:0", "a0"), - new LongMinAggregatorFactory("_a0:1", null, "10", ExprMacroTable.nil()), - new DoubleMinAggregatorFactory("_a1:0", null, "(1.2 + \"a0\")", ExprMacroTable.nil()), - new DoubleMinAggregatorFactory("_a1:1", null, "10.2", ExprMacroTable.nil()) - )) - .setPostAggregatorSpecs( - ImmutableList.of( - new LongLeastPostAggregator( - "_a0", - ImmutableList.of( - new FieldAccessPostAggregator(null, "_a0:0"), - new FieldAccessPostAggregator(null, "_a0:1") - ) - ), - new DoubleLeastPostAggregator( - "_a1", - ImmutableList.of( - new FieldAccessPostAggregator(null, "_a1:0"), - new FieldAccessPostAggregator(null, "_a1:1") - ) - ) - ) - ) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of(new Object[]{1L, 2.2D}) - ); - } - - @Test - public void testGreatestSingleColumnPostAggregations() throws Exception - { - // Cannot vectorize due to virtual columns. - cannotVectorize(); - - testQuery( - "SELECT\n" - + " greatest(cnt), greatest(m1), greatest(m2)\n" - + " FROM \"foo\"\n", - ImmutableList.of( - Druids.newTimeseriesQueryBuilder() - .dataSource(CalciteTests.DATASOURCE1) - .intervals(querySegmentSpec(Filtration.eternity())) - .granularity(Granularities.ALL) - .aggregators(aggregators( - new LongMaxAggregatorFactory("a0:0", "cnt"), - new DoubleMaxAggregatorFactory("a1:0", "m1"), - new DoubleMaxAggregatorFactory("a2:0", "m2") - )) - .postAggregators(ImmutableList.of( - new LongGreatestPostAggregator( - "a0", - ImmutableList.of( - new FieldAccessPostAggregator(null, "a0:0") - ) - ), - new DoubleGreatestPostAggregator( - "a1", - ImmutableList.of( - new FieldAccessPostAggregator(null, "a1:0") - ) - ), - new DoubleGreatestPostAggregator( - "a2", - ImmutableList.of( - new FieldAccessPostAggregator(null, "a2:0") - ) - ) - ) - ) - .context(TIMESERIES_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of(new Object[]{1L, 6.0F, 6.0D}) - ); - } - - @Test - public void testLeastSingleColumnPostAggregations() throws Exception - { - // Cannot vectorize due to virtual columns. - cannotVectorize(); - - testQuery( - "SELECT\n" - + " least(cnt), least(m1), least(m2)\n" - + " FROM \"foo\"\n", - ImmutableList.of( - Druids.newTimeseriesQueryBuilder() - .dataSource(CalciteTests.DATASOURCE1) - .intervals(querySegmentSpec(Filtration.eternity())) - .granularity(Granularities.ALL) - .aggregators(aggregators( - new LongMinAggregatorFactory("a0:0", "cnt"), - new DoubleMinAggregatorFactory("a1:0", "m1"), - new DoubleMinAggregatorFactory("a2:0", "m2") - )) - .postAggregators(ImmutableList.of( - new LongLeastPostAggregator( - "a0", - ImmutableList.of( - new FieldAccessPostAggregator(null, "a0:0") - ) - ), - new DoubleLeastPostAggregator( - "a1", - ImmutableList.of( - new FieldAccessPostAggregator(null, "a1:0") - ) - ), - new DoubleLeastPostAggregator( - "a2", - ImmutableList.of( - new FieldAccessPostAggregator(null, "a2:0") - ) - ) - ) - ) - .context(TIMESERIES_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of(new Object[]{1L, 1.0F, 1.0D}) - ); - } - - @Test - public void testGreatestCombinationPostAggregations() throws Exception - { - // Cannot vectorize due to virtual columns. - cannotVectorize(); - - testQuery( - "SELECT\n" - + " greatest(cnt, 10, 10 * 2 + 3),\n" - + " greatest(m1, 10.0, 10.2 * 2.0 + 3.0),\n" - + " greatest(m2, 10.0, 10.2 * 2.0 + 3.0)\n" - + " FROM \"foo\"\n", - ImmutableList.of( - Druids.newTimeseriesQueryBuilder() - .dataSource(CalciteTests.DATASOURCE1) - .intervals(querySegmentSpec(Filtration.eternity())) - .granularity(Granularities.ALL) - .aggregators(aggregators( - new LongMaxAggregatorFactory("a0:0", "cnt"), - new LongMaxAggregatorFactory("a0:1", null, "10", ExprMacroTable.nil()), - new LongMaxAggregatorFactory("a0:2", null, "23", ExprMacroTable.nil()), - new DoubleMaxAggregatorFactory("a1:0", "m1"), - new DoubleMaxAggregatorFactory("a1:1", null, "10.0", ExprMacroTable.nil()), - new DoubleMaxAggregatorFactory("a1:2", null, "23.4", ExprMacroTable.nil()), - new DoubleMaxAggregatorFactory("a2:0", "m2"), - new DoubleMaxAggregatorFactory("a2:1", null, "10.0", ExprMacroTable.nil()), - new DoubleMaxAggregatorFactory("a2:2", null, "23.4", ExprMacroTable.nil()) - )) - .postAggregators(ImmutableList.of( - new LongGreatestPostAggregator( - "a0", - ImmutableList.of( - new FieldAccessPostAggregator(null, "a0:0"), - new FieldAccessPostAggregator(null, "a0:1"), - new FieldAccessPostAggregator(null, "a0:2") - ) - ), - new DoubleGreatestPostAggregator( - "a1", - ImmutableList.of( - new FieldAccessPostAggregator(null, "a1:0"), - new FieldAccessPostAggregator(null, "a1:1"), - new FieldAccessPostAggregator(null, "a1:2") - ) - ), - new DoubleGreatestPostAggregator( - "a2", - ImmutableList.of( - new FieldAccessPostAggregator(null, "a2:0"), - new FieldAccessPostAggregator(null, "a2:1"), - new FieldAccessPostAggregator(null, "a2:2") - ) - )) - ) - .context(TIMESERIES_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of(new Object[]{23L, 23.4D, 23.4D}) - ); - } - - @Test - public void testLeastCombinationPostAggregations() throws Exception - { - // Cannot vectorize due to virtual columns. - cannotVectorize(); - - testQuery( - "SELECT\n" - + " least(cnt, 10, 10 * 2 + 3),\n" - + " least(m1, 10.0, 10.2 * 2.0 + 3.0),\n" - + " least(m2, 10.0, 10.2 * 2.0 + 3.0)\n" - + " FROM \"foo\"\n", - ImmutableList.of( - Druids.newTimeseriesQueryBuilder() - .dataSource(CalciteTests.DATASOURCE1) - .intervals(querySegmentSpec(Filtration.eternity())) - .granularity(Granularities.ALL) - .aggregators(aggregators( - new LongMinAggregatorFactory("a0:0", "cnt"), - new LongMinAggregatorFactory("a0:1", null, "10", ExprMacroTable.nil()), - new LongMinAggregatorFactory("a0:2", null, "23", ExprMacroTable.nil()), - new DoubleMinAggregatorFactory("a1:0", "m1"), - new DoubleMinAggregatorFactory("a1:1", null, "10.0", ExprMacroTable.nil()), - new DoubleMinAggregatorFactory("a1:2", null, "23.4", ExprMacroTable.nil()), - new DoubleMinAggregatorFactory("a2:0", "m2"), - new DoubleMinAggregatorFactory("a2:1", null, "10.0", ExprMacroTable.nil()), - new DoubleMinAggregatorFactory("a2:2", null, "23.4", ExprMacroTable.nil()) - )) - .postAggregators(ImmutableList.of( - new LongLeastPostAggregator( - "a0", - ImmutableList.of( - new FieldAccessPostAggregator(null, "a0:0"), - new FieldAccessPostAggregator(null, "a0:1"), - new FieldAccessPostAggregator(null, "a0:2") - ) - ), - new DoubleLeastPostAggregator( - "a1", - ImmutableList.of( - new FieldAccessPostAggregator(null, "a1:0"), - new FieldAccessPostAggregator(null, "a1:1"), - new FieldAccessPostAggregator(null, "a1:2") - ) - ), - new DoubleLeastPostAggregator( - "a2", - ImmutableList.of( - new FieldAccessPostAggregator(null, "a2:0"), - new FieldAccessPostAggregator(null, "a2:1"), - new FieldAccessPostAggregator(null, "a2:2") - ) - )) - ) - .context(TIMESERIES_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of(new Object[]{1L, 1.0D, 1.0D}) - ); - } - - @Test - public void testGreatestInvalidPostAggregations() throws Exception - { - // Cannot vectorize due to virtual columns. - cannotVectorize(); - expectedException.expect(RuntimeException.class); - expectedException.expectCause(CoreMatchers.instanceOf(ISE.class)); - expectedException.expectCause( - ThrowableMessageMatcher.hasMessage( - CoreMatchers.containsString( - "Cannot create aggregator factory for type[STRING]" - ) - ) - ); - - testQuery("SELECT GREATEST(dim1) FROM druid.foo", ImmutableList.of(), ImmutableList.of()); - } - - @Test - public void testLeastInvalidPostAggregations() throws Exception - { - // Cannot vectorize due to virtual columns. - cannotVectorize(); - expectedException.expect(RuntimeException.class); - expectedException.expectCause(CoreMatchers.instanceOf(ISE.class)); - expectedException.expectCause( - ThrowableMessageMatcher.hasMessage( - CoreMatchers.containsString( - "Cannot create aggregator factory for type[STRING]" - ) - ) - ); - - testQuery("SELECT LEAST(dim1) FROM druid.foo", ImmutableList.of(), ImmutableList.of()); - } - - @Test - public void testGreatestInvalidCombinationPostAggregations() throws Exception - { - // Cannot vectorize due to virtual columns. - cannotVectorize(); - expectedException.expect(ValidationException.class); - expectedException.expectCause(CoreMatchers.instanceOf(IllegalArgumentException.class)); - expectedException.expectCause( - ThrowableMessageMatcher.hasMessage( - CoreMatchers.containsString( - "Cannot infer return type for GREATEST; operand types: [INTEGER, VARCHAR]" - ) - ) - ); - - testQuery("SELECT GREATEST(10, dim1) FROM druid.foo", ImmutableList.of(), ImmutableList.of()); - } - - @Test - public void testLeastInvalidCombinationPostAggregations() throws Exception - { - // Cannot vectorize due to virtual columns. - cannotVectorize(); - expectedException.expect(ValidationException.class); - expectedException.expectCause(CoreMatchers.instanceOf(IllegalArgumentException.class)); - expectedException.expectCause( - ThrowableMessageMatcher.hasMessage( - CoreMatchers.containsString( - "Cannot infer return type for LEAST; operand types: [INTEGER, VARCHAR]" - ) - ) - ); - - testQuery("SELECT LEAST(10, dim1) FROM druid.foo", ImmutableList.of(), ImmutableList.of()); - } - @Test public void testAvgDailyCountDistinct() throws Exception { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java index 48dea514481..3136508ee9b 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java @@ -133,6 +133,11 @@ class ExpressionTestHelper return rexBuilder.makeIntervalLiteral(v, intervalQualifier); } + RexNode makeLiteral(Double d) + { + return rexBuilder.makeLiteral(d, createSqlType(SqlTypeName.DOUBLE), true); + } + RexNode makeCall(SqlOperator op, RexNode... exprs) { return rexBuilder.makeCall(op, exprs); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java new file mode 100644 index 00000000000..fcd96911aec --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.expression; + +import com.google.common.collect.ImmutableMap; +import org.apache.calcite.avatica.util.TimeUnit; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlIntervalQualifier; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.sql.calcite.expression.builtin.GreatestOperatorConversion; +import org.junit.Before; +import org.junit.Test; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class GreatestExpressionTest extends ExpressionTestBase +{ + private static final String DOUBLE_KEY = "d"; + private static final double DOUBLE_VALUE = 3.1; + private static final String LONG_KEY = "l"; + private static final long LONG_VALUE = 2L; + private static final String STRING_KEY = "s"; + private static final String STRING_VALUE = "foo"; + private static final RowSignature ROW_SIGNATURE = RowSignature + .builder() + .add(DOUBLE_KEY, ValueType.DOUBLE) + .add(LONG_KEY, ValueType.LONG) + .add(STRING_KEY, ValueType.STRING) + .build(); + private static final Map BINDINGS = ImmutableMap.of( + DOUBLE_KEY, DOUBLE_VALUE, + LONG_KEY, LONG_VALUE, + STRING_KEY, STRING_VALUE + ); + + private GreatestOperatorConversion target; + private ExpressionTestHelper testHelper; + + @Before + public void setUp() + { + target = new GreatestOperatorConversion(); + testHelper = new ExpressionTestHelper(ROW_SIGNATURE, BINDINGS); + } + + @Test + public void testNoArgs() + { + testExpression( + Collections.emptyList(), + buildExpectedExpression(), + null + ); + } + + @Test + public void testAllNull() + { + testExpression( + Arrays.asList( + testHelper.getConstantNull(), + testHelper.getConstantNull() + ), + buildExpectedExpression(null, null), + null + ); + } + + @Test + public void testSomeNull() + { + testExpression( + Arrays.asList( + testHelper.makeInputRef(DOUBLE_KEY), + testHelper.getConstantNull(), + testHelper.makeInputRef(STRING_KEY) + ), + buildExpectedExpression( + testHelper.makeVariable(DOUBLE_KEY), + null, + testHelper.makeVariable(STRING_KEY) + ), + STRING_VALUE + ); + } + + @Test + public void testAllDouble() + { + testExpression( + Arrays.asList( + testHelper.makeLiteral(34.1), + testHelper.makeInputRef(DOUBLE_KEY), + testHelper.makeLiteral(5.2), + testHelper.makeLiteral(767.3) + ), + buildExpectedExpression( + 34.1, + testHelper.makeVariable(DOUBLE_KEY), + 5.2, + 767.3 + ), + 767.3 + ); + } + + @Test + public void testAllLong() + { + testExpression( + Arrays.asList( + testHelper.makeInputRef(LONG_KEY), + testHelper.makeLiteral(0) + ), + buildExpectedExpression( + testHelper.makeVariable(LONG_KEY), + 0 + ), + LONG_VALUE + ); + } + + @Test + public void testAllString() + { + testExpression( + Arrays.asList( + testHelper.makeLiteral("B"), + testHelper.makeInputRef(STRING_KEY), + testHelper.makeLiteral("A") + ), + buildExpectedExpression( + "B", + testHelper.makeVariable(STRING_KEY), + "A" + ), + STRING_VALUE + ); + } + + @Test + public void testCoerceString() + { + testExpression( + Arrays.asList( + testHelper.makeLiteral(-1), + testHelper.makeInputRef(DOUBLE_KEY), + testHelper.makeLiteral("A") + ), + buildExpectedExpression( + -1, + testHelper.makeVariable(DOUBLE_KEY), + "A" + ), + "A" + ); + } + + @Test + public void testCoerceDouble() + { + testExpression( + Arrays.asList( + testHelper.makeLiteral(-1), + testHelper.makeInputRef(DOUBLE_KEY) + ), + buildExpectedExpression( + -1, + testHelper.makeVariable(DOUBLE_KEY) + ), + DOUBLE_VALUE + ); + } + + @Test + public void testDecimal() + { + testExpression( + Arrays.asList( + testHelper.makeLiteral(BigDecimal.valueOf(1.2)), + testHelper.makeLiteral(BigDecimal.valueOf(3.4)) + ), + buildExpectedExpression( + 1.2, + 3.4 + ), + 3.4 + ); + } + + @Test + public void testTimestamp() + { + testExpression( + Arrays.asList( + testHelper.makeLiteral(DateTimes.utc(1000)), + testHelper.makeLiteral(DateTimes.utc(2000)) + ), + buildExpectedExpression( + 1000, + 2000 + ), + 2000L + ); + } + + @Test + public void testInvalidType() + { + expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH"); + + testExpression( + Collections.singletonList( + testHelper.makeLiteral( + new BigDecimal(13), // YEAR-MONTH literals value is months + new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO) + ) + ), + null, + null + ); + } + + private void testExpression( + List exprs, + final DruidExpression expectedExpression, + final Object expectedResult + ) + { + testHelper.testExpression(target.calciteOperator(), exprs, expectedExpression, expectedResult); + } + + private DruidExpression buildExpectedExpression(Object... args) + { + return testHelper.buildExpectedExpression(target.getDruidFunctionName(), args); + } +} diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java new file mode 100644 index 00000000000..1405eac207c --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.expression; + +import com.google.common.collect.ImmutableMap; +import org.apache.calcite.avatica.util.TimeUnit; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlIntervalQualifier; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.sql.calcite.expression.builtin.LeastOperatorConversion; +import org.junit.Before; +import org.junit.Test; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class LeastExpressionTest extends ExpressionTestBase +{ + private static final String DOUBLE_KEY = "d"; + private static final double DOUBLE_VALUE = 3.1; + private static final String LONG_KEY = "l"; + private static final long LONG_VALUE = 2L; + private static final String STRING_KEY = "s"; + private static final String STRING_VALUE = "foo"; + private static final RowSignature ROW_SIGNATURE = RowSignature + .builder() + .add(DOUBLE_KEY, ValueType.DOUBLE) + .add(LONG_KEY, ValueType.LONG) + .add(STRING_KEY, ValueType.STRING) + .build(); + private static final Map BINDINGS = ImmutableMap.of( + DOUBLE_KEY, DOUBLE_VALUE, + LONG_KEY, LONG_VALUE, + STRING_KEY, STRING_VALUE + ); + + private LeastOperatorConversion target; + private ExpressionTestHelper testHelper; + + @Before + public void setUp() + { + target = new LeastOperatorConversion(); + testHelper = new ExpressionTestHelper(ROW_SIGNATURE, BINDINGS); + } + + @Test + public void testNoArgs() + { + testExpression( + Collections.emptyList(), + buildExpectedExpression(), + null + ); + } + + @Test + public void testAllNull() + { + testExpression( + Arrays.asList( + testHelper.getConstantNull(), + testHelper.getConstantNull() + ), + buildExpectedExpression(null, null), + null + ); + } + + @Test + public void testSomeNull() + { + testExpression( + Arrays.asList( + testHelper.makeInputRef(DOUBLE_KEY), + testHelper.getConstantNull(), + testHelper.makeInputRef(STRING_KEY) + ), + buildExpectedExpression( + testHelper.makeVariable(DOUBLE_KEY), + null, + testHelper.makeVariable(STRING_KEY) + ), + String.valueOf(DOUBLE_VALUE) + ); + } + + @Test + public void testAllDouble() + { + testExpression( + Arrays.asList( + testHelper.makeLiteral(34.1), + testHelper.makeInputRef(DOUBLE_KEY), + testHelper.makeLiteral(5.2), + testHelper.makeLiteral(767.3) + ), + buildExpectedExpression( + 34.1, + testHelper.makeVariable(DOUBLE_KEY), + 5.2, + 767.3 + ), + 3.1 + ); + } + + @Test + public void testAllLong() + { + testExpression( + Arrays.asList( + testHelper.makeInputRef(LONG_KEY), + testHelper.makeLiteral(0) + ), + buildExpectedExpression( + testHelper.makeVariable(LONG_KEY), + 0 + ), + 0L + ); + } + + @Test + public void testAllString() + { + testExpression( + Arrays.asList( + testHelper.makeLiteral("B"), + testHelper.makeInputRef(STRING_KEY), + testHelper.makeLiteral("A") + ), + buildExpectedExpression( + "B", + testHelper.makeVariable(STRING_KEY), + "A" + ), + "A" + ); + } + + @Test + public void testCoerceString() + { + testExpression( + Arrays.asList( + testHelper.makeLiteral(-1), + testHelper.makeInputRef(DOUBLE_KEY), + testHelper.makeLiteral("A") + ), + buildExpectedExpression( + -1, + testHelper.makeVariable(DOUBLE_KEY), + "A" + ), + "-1" + ); + } + + @Test + public void testCoerceDouble() + { + testExpression( + Arrays.asList( + testHelper.makeLiteral(-1), + testHelper.makeInputRef(DOUBLE_KEY) + ), + buildExpectedExpression( + -1, + testHelper.makeVariable(DOUBLE_KEY) + ), + -1.0 + ); + } + + @Test + public void testDecimal() + { + testExpression( + Arrays.asList( + testHelper.makeLiteral(BigDecimal.valueOf(1.2)), + testHelper.makeLiteral(BigDecimal.valueOf(3.4)) + ), + buildExpectedExpression( + 1.2, + 3.4 + ), + 1.2 + ); + } + + @Test + public void testTimestamp() + { + testExpression( + Arrays.asList( + testHelper.makeLiteral(DateTimes.utc(1000)), + testHelper.makeLiteral(DateTimes.utc(2000)) + ), + buildExpectedExpression( + 1000, + 2000 + ), + 1000L + ); + } + + @Test + public void testInvalidType() + { + expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH"); + + testExpression( + Collections.singletonList( + testHelper.makeLiteral( + new BigDecimal(13), // YEAR-MONTH literals value is months + new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO) + ) + ), + null, + null + ); + } + + private void testExpression( + List exprs, + final DruidExpression expectedExpression, + final Object expectedResult + ) + { + testHelper.testExpression(target.calciteOperator(), exprs, expectedExpression, expectedResult); + } + + private DruidExpression buildExpectedExpression(Object... args) + { + return testHelper.buildExpectedExpression(target.getDruidFunctionName(), args); + } +}