Match GREATEST/LEAST function behavior to other DBs (#9488)

* Match GREATEST/LEAST function behavior

Change the behavior of the GREATEST / LEAST functions to be similar to
how it is implemented in other databases (as functions instead of
aggregators). The GREATEST/LEAST functions are not in the SQL standard,
but users will expect behavior similar to what other databases provide.

* Match postgres behavior & handle more SQL types

* Fix imports
This commit is contained in:
Chi Cao Minh 2020-03-12 15:10:11 -07:00 committed by GitHub
parent ddc6f87920
commit 6b02991464
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1015 additions and 767 deletions

View File

@ -35,9 +35,14 @@ import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.BinaryOperator;
import java.util.function.DoubleBinaryOperator;
import java.util.function.LongBinaryOperator;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -48,7 +53,7 @@ import java.util.stream.Stream;
* Do NOT remove "unused" members in this class. They are used by generated Antlr
*/
@SuppressWarnings("unused")
interface Function
public interface Function
{
/**
* Name of the function.
@ -976,6 +981,172 @@ interface Function
}
}
class GreatestFunc extends ReduceFunc
{
public static final String NAME = "greatest";
public GreatestFunc()
{
super(
Math::max,
Math::max,
BinaryOperator.maxBy(Comparator.naturalOrder())
);
}
@Override
public String name()
{
return NAME;
}
}
class LeastFunc extends ReduceFunc
{
public static final String NAME = "least";
public LeastFunc()
{
super(
Math::min,
Math::min,
BinaryOperator.minBy(Comparator.naturalOrder())
);
}
@Override
public String name()
{
return NAME;
}
}
abstract class ReduceFunc implements Function
{
private final DoubleBinaryOperator doubleReducer;
private final LongBinaryOperator longReducer;
private final BinaryOperator<String> stringReducer;
ReduceFunc(
DoubleBinaryOperator doubleReducer,
LongBinaryOperator longReducer,
BinaryOperator<String> stringReducer
)
{
this.doubleReducer = doubleReducer;
this.longReducer = longReducer;
this.stringReducer = stringReducer;
}
@Override
public void validateArguments(List<Expr> args)
{
// anything goes
}
@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
if (args.isEmpty()) {
return ExprEval.of(null);
}
ExprAnalysis exprAnalysis = analyzeExprs(args, bindings);
if (exprAnalysis.exprEvals.isEmpty()) {
// The GREATEST/LEAST functions are not in the SQL standard. Emulate the behavior of postgres (return null if
// all expressions are null, otherwise skip null values) since it is used as a base for a wide number of
// databases. This also matches the behavior the the long/double greatest/least post aggregators. Some other
// databases (e.g., MySQL) return null if any expression is null.
// https://www.postgresql.org/docs/9.5/functions-conditional.html
// https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least
return ExprEval.of(null);
}
Stream<ExprEval<?>> exprEvalStream = exprAnalysis.exprEvals.stream();
switch (exprAnalysis.comparisonType) {
case DOUBLE:
//noinspection OptionalGetWithoutIsPresent (empty list handled earlier)
return ExprEval.of(exprEvalStream.mapToDouble(ExprEval::asDouble).reduce(doubleReducer).getAsDouble());
case LONG:
//noinspection OptionalGetWithoutIsPresent (empty list handled earlier)
return ExprEval.of(exprEvalStream.mapToLong(ExprEval::asLong).reduce(longReducer).getAsLong());
default:
//noinspection OptionalGetWithoutIsPresent (empty list handled earlier)
return ExprEval.of(exprEvalStream.map(ExprEval::asString).reduce(stringReducer).get());
}
}
/**
* Determines which {@link ExprType} to use to compare non-null evaluated expressions.
*
* @param exprs Expressions to analyze
* @param bindings Bindings for expressions
*
* @return Comparison type and non-null evaluated expressions.
*/
private ExprAnalysis analyzeExprs(List<Expr> exprs, Expr.ObjectBinding bindings)
{
Set<ExprType> presentTypes = EnumSet.noneOf(ExprType.class);
List<ExprEval<?>> exprEvals = new ArrayList<>();
for (Expr expr : exprs) {
ExprEval<?> exprEval = expr.eval(bindings);
ExprType exprType = exprEval.type();
if (isValidType(exprType)) {
presentTypes.add(exprType);
}
if (exprEval.value() != null) {
exprEvals.add(exprEval);
}
}
ExprType comparisonType = getComparisionType(presentTypes);
return new ExprAnalysis(comparisonType, exprEvals);
}
private boolean isValidType(ExprType exprType)
{
switch (exprType) {
case DOUBLE:
case LONG:
case STRING:
return true;
default:
throw new IAE("Function[%s] does not accept %s types", name(), exprType);
}
}
/**
* Implements rules similar to: https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least
*
* @see org.apache.druid.sql.calcite.expression.builtin.ReductionOperatorConversionHelper#TYPE_INFERENCE
*/
private static ExprType getComparisionType(Set<ExprType> exprTypes)
{
if (exprTypes.contains(ExprType.STRING)) {
return ExprType.STRING;
} else if (exprTypes.contains(ExprType.DOUBLE)) {
return ExprType.DOUBLE;
} else {
return ExprType.LONG;
}
}
private static class ExprAnalysis
{
final ExprType comparisonType;
final List<ExprEval<?>> exprEvals;
ExprAnalysis(ExprType comparisonType, List<ExprEval<?>> exprEvals)
{
this.comparisonType = comparisonType;
this.exprEvals = exprEvals;
}
}
}
class NextAfter extends BivariateMathFunction
{
@Override
@ -2390,6 +2561,7 @@ interface Function
throw new RE("Unable to prepend to unknown type %s", arrayExpr.type());
}
private <T> Stream<T> prepend(T val, T[] array)
{
List<T> l = new ArrayList<>(Arrays.asList(array));

View File

@ -26,6 +26,8 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import javax.annotation.Nullable;
public class FunctionTest extends InitializedNullHandlingTest
{
private Expr.ObjectBinding bindings;
@ -68,11 +70,11 @@ public class FunctionTest extends InitializedNullHandlingTest
if (NullHandling.replaceWithDefault()) {
assertExpr("concat(x,' ',nonexistent,' ',y)", "foo 2");
} else {
assertExpr("concat(x,' ',nonexistent,' ',y)", null);
assertArrayExpr("concat(x,' ',nonexistent,' ',y)", null);
}
assertExpr("concat(z)", "3.1");
assertExpr("concat()", null);
assertArrayExpr("concat()", null);
}
@Test
@ -144,9 +146,9 @@ public class FunctionTest extends InitializedNullHandlingTest
assertExpr("lpad(x, 5, 'ab')", "abfoo");
assertExpr("lpad(x, 4, 'ab')", "afoo");
assertExpr("lpad(x, 2, 'ab')", "fo");
assertExpr("lpad(x, 0, 'ab')", null);
assertExpr("lpad(x, 5, null)", null);
assertExpr("lpad(null, 5, x)", null);
assertArrayExpr("lpad(x, 0, 'ab')", null);
assertArrayExpr("lpad(x, 5, null)", null);
assertArrayExpr("lpad(null, 5, x)", null);
}
@Test
@ -155,18 +157,18 @@ public class FunctionTest extends InitializedNullHandlingTest
assertExpr("rpad(x, 5, 'ab')", "fooab");
assertExpr("rpad(x, 4, 'ab')", "fooa");
assertExpr("rpad(x, 2, 'ab')", "fo");
assertExpr("rpad(x, 0, 'ab')", null);
assertExpr("rpad(x, 5, null)", null);
assertExpr("rpad(null, 5, x)", null);
assertArrayExpr("rpad(x, 0, 'ab')", null);
assertArrayExpr("rpad(x, 5, null)", null);
assertArrayExpr("rpad(null, 5, x)", null);
}
@Test
public void testArrayConstructor()
{
assertExpr("array(1, 2, 3, 4)", new Long[]{1L, 2L, 3L, 4L});
assertExpr("array(1, 2, 3, 'bar')", new Long[]{1L, 2L, 3L, null});
assertExpr("array(1.0)", new Double[]{1.0});
assertExpr("array('foo', 'bar')", new String[]{"foo", "bar"});
assertArrayExpr("array(1, 2, 3, 4)", new Long[]{1L, 2L, 3L, 4L});
assertArrayExpr("array(1, 2, 3, 'bar')", new Long[]{1L, 2L, 3L, null});
assertArrayExpr("array(1.0)", new Double[]{1.0});
assertArrayExpr("array('foo', 'bar')", new String[]{"foo", "bar"});
}
@Test
@ -180,7 +182,7 @@ public class FunctionTest extends InitializedNullHandlingTest
public void testArrayOffset()
{
assertExpr("array_offset([1, 2, 3], 2)", 3L);
assertExpr("array_offset([1, 2, 3], 3)", null);
assertArrayExpr("array_offset([1, 2, 3], 3)", null);
assertExpr("array_offset(a, 2)", "baz");
}
@ -188,7 +190,7 @@ public class FunctionTest extends InitializedNullHandlingTest
public void testArrayOrdinal()
{
assertExpr("array_ordinal([1, 2, 3], 3)", 3L);
assertExpr("array_ordinal([1, 2, 3], 4)", null);
assertArrayExpr("array_ordinal([1, 2, 3], 4)", null);
assertExpr("array_ordinal(a, 3)", "baz");
}
@ -228,20 +230,20 @@ public class FunctionTest extends InitializedNullHandlingTest
@Test
public void testArrayAppend()
{
assertExpr("array_append([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
assertExpr("array_append([1, 2, 3], 'bar')", new Long[]{1L, 2L, 3L, null});
assertExpr("array_append([], 1)", new String[]{"1"});
assertExpr("array_append(<LONG>[], 1)", new Long[]{1L});
assertArrayExpr("array_append([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
assertArrayExpr("array_append([1, 2, 3], 'bar')", new Long[]{1L, 2L, 3L, null});
assertArrayExpr("array_append([], 1)", new String[]{"1"});
assertArrayExpr("array_append(<LONG>[], 1)", new Long[]{1L});
}
@Test
public void testArrayConcat()
{
assertExpr("array_concat([1, 2, 3], [2, 4, 6])", new Long[]{1L, 2L, 3L, 2L, 4L, 6L});
assertExpr("array_concat([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
assertExpr("array_concat(0, [1, 2, 3])", new Long[]{0L, 1L, 2L, 3L});
assertExpr("array_concat(map(y -> y * 3, b), [1, 2, 3])", new Long[]{3L, 6L, 9L, 12L, 15L, 1L, 2L, 3L});
assertExpr("array_concat(0, 1)", new Long[]{0L, 1L});
assertArrayExpr("array_concat([1, 2, 3], [2, 4, 6])", new Long[]{1L, 2L, 3L, 2L, 4L, 6L});
assertArrayExpr("array_concat([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
assertArrayExpr("array_concat(0, [1, 2, 3])", new Long[]{0L, 1L, 2L, 3L});
assertArrayExpr("array_concat(map(y -> y * 3, b), [1, 2, 3])", new Long[]{3L, 6L, 9L, 12L, 15L, 1L, 2L, 3L});
assertArrayExpr("array_concat(0, 1)", new Long[]{0L, 1L});
}
@Test
@ -255,43 +257,99 @@ public class FunctionTest extends InitializedNullHandlingTest
@Test
public void testStringToArray()
{
assertExpr("string_to_array('1,2,3', ',')", new String[]{"1", "2", "3"});
assertExpr("string_to_array('1', ',')", new String[]{"1"});
assertExpr("string_to_array(array_to_string(a, ','), ',')", new String[]{"foo", "bar", "baz", "foobar"});
assertArrayExpr("string_to_array('1,2,3', ',')", new String[]{"1", "2", "3"});
assertArrayExpr("string_to_array('1', ',')", new String[]{"1"});
assertArrayExpr("string_to_array(array_to_string(a, ','), ',')", new String[]{"foo", "bar", "baz", "foobar"});
}
@Test
public void testArrayCast()
{
assertExpr("cast([1, 2, 3], 'STRING_ARRAY')", new String[]{"1", "2", "3"});
assertExpr("cast([1, 2, 3], 'DOUBLE_ARRAY')", new Double[]{1.0, 2.0, 3.0});
assertExpr("cast(c, 'LONG_ARRAY')", new Long[]{3L, 4L, 5L});
assertExpr("cast(string_to_array(array_to_string(b, ','), ','), 'LONG_ARRAY')", new Long[]{1L, 2L, 3L, 4L, 5L});
assertExpr("cast(['1.0', '2.0', '3.0'], 'LONG_ARRAY')", new Long[]{1L, 2L, 3L});
assertArrayExpr("cast([1, 2, 3], 'STRING_ARRAY')", new String[]{"1", "2", "3"});
assertArrayExpr("cast([1, 2, 3], 'DOUBLE_ARRAY')", new Double[]{1.0, 2.0, 3.0});
assertArrayExpr("cast(c, 'LONG_ARRAY')", new Long[]{3L, 4L, 5L});
assertArrayExpr("cast(string_to_array(array_to_string(b, ','), ','), 'LONG_ARRAY')", new Long[]{1L, 2L, 3L, 4L, 5L});
assertArrayExpr("cast(['1.0', '2.0', '3.0'], 'LONG_ARRAY')", new Long[]{1L, 2L, 3L});
}
@Test
public void testArraySlice()
{
assertExpr("array_slice([1, 2, 3, 4], 1, 3)", new Long[] {2L, 3L});
assertExpr("array_slice([1.0, 2.1, 3.2, 4.3], 2)", new Double[] {3.2, 4.3});
assertExpr("array_slice(['a', 'b', 'c', 'd'], 4, 6)", new String[] {null, null});
assertExpr("array_slice([1, 2, 3, 4], 2, 2)", new Long[] {});
assertExpr("array_slice([1, 2, 3, 4], 5, 7)", null);
assertExpr("array_slice([1, 2, 3, 4], 2, 1)", null);
assertArrayExpr("array_slice([1, 2, 3, 4], 1, 3)", new Long[] {2L, 3L});
assertArrayExpr("array_slice([1.0, 2.1, 3.2, 4.3], 2)", new Double[] {3.2, 4.3});
assertArrayExpr("array_slice(['a', 'b', 'c', 'd'], 4, 6)", new String[] {null, null});
assertArrayExpr("array_slice([1, 2, 3, 4], 2, 2)", new Long[] {});
assertArrayExpr("array_slice([1, 2, 3, 4], 5, 7)", null);
assertArrayExpr("array_slice([1, 2, 3, 4], 2, 1)", null);
}
@Test
public void testArrayPrepend()
{
assertExpr("array_prepend(4, [1, 2, 3])", new Long[]{4L, 1L, 2L, 3L});
assertExpr("array_prepend('bar', [1, 2, 3])", new Long[]{null, 1L, 2L, 3L});
assertExpr("array_prepend(1, [])", new String[]{"1"});
assertExpr("array_prepend(1, <LONG>[])", new Long[]{1L});
assertExpr("array_prepend(1, <DOUBLE>[])", new Double[]{1.0});
assertArrayExpr("array_prepend(4, [1, 2, 3])", new Long[]{4L, 1L, 2L, 3L});
assertArrayExpr("array_prepend('bar', [1, 2, 3])", new Long[]{null, 1L, 2L, 3L});
assertArrayExpr("array_prepend(1, [])", new String[]{"1"});
assertArrayExpr("array_prepend(1, <LONG>[])", new Long[]{1L});
assertArrayExpr("array_prepend(1, <DOUBLE>[])", new Double[]{1.0});
}
private void assertExpr(final String expression, final Object expectedResult)
@Test
public void testGreatest()
{
// Same types
assertExpr("greatest(y, 0)", 2L);
assertExpr("greatest(34.0, z, 5.0, 767.0", 767.0);
assertExpr("greatest('B', x, 'A')", "foo");
// Different types
assertExpr("greatest(-1, z, 'A')", "A");
assertExpr("greatest(-1, z)", 3.1);
assertExpr("greatest(1, 'A')", "A");
// Invalid types
try {
assertExpr("greatest(1, ['A'])", null);
Assert.fail("Did not throw IllegalArgumentException");
}
catch (IllegalArgumentException e) {
Assert.assertEquals("Function[greatest] does not accept STRING_ARRAY types", e.getMessage());
}
// Null handling
assertExpr("greatest()", null);
assertExpr("greatest(null, null)", null);
assertExpr("greatest(1, null, 'A')", "A");
}
@Test
public void testLeast()
{
// Same types
assertExpr("least(y, 0)", 0L);
assertExpr("least(34.0, z, 5.0, 767.0", 3.1);
assertExpr("least('B', x, 'A')", "A");
// Different types
assertExpr("least(-1, z, 'A')", "-1");
assertExpr("least(-1, z)", -1.0);
assertExpr("least(1, 'A')", "1");
// Invalid types
try {
assertExpr("least(1, [2, 3])", null);
Assert.fail("Did not throw IllegalArgumentException");
}
catch (IllegalArgumentException e) {
Assert.assertEquals("Function[least] does not accept LONG_ARRAY types", e.getMessage());
}
// Null handling
assertExpr("least()", null);
assertExpr("least(null, null)", null);
assertExpr("least(1, null, 'A')", "1");
}
private void assertExpr(final String expression, @Nullable final Object expectedResult)
{
final Expr expr = Parser.parse(expression, ExprMacroTable.nil());
Assert.assertEquals(expression, expectedResult, expr.eval(bindings).value());
@ -307,7 +365,7 @@ public class FunctionTest extends InitializedNullHandlingTest
Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify());
}
private void assertExpr(final String expression, final Object[] expectedResult)
private void assertArrayExpr(final String expression, @Nullable final Object[] expectedResult)
{
final Expr expr = Parser.parse(expression, ExprMacroTable.nil());
Assert.assertArrayEquals(expression, expectedResult, expr.eval(bindings).asArray());

View File

@ -181,6 +181,22 @@ See javadoc of java.lang.Math for detailed explanation for each function.
| all(lambda,arr) | returns 1 if all elements in the array matches the lambda expression, else 0 |
### Reduction functions
Reduction functions operate on zero or more expressions and return a single expression. If no expressions are passed as
arguments, then the result is `NULL`. The expressions must all be convertible to a common data type, which will be the
type of the result:
* If all arguments are `NULL`, the result is `NULL`. Otherwise, `NULL` arguments are ignored.
* If the arguments comprise a mix of numbers and strings, the arguments are interpreted as strings.
* If all arguments are integer numbers, the arguments are interpreted as longs.
* If all arguments are numbers and at least one argument is a double, the arguments are interpreted as doubles.
| function | description |
| --- | --- |
| greatest([expr1, ...]) | Evaluates zero or more expressions and returns the maximum value based on comparisons as described above. |
| least([expr1, ...]) | Evaluates zero or more expressions and returns the minimum value based on comparisons as described above. |
## IP address functions
For the IPv4 address functions, the `address` argument can either be an IPv4 dotted-decimal string (e.g., "192.168.0.1") or an IP address represented as a long (e.g., 3232235521). The `subnet` argument should be a string formatted as an IPv4 address subnet in CIDR notation (e.g., "192.168.0.0/16").

View File

@ -198,8 +198,6 @@ Only the COUNT aggregation can accept DISTINCT.
|`SUM(expr)`|Sums numbers.|
|`MIN(expr)`|Takes the minimum of numbers.|
|`MAX(expr)`|Takes the maximum of numbers.|
|`LEAST(expr1, [expr2, ...])`|Takes the minimum of numbers across one or more expression(s).|
|`GREATEST(expr1, [expr2, ...])`|Takes the maximum of numbers across one or more expression(s).|
|`AVG(expr)`|Averages numbers.|
|`APPROX_COUNT_DISTINCT(expr)`|Counts distinct values of expr, which can be a regular column or a hyperUnique column. This is always approximate, regardless of the value of "useApproximateCountDistinct". This uses Druid's built-in "cardinality" or "hyperUnique" aggregators. See also `COUNT(DISTINCT expr)`.|
|`APPROX_COUNT_DISTINCT_DS_HLL(expr, [lgK, tgtHllType])`|Counts distinct values of expr, which can be a regular column or an [HLL sketch](../development/extensions-core/datasketches-hll.html) column. The `lgK` and `tgtHllType` parameters are described in the HLL sketch documentation. This is always approximate, regardless of the value of "useApproximateCountDistinct". See also `COUNT(DISTINCT expr)`. The [DataSketches extension](../development/extensions-core/datasketches-extension.html) must be loaded to use this function.|
@ -334,6 +332,22 @@ simplest way to write literal timestamps in other time zones is to use TIME_PARS
|<code>timestamp_expr { + &#124; - } <interval_expr><code>|Add or subtract an amount of time from a timestamp. interval_expr can include interval literals like `INTERVAL '2' HOUR`, and may include interval arithmetic as well. This operator treats days as uniformly 86400 seconds long, and does not take into account daylight savings time. To account for daylight savings time, use TIME_SHIFT instead.|
### Reduction functions
Reduction functions operate on zero or more expressions and return a single expression. If no expressions are passed as
arguments, then the result is `NULL`. The expressions must all be convertible to a common data type, which will be the
type of the result:
* If all argument are `NULL`, the result is `NULL`. Otherwise, `NULL` arguments are ignored.
* If the arguments comprise a mix of numbers and strings, the arguments are interpreted as strings.
* If all arguments are integer numbers, the arguments are interpreted as longs.
* If all arguments are numbers and at least one argument is a double, the arguments are interpreted as doubles.
|Function|Notes|
|--------|-----|
|`GREATEST([expr1, ...])`|Evaluates zero or more expressions and returns the maximum value based on comparisons as described above.|
|`LEAST([expr1, ...])`|Evaluates zero or more expressions and returns the minimum value based on comparisons as described above.|
### IP address functions
For the IPv4 address functions, the `address` argument can either be an IPv4 dotted-decimal string

View File

@ -1,136 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.aggregation.builtin;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeTransforms;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.aggregation.post.DoubleGreatestPostAggregator;
import org.apache.druid.query.aggregation.post.LongGreatestPostAggregator;
import org.apache.druid.segment.column.ValueType;
import java.util.List;
/**
* Calcite integration class for Greatest post aggregators of Long & Double types.
* It applies Max aggregators over the provided fields/expressions & combines their results via Field access post aggregators.
*/
public class GreatestSqlAggregator extends MultiColumnSqlAggregator
{
private static final SqlAggFunction FUNCTION_INSTANCE = new GreatestSqlAggFunction();
private static final String NAME = "GREATEST";
@Override
public SqlAggFunction calciteFunction()
{
return FUNCTION_INSTANCE;
}
@Override
AggregatorFactory createAggregatorFactory(
ValueType valueType,
String prefixedName,
FieldInfo fieldInfo,
ExprMacroTable macroTable
)
{
final AggregatorFactory aggregatorFactory;
switch (valueType) {
case LONG:
aggregatorFactory = new LongMaxAggregatorFactory(prefixedName, fieldInfo.fieldName, fieldInfo.expression, macroTable);
break;
case FLOAT:
case DOUBLE:
aggregatorFactory = new DoubleMaxAggregatorFactory(prefixedName, fieldInfo.fieldName, fieldInfo.expression, macroTable);
break;
default:
throw new ISE("Cannot create aggregator factory for type[%s]", valueType);
}
return aggregatorFactory;
}
@Override
PostAggregator createFinalPostAggregator(
ValueType valueType,
String name,
List<PostAggregator> postAggregators
)
{
final PostAggregator finalPostAggregator;
switch (valueType) {
case LONG:
finalPostAggregator = new LongGreatestPostAggregator(name, postAggregators);
break;
case FLOAT:
case DOUBLE:
finalPostAggregator = new DoubleGreatestPostAggregator(name, postAggregators);
break;
default:
throw new ISE("Cannot create aggregator factory for type[%s]", valueType);
}
return finalPostAggregator;
}
/**
* Calcite SQL function definition
*/
private static class GreatestSqlAggFunction extends SqlAggFunction
{
GreatestSqlAggFunction()
{
/*
* The constructor params are explained as follows,
* name: SQL function name
* sqlIdentifier: null for built-in functions
* kind: SqlKind.GREATEST
* returnTypeInference: biggest operand type & nullable if any of the operands is nullable
* operandTypeInference: same as return type
* operandTypeChecker: variadic function with at least one argument
* funcType: System
* requiresOrder: No
* requiresOver: No
* requiresGroupOrder: Not allowed
*/
super(
NAME,
null,
SqlKind.GREATEST,
ReturnTypes.cascade(ReturnTypes.LEAST_RESTRICTIVE, SqlTypeTransforms.TO_NULLABLE),
InferTypes.RETURN_TYPE,
OperandTypes.ONE_OR_MORE,
SqlFunctionCategory.SYSTEM,
false,
false,
Optionality.FORBIDDEN
);
}
}
}

View File

@ -1,136 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.aggregation.builtin;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeTransforms;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
import org.apache.druid.query.aggregation.LongMinAggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.aggregation.post.DoubleLeastPostAggregator;
import org.apache.druid.query.aggregation.post.LongLeastPostAggregator;
import org.apache.druid.segment.column.ValueType;
import java.util.List;
/**
* Calcite integration class for Least post aggregators of Long & Double types.
* It applies Min aggregators over the provided fields/expressions & combines their results via Field access post aggregators.
*/
public class LeastSqlAggregator extends MultiColumnSqlAggregator
{
private static final SqlAggFunction FUNCTION_INSTANCE = new LeastSqlAggFunction();
private static final String NAME = "LEAST";
@Override
public SqlAggFunction calciteFunction()
{
return FUNCTION_INSTANCE;
}
@Override
AggregatorFactory createAggregatorFactory(
ValueType valueType,
String prefixedName,
FieldInfo fieldInfo,
ExprMacroTable macroTable
)
{
final AggregatorFactory aggregatorFactory;
switch (valueType) {
case LONG:
aggregatorFactory = new LongMinAggregatorFactory(prefixedName, fieldInfo.fieldName, fieldInfo.expression, macroTable);
break;
case FLOAT:
case DOUBLE:
aggregatorFactory = new DoubleMinAggregatorFactory(prefixedName, fieldInfo.fieldName, fieldInfo.expression, macroTable);
break;
default:
throw new ISE("Cannot create aggregator factory for type[%s]", valueType);
}
return aggregatorFactory;
}
@Override
PostAggregator createFinalPostAggregator(
ValueType valueType,
String name,
List<PostAggregator> postAggregators
)
{
final PostAggregator finalPostAggregator;
switch (valueType) {
case LONG:
finalPostAggregator = new LongLeastPostAggregator(name, postAggregators);
break;
case FLOAT:
case DOUBLE:
finalPostAggregator = new DoubleLeastPostAggregator(name, postAggregators);
break;
default:
throw new ISE("Cannot create aggregator factory for type[%s]", valueType);
}
return finalPostAggregator;
}
/**
* Calcite SQL function definition
*/
private static class LeastSqlAggFunction extends SqlAggFunction
{
LeastSqlAggFunction()
{
/*
* The constructor params are explained as follows,
* name: SQL function name
* sqlIdentifier: null for built-in functions
* kind: SqlKind.LEAST
* returnTypeInference: biggest operand type & nullable if any of the operands is nullable
* operandTypeInference: same as return type
* operandTypeChecker: variadic function with at least one argument
* funcType: System
* requiresOrder: No
* requiresOver: No
* requiresGroupOrder: Not allowed
*/
super(
NAME,
null,
SqlKind.LEAST,
ReturnTypes.cascade(ReturnTypes.LEAST_RESTRICTIVE, SqlTypeTransforms.TO_NULLABLE),
InferTypes.RETURN_TYPE,
OperandTypes.ONE_OR_MORE,
SqlFunctionCategory.SYSTEM,
false,
false,
Optionality.FORBIDDEN
);
}
}
}

View File

@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.expression.builtin;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Function;
import org.apache.druid.sql.calcite.expression.DirectOperatorConversion;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
public class GreatestOperatorConversion extends DirectOperatorConversion
{
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(Function.GreatestFunc.NAME))
.operandTypeChecker(OperandTypes.VARIADIC)
.returnTypeInference(ReductionOperatorConversionHelper.TYPE_INFERENCE)
.build();
public GreatestOperatorConversion()
{
super(SQL_FUNCTION, Function.GreatestFunc.NAME);
}
@Override
public SqlOperator calciteOperator()
{
return SQL_FUNCTION;
}
}

View File

@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.expression.builtin;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Function;
import org.apache.druid.sql.calcite.expression.DirectOperatorConversion;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
public class LeastOperatorConversion extends DirectOperatorConversion
{
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(Function.LeastFunc.NAME))
.operandTypeChecker(OperandTypes.VARIADIC)
.returnTypeInference(ReductionOperatorConversionHelper.TYPE_INFERENCE)
.build();
public LeastOperatorConversion()
{
super(SQL_FUNCTION, Function.LeastFunc.NAME);
}
@Override
public SqlOperator calciteOperator()
{
return SQL_FUNCTION;
}
}

View File

@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.expression.builtin;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.planner.Calcites;
class ReductionOperatorConversionHelper
{
private ReductionOperatorConversionHelper()
{
}
/**
* Implements type precedence rules similar to:
* https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least
*
* @see org.apache.druid.math.expr.Function.ReduceFunc#apply
* @see org.apache.druid.math.expr.Function.ReduceFunc#getComparisionType
*/
static final SqlReturnTypeInference TYPE_INFERENCE =
opBinding -> {
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
final int n = opBinding.getOperandCount();
if (n == 0) {
return typeFactory.createSqlType(SqlTypeName.NULL);
}
SqlTypeName returnSqlTypeName = SqlTypeName.NULL;
boolean hasDouble = false;
for (int i = 0; i < n; i++) {
RelDataType type = opBinding.getOperandType(i);
SqlTypeName sqlTypeName = type.getSqlTypeName();
ValueType valueType = Calcites.getValueTypeForSqlTypeName(sqlTypeName);
// Return types are listed in order of preference:
if (valueType == ValueType.STRING) {
returnSqlTypeName = sqlTypeName;
break;
} else if (valueType == ValueType.DOUBLE || valueType == ValueType.FLOAT) {
returnSqlTypeName = SqlTypeName.DOUBLE;
hasDouble = true;
} else if (valueType == ValueType.LONG && !hasDouble) {
returnSqlTypeName = SqlTypeName.BIGINT;
} else if (sqlTypeName != SqlTypeName.NULL) {
throw new IAE("Argument %d has invalid type: %s", i, sqlTypeName);
}
}
return typeFactory.createSqlType(returnSqlTypeName);
};
}

View File

@ -37,8 +37,6 @@ import org.apache.druid.sql.calcite.aggregation.builtin.ApproxCountDistinctSqlAg
import org.apache.druid.sql.calcite.aggregation.builtin.AvgSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.CountSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.EarliestLatestAnySqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.GreatestSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.LeastSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.MaxSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.MinSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.SumSqlAggregator;
@ -66,11 +64,13 @@ import org.apache.druid.sql.calcite.expression.builtin.ConcatOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.DateTruncOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.ExtractOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.FloorOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.GreatestOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.IPv4AddressMatchOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.IPv4AddressParseOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.IPv4AddressStringifyOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.LPadOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.LTrimOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.LeastOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.LeftOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.LikeOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.MillisToTimestampOperatorConversion;
@ -126,8 +126,6 @@ public class DruidOperatorTable implements SqlOperatorTable
.add(EarliestLatestAnySqlAggregator.ANY_VALUE)
.add(new MinSqlAggregator())
.add(new MaxSqlAggregator())
.add(new GreatestSqlAggregator())
.add(new LeastSqlAggregator())
.add(new SumSqlAggregator())
.add(new SumZeroSqlAggregator())
.build();
@ -219,6 +217,12 @@ public class DruidOperatorTable implements SqlOperatorTable
.add(new StringToMultiValueStringOperatorConversion())
.build();
private static final List<SqlOperatorConversion> REDUCTION_OPERATOR_CONVERSIONS =
ImmutableList.<SqlOperatorConversion>builder()
.add(new GreatestOperatorConversion())
.add(new LeastOperatorConversion())
.build();
private static final List<SqlOperatorConversion> IPV4ADDRESS_OPERATOR_CONVERSIONS =
ImmutableList.<SqlOperatorConversion>builder()
.add(new IPv4AddressMatchOperatorConversion())
@ -282,6 +286,7 @@ public class DruidOperatorTable implements SqlOperatorTable
.addAll(VALUE_COERCION_OPERATOR_CONVERSIONS)
.addAll(ARRAY_OPERATOR_CONVERSIONS)
.addAll(MULTIVALUE_STRING_OPERATOR_CONVERSIONS)
.addAll(REDUCTION_OPERATOR_CONVERSIONS)
.addAll(IPV4ADDRESS_OPERATOR_CONVERSIONS)
.build();

View File

@ -26,7 +26,6 @@ import org.apache.calcite.tools.ValidationException;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.JodaUtils;
import org.apache.druid.java.util.common.granularity.Granularities;
@ -39,7 +38,6 @@ import org.apache.druid.query.ResourceLimitExceededException;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory;
@ -63,12 +61,8 @@ import org.apache.druid.query.aggregation.last.FloatLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.DoubleGreatestPostAggregator;
import org.apache.druid.query.aggregation.post.DoubleLeastPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.query.aggregation.post.LongGreatestPostAggregator;
import org.apache.druid.query.aggregation.post.LongLeastPostAggregator;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.ExtractionDimensionSpec;
import org.apache.druid.query.extraction.RegexDimExtractionFn;
@ -6269,444 +6263,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
);
}
@Test
public void testGreatestLongAndDoubleWithGroupBy() throws Exception
{
// Cannot vectorize due to virtual columns.
cannotVectorize();
testQuery(
"SELECT * FROM ("
+ " SELECT greatest(cntl1, cntl2), greatest(cntd1, cntd2) FROM (\n"
+ " SELECT TIME_FLOOR(__time, 'P1D') AS t,\n"
+ " count(1) AS cntl1, 10 AS cntl2,\n"
+ " (1.2 + count(1)) AS cntd1, 10.2 AS cntd2\n"
+ " FROM \"foo\"\n"
+ " GROUP BY 1\n"
+ " )"
+ ")\n",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setVirtualColumns(
expressionVirtualColumn(
"v0",
"timestamp_floor(\"__time\",'P1D',null,'UTC')",
ValueType.LONG
)
)
.setDimensions(dimensions(new DefaultDimensionSpec(
"v0",
"d0",
ValueType.LONG
)))
.setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0")))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
)
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
new LongMaxAggregatorFactory("_a0:0", "a0"),
new LongMaxAggregatorFactory("_a0:1", null, "10", ExprMacroTable.nil()),
new DoubleMaxAggregatorFactory("_a1:0", null, "(1.2 + \"a0\")", ExprMacroTable.nil()),
new DoubleMaxAggregatorFactory("_a1:1", null, "10.2", ExprMacroTable.nil())
))
.setPostAggregatorSpecs(
ImmutableList.of(
new LongGreatestPostAggregator(
"_a0",
ImmutableList.of(
new FieldAccessPostAggregator(null, "_a0:0"),
new FieldAccessPostAggregator(null, "_a0:1")
)
),
new DoubleGreatestPostAggregator(
"_a1",
ImmutableList.of(
new FieldAccessPostAggregator(null, "_a1:0"),
new FieldAccessPostAggregator(null, "_a1:1")
)
)
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{10L, 10.2D})
);
}
@Test
public void testLeastLongAndDoubleWithGroupBy() throws Exception
{
// Cannot vectorize due to virtual columns.
cannotVectorize();
testQuery(
"SELECT * FROM ("
+ " SELECT least(cntl1, cntl2), least(cntd1, cntd2) FROM (\n"
+ " SELECT TIME_FLOOR(__time, 'P1D') AS t,\n"
+ " count(1) AS cntl1, 10 AS cntl2,\n"
+ " (1.2 + count(1)) AS cntd1, 10.2 AS cntd2\n"
+ " FROM \"foo\"\n"
+ " GROUP BY 1\n"
+ " )"
+ ")\n",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setVirtualColumns(
expressionVirtualColumn(
"v0",
"timestamp_floor(\"__time\",'P1D',null,'UTC')",
ValueType.LONG
)
)
.setDimensions(dimensions(new DefaultDimensionSpec(
"v0",
"d0",
ValueType.LONG
)))
.setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0")))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
)
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
new LongMinAggregatorFactory("_a0:0", "a0"),
new LongMinAggregatorFactory("_a0:1", null, "10", ExprMacroTable.nil()),
new DoubleMinAggregatorFactory("_a1:0", null, "(1.2 + \"a0\")", ExprMacroTable.nil()),
new DoubleMinAggregatorFactory("_a1:1", null, "10.2", ExprMacroTable.nil())
))
.setPostAggregatorSpecs(
ImmutableList.of(
new LongLeastPostAggregator(
"_a0",
ImmutableList.of(
new FieldAccessPostAggregator(null, "_a0:0"),
new FieldAccessPostAggregator(null, "_a0:1")
)
),
new DoubleLeastPostAggregator(
"_a1",
ImmutableList.of(
new FieldAccessPostAggregator(null, "_a1:0"),
new FieldAccessPostAggregator(null, "_a1:1")
)
)
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{1L, 2.2D})
);
}
@Test
public void testGreatestSingleColumnPostAggregations() throws Exception
{
// Cannot vectorize due to virtual columns.
cannotVectorize();
testQuery(
"SELECT\n"
+ " greatest(cnt), greatest(m1), greatest(m2)\n"
+ " FROM \"foo\"\n",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(
new LongMaxAggregatorFactory("a0:0", "cnt"),
new DoubleMaxAggregatorFactory("a1:0", "m1"),
new DoubleMaxAggregatorFactory("a2:0", "m2")
))
.postAggregators(ImmutableList.of(
new LongGreatestPostAggregator(
"a0",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a0:0")
)
),
new DoubleGreatestPostAggregator(
"a1",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a1:0")
)
),
new DoubleGreatestPostAggregator(
"a2",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a2:0")
)
)
)
)
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{1L, 6.0F, 6.0D})
);
}
@Test
public void testLeastSingleColumnPostAggregations() throws Exception
{
// Cannot vectorize due to virtual columns.
cannotVectorize();
testQuery(
"SELECT\n"
+ " least(cnt), least(m1), least(m2)\n"
+ " FROM \"foo\"\n",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(
new LongMinAggregatorFactory("a0:0", "cnt"),
new DoubleMinAggregatorFactory("a1:0", "m1"),
new DoubleMinAggregatorFactory("a2:0", "m2")
))
.postAggregators(ImmutableList.of(
new LongLeastPostAggregator(
"a0",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a0:0")
)
),
new DoubleLeastPostAggregator(
"a1",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a1:0")
)
),
new DoubleLeastPostAggregator(
"a2",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a2:0")
)
)
)
)
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{1L, 1.0F, 1.0D})
);
}
@Test
public void testGreatestCombinationPostAggregations() throws Exception
{
// Cannot vectorize due to virtual columns.
cannotVectorize();
testQuery(
"SELECT\n"
+ " greatest(cnt, 10, 10 * 2 + 3),\n"
+ " greatest(m1, 10.0, 10.2 * 2.0 + 3.0),\n"
+ " greatest(m2, 10.0, 10.2 * 2.0 + 3.0)\n"
+ " FROM \"foo\"\n",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(
new LongMaxAggregatorFactory("a0:0", "cnt"),
new LongMaxAggregatorFactory("a0:1", null, "10", ExprMacroTable.nil()),
new LongMaxAggregatorFactory("a0:2", null, "23", ExprMacroTable.nil()),
new DoubleMaxAggregatorFactory("a1:0", "m1"),
new DoubleMaxAggregatorFactory("a1:1", null, "10.0", ExprMacroTable.nil()),
new DoubleMaxAggregatorFactory("a1:2", null, "23.4", ExprMacroTable.nil()),
new DoubleMaxAggregatorFactory("a2:0", "m2"),
new DoubleMaxAggregatorFactory("a2:1", null, "10.0", ExprMacroTable.nil()),
new DoubleMaxAggregatorFactory("a2:2", null, "23.4", ExprMacroTable.nil())
))
.postAggregators(ImmutableList.of(
new LongGreatestPostAggregator(
"a0",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a0:0"),
new FieldAccessPostAggregator(null, "a0:1"),
new FieldAccessPostAggregator(null, "a0:2")
)
),
new DoubleGreatestPostAggregator(
"a1",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a1:0"),
new FieldAccessPostAggregator(null, "a1:1"),
new FieldAccessPostAggregator(null, "a1:2")
)
),
new DoubleGreatestPostAggregator(
"a2",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a2:0"),
new FieldAccessPostAggregator(null, "a2:1"),
new FieldAccessPostAggregator(null, "a2:2")
)
))
)
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{23L, 23.4D, 23.4D})
);
}
@Test
public void testLeastCombinationPostAggregations() throws Exception
{
// Cannot vectorize due to virtual columns.
cannotVectorize();
testQuery(
"SELECT\n"
+ " least(cnt, 10, 10 * 2 + 3),\n"
+ " least(m1, 10.0, 10.2 * 2.0 + 3.0),\n"
+ " least(m2, 10.0, 10.2 * 2.0 + 3.0)\n"
+ " FROM \"foo\"\n",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(
new LongMinAggregatorFactory("a0:0", "cnt"),
new LongMinAggregatorFactory("a0:1", null, "10", ExprMacroTable.nil()),
new LongMinAggregatorFactory("a0:2", null, "23", ExprMacroTable.nil()),
new DoubleMinAggregatorFactory("a1:0", "m1"),
new DoubleMinAggregatorFactory("a1:1", null, "10.0", ExprMacroTable.nil()),
new DoubleMinAggregatorFactory("a1:2", null, "23.4", ExprMacroTable.nil()),
new DoubleMinAggregatorFactory("a2:0", "m2"),
new DoubleMinAggregatorFactory("a2:1", null, "10.0", ExprMacroTable.nil()),
new DoubleMinAggregatorFactory("a2:2", null, "23.4", ExprMacroTable.nil())
))
.postAggregators(ImmutableList.of(
new LongLeastPostAggregator(
"a0",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a0:0"),
new FieldAccessPostAggregator(null, "a0:1"),
new FieldAccessPostAggregator(null, "a0:2")
)
),
new DoubleLeastPostAggregator(
"a1",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a1:0"),
new FieldAccessPostAggregator(null, "a1:1"),
new FieldAccessPostAggregator(null, "a1:2")
)
),
new DoubleLeastPostAggregator(
"a2",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a2:0"),
new FieldAccessPostAggregator(null, "a2:1"),
new FieldAccessPostAggregator(null, "a2:2")
)
))
)
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{1L, 1.0D, 1.0D})
);
}
@Test
public void testGreatestInvalidPostAggregations() throws Exception
{
// Cannot vectorize due to virtual columns.
cannotVectorize();
expectedException.expect(RuntimeException.class);
expectedException.expectCause(CoreMatchers.instanceOf(ISE.class));
expectedException.expectCause(
ThrowableMessageMatcher.hasMessage(
CoreMatchers.containsString(
"Cannot create aggregator factory for type[STRING]"
)
)
);
testQuery("SELECT GREATEST(dim1) FROM druid.foo", ImmutableList.of(), ImmutableList.of());
}
@Test
public void testLeastInvalidPostAggregations() throws Exception
{
// Cannot vectorize due to virtual columns.
cannotVectorize();
expectedException.expect(RuntimeException.class);
expectedException.expectCause(CoreMatchers.instanceOf(ISE.class));
expectedException.expectCause(
ThrowableMessageMatcher.hasMessage(
CoreMatchers.containsString(
"Cannot create aggregator factory for type[STRING]"
)
)
);
testQuery("SELECT LEAST(dim1) FROM druid.foo", ImmutableList.of(), ImmutableList.of());
}
@Test
public void testGreatestInvalidCombinationPostAggregations() throws Exception
{
// Cannot vectorize due to virtual columns.
cannotVectorize();
expectedException.expect(ValidationException.class);
expectedException.expectCause(CoreMatchers.instanceOf(IllegalArgumentException.class));
expectedException.expectCause(
ThrowableMessageMatcher.hasMessage(
CoreMatchers.containsString(
"Cannot infer return type for GREATEST; operand types: [INTEGER, VARCHAR]"
)
)
);
testQuery("SELECT GREATEST(10, dim1) FROM druid.foo", ImmutableList.of(), ImmutableList.of());
}
@Test
public void testLeastInvalidCombinationPostAggregations() throws Exception
{
// Cannot vectorize due to virtual columns.
cannotVectorize();
expectedException.expect(ValidationException.class);
expectedException.expectCause(CoreMatchers.instanceOf(IllegalArgumentException.class));
expectedException.expectCause(
ThrowableMessageMatcher.hasMessage(
CoreMatchers.containsString(
"Cannot infer return type for LEAST; operand types: [INTEGER, VARCHAR]"
)
)
);
testQuery("SELECT LEAST(10, dim1) FROM druid.foo", ImmutableList.of(), ImmutableList.of());
}
@Test
public void testAvgDailyCountDistinct() throws Exception
{

View File

@ -133,6 +133,11 @@ class ExpressionTestHelper
return rexBuilder.makeIntervalLiteral(v, intervalQualifier);
}
RexNode makeLiteral(Double d)
{
return rexBuilder.makeLiteral(d, createSqlType(SqlTypeName.DOUBLE), true);
}
RexNode makeCall(SqlOperator op, RexNode... exprs)
{
return rexBuilder.makeCall(op, exprs);

View File

@ -0,0 +1,261 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.expression;
import com.google.common.collect.ImmutableMap;
import org.apache.calcite.avatica.util.TimeUnit;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.expression.builtin.GreatestOperatorConversion;
import org.junit.Before;
import org.junit.Test;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class GreatestExpressionTest extends ExpressionTestBase
{
private static final String DOUBLE_KEY = "d";
private static final double DOUBLE_VALUE = 3.1;
private static final String LONG_KEY = "l";
private static final long LONG_VALUE = 2L;
private static final String STRING_KEY = "s";
private static final String STRING_VALUE = "foo";
private static final RowSignature ROW_SIGNATURE = RowSignature
.builder()
.add(DOUBLE_KEY, ValueType.DOUBLE)
.add(LONG_KEY, ValueType.LONG)
.add(STRING_KEY, ValueType.STRING)
.build();
private static final Map<String, Object> BINDINGS = ImmutableMap.of(
DOUBLE_KEY, DOUBLE_VALUE,
LONG_KEY, LONG_VALUE,
STRING_KEY, STRING_VALUE
);
private GreatestOperatorConversion target;
private ExpressionTestHelper testHelper;
@Before
public void setUp()
{
target = new GreatestOperatorConversion();
testHelper = new ExpressionTestHelper(ROW_SIGNATURE, BINDINGS);
}
@Test
public void testNoArgs()
{
testExpression(
Collections.emptyList(),
buildExpectedExpression(),
null
);
}
@Test
public void testAllNull()
{
testExpression(
Arrays.asList(
testHelper.getConstantNull(),
testHelper.getConstantNull()
),
buildExpectedExpression(null, null),
null
);
}
@Test
public void testSomeNull()
{
testExpression(
Arrays.asList(
testHelper.makeInputRef(DOUBLE_KEY),
testHelper.getConstantNull(),
testHelper.makeInputRef(STRING_KEY)
),
buildExpectedExpression(
testHelper.makeVariable(DOUBLE_KEY),
null,
testHelper.makeVariable(STRING_KEY)
),
STRING_VALUE
);
}
@Test
public void testAllDouble()
{
testExpression(
Arrays.asList(
testHelper.makeLiteral(34.1),
testHelper.makeInputRef(DOUBLE_KEY),
testHelper.makeLiteral(5.2),
testHelper.makeLiteral(767.3)
),
buildExpectedExpression(
34.1,
testHelper.makeVariable(DOUBLE_KEY),
5.2,
767.3
),
767.3
);
}
@Test
public void testAllLong()
{
testExpression(
Arrays.asList(
testHelper.makeInputRef(LONG_KEY),
testHelper.makeLiteral(0)
),
buildExpectedExpression(
testHelper.makeVariable(LONG_KEY),
0
),
LONG_VALUE
);
}
@Test
public void testAllString()
{
testExpression(
Arrays.asList(
testHelper.makeLiteral("B"),
testHelper.makeInputRef(STRING_KEY),
testHelper.makeLiteral("A")
),
buildExpectedExpression(
"B",
testHelper.makeVariable(STRING_KEY),
"A"
),
STRING_VALUE
);
}
@Test
public void testCoerceString()
{
testExpression(
Arrays.asList(
testHelper.makeLiteral(-1),
testHelper.makeInputRef(DOUBLE_KEY),
testHelper.makeLiteral("A")
),
buildExpectedExpression(
-1,
testHelper.makeVariable(DOUBLE_KEY),
"A"
),
"A"
);
}
@Test
public void testCoerceDouble()
{
testExpression(
Arrays.asList(
testHelper.makeLiteral(-1),
testHelper.makeInputRef(DOUBLE_KEY)
),
buildExpectedExpression(
-1,
testHelper.makeVariable(DOUBLE_KEY)
),
DOUBLE_VALUE
);
}
@Test
public void testDecimal()
{
testExpression(
Arrays.asList(
testHelper.makeLiteral(BigDecimal.valueOf(1.2)),
testHelper.makeLiteral(BigDecimal.valueOf(3.4))
),
buildExpectedExpression(
1.2,
3.4
),
3.4
);
}
@Test
public void testTimestamp()
{
testExpression(
Arrays.asList(
testHelper.makeLiteral(DateTimes.utc(1000)),
testHelper.makeLiteral(DateTimes.utc(2000))
),
buildExpectedExpression(
1000,
2000
),
2000L
);
}
@Test
public void testInvalidType()
{
expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH");
testExpression(
Collections.singletonList(
testHelper.makeLiteral(
new BigDecimal(13), // YEAR-MONTH literals value is months
new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO)
)
),
null,
null
);
}
private void testExpression(
List<? extends RexNode> exprs,
final DruidExpression expectedExpression,
final Object expectedResult
)
{
testHelper.testExpression(target.calciteOperator(), exprs, expectedExpression, expectedResult);
}
private DruidExpression buildExpectedExpression(Object... args)
{
return testHelper.buildExpectedExpression(target.getDruidFunctionName(), args);
}
}

View File

@ -0,0 +1,261 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.expression;
import com.google.common.collect.ImmutableMap;
import org.apache.calcite.avatica.util.TimeUnit;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.expression.builtin.LeastOperatorConversion;
import org.junit.Before;
import org.junit.Test;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class LeastExpressionTest extends ExpressionTestBase
{
private static final String DOUBLE_KEY = "d";
private static final double DOUBLE_VALUE = 3.1;
private static final String LONG_KEY = "l";
private static final long LONG_VALUE = 2L;
private static final String STRING_KEY = "s";
private static final String STRING_VALUE = "foo";
private static final RowSignature ROW_SIGNATURE = RowSignature
.builder()
.add(DOUBLE_KEY, ValueType.DOUBLE)
.add(LONG_KEY, ValueType.LONG)
.add(STRING_KEY, ValueType.STRING)
.build();
private static final Map<String, Object> BINDINGS = ImmutableMap.of(
DOUBLE_KEY, DOUBLE_VALUE,
LONG_KEY, LONG_VALUE,
STRING_KEY, STRING_VALUE
);
private LeastOperatorConversion target;
private ExpressionTestHelper testHelper;
@Before
public void setUp()
{
target = new LeastOperatorConversion();
testHelper = new ExpressionTestHelper(ROW_SIGNATURE, BINDINGS);
}
@Test
public void testNoArgs()
{
testExpression(
Collections.emptyList(),
buildExpectedExpression(),
null
);
}
@Test
public void testAllNull()
{
testExpression(
Arrays.asList(
testHelper.getConstantNull(),
testHelper.getConstantNull()
),
buildExpectedExpression(null, null),
null
);
}
@Test
public void testSomeNull()
{
testExpression(
Arrays.asList(
testHelper.makeInputRef(DOUBLE_KEY),
testHelper.getConstantNull(),
testHelper.makeInputRef(STRING_KEY)
),
buildExpectedExpression(
testHelper.makeVariable(DOUBLE_KEY),
null,
testHelper.makeVariable(STRING_KEY)
),
String.valueOf(DOUBLE_VALUE)
);
}
@Test
public void testAllDouble()
{
testExpression(
Arrays.asList(
testHelper.makeLiteral(34.1),
testHelper.makeInputRef(DOUBLE_KEY),
testHelper.makeLiteral(5.2),
testHelper.makeLiteral(767.3)
),
buildExpectedExpression(
34.1,
testHelper.makeVariable(DOUBLE_KEY),
5.2,
767.3
),
3.1
);
}
@Test
public void testAllLong()
{
testExpression(
Arrays.asList(
testHelper.makeInputRef(LONG_KEY),
testHelper.makeLiteral(0)
),
buildExpectedExpression(
testHelper.makeVariable(LONG_KEY),
0
),
0L
);
}
@Test
public void testAllString()
{
testExpression(
Arrays.asList(
testHelper.makeLiteral("B"),
testHelper.makeInputRef(STRING_KEY),
testHelper.makeLiteral("A")
),
buildExpectedExpression(
"B",
testHelper.makeVariable(STRING_KEY),
"A"
),
"A"
);
}
@Test
public void testCoerceString()
{
testExpression(
Arrays.asList(
testHelper.makeLiteral(-1),
testHelper.makeInputRef(DOUBLE_KEY),
testHelper.makeLiteral("A")
),
buildExpectedExpression(
-1,
testHelper.makeVariable(DOUBLE_KEY),
"A"
),
"-1"
);
}
@Test
public void testCoerceDouble()
{
testExpression(
Arrays.asList(
testHelper.makeLiteral(-1),
testHelper.makeInputRef(DOUBLE_KEY)
),
buildExpectedExpression(
-1,
testHelper.makeVariable(DOUBLE_KEY)
),
-1.0
);
}
@Test
public void testDecimal()
{
testExpression(
Arrays.asList(
testHelper.makeLiteral(BigDecimal.valueOf(1.2)),
testHelper.makeLiteral(BigDecimal.valueOf(3.4))
),
buildExpectedExpression(
1.2,
3.4
),
1.2
);
}
@Test
public void testTimestamp()
{
testExpression(
Arrays.asList(
testHelper.makeLiteral(DateTimes.utc(1000)),
testHelper.makeLiteral(DateTimes.utc(2000))
),
buildExpectedExpression(
1000,
2000
),
1000L
);
}
@Test
public void testInvalidType()
{
expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH");
testExpression(
Collections.singletonList(
testHelper.makeLiteral(
new BigDecimal(13), // YEAR-MONTH literals value is months
new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO)
)
),
null,
null
);
}
private void testExpression(
List<? extends RexNode> exprs,
final DruidExpression expectedExpression,
final Object expectedResult
)
{
testHelper.testExpression(target.calciteOperator(), exprs, expectedExpression, expectedResult);
}
private DruidExpression buildExpectedExpression(Object... args)
{
return testHelper.buildExpectedExpression(target.getDruidFunctionName(), args);
}
}