fix issue with SQL planner and null array constants (#12971)

This commit is contained in:
Clint Wylie 2022-08-26 04:44:17 -07:00 committed by GitHub
parent acb09ff18b
commit 4bdf9815c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 162 additions and 48 deletions

View File

@ -347,12 +347,13 @@ public class FunctionTest extends InitializedNullHandlingTest
public void testStringToArray()
{
assertArrayExpr("string_to_array('1,2,3', ',')", new String[]{"1", "2", "3"});
assertArrayExpr("string_to_array(null, ',')", null);
assertArrayExpr("string_to_array('1', ',')", new String[]{"1"});
assertArrayExpr("string_to_array(array_to_string(a, ','), ',')", new String[]{"foo", "bar", "baz", "foobar"});
}
@Test
public void testArrayCast()
public void testArrayCastLegacy()
{
assertArrayExpr("cast([1, 2, 3], 'STRING_ARRAY')", new String[]{"1", "2", "3"});
assertArrayExpr("cast([1, 2, 3], 'DOUBLE_ARRAY')", new Double[]{1.0, 2.0, 3.0});
@ -364,6 +365,19 @@ public class FunctionTest extends InitializedNullHandlingTest
assertArrayExpr("cast(['1.0', '2.0', '3.0'], 'LONG_ARRAY')", new Long[]{1L, 2L, 3L});
}
@Test
public void testArrayCast()
{
assertArrayExpr("cast([1, 2, 3], 'ARRAY<STRING>')", new String[]{"1", "2", "3"});
assertArrayExpr("cast([1, 2, 3], 'ARRAY<DOUBLE>')", new Double[]{1.0, 2.0, 3.0});
assertArrayExpr("cast(c, 'ARRAY<LONG>')", new Long[]{3L, 4L, 5L});
assertArrayExpr(
"cast(string_to_array(array_to_string(b, ','), ','), 'ARRAY<LONG>')",
new Long[]{1L, 2L, 3L, 4L, 5L}
);
assertArrayExpr("cast(['1.0', '2.0', '3.0'], 'ARRAY<LONG>')", new Long[]{1L, 2L, 3L});
}
@Test
public void testArraySlice()
{

View File

@ -318,8 +318,8 @@ public class OperatorConversions
* operator should never, ever, return null.
*
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
* {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
* calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
* {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)}
* must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeNonNull(final SqlTypeName typeName)
{
@ -335,8 +335,8 @@ public class OperatorConversions
* Sets the return type of the operator to "typeName", marked as nullable.
*
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
* {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
* calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
* {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)}
* must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeNullable(final SqlTypeName typeName)
{
@ -352,8 +352,8 @@ public class OperatorConversions
* Sets the return type of the operator to "typeName", marked as nullable if any of its operands are nullable.
*
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
* {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
* calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
* {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)}
* must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeCascadeNullable(final SqlTypeName typeName)
{
@ -366,10 +366,10 @@ public class OperatorConversions
* Sets the return type of the operator to an array type with elements of "typeName", marked as nullable.
*
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
* {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
* calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
* {@link #returnTypeArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be
* used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeNullableArray(final SqlTypeName elementTypeName)
public OperatorBuilder returnTypeArrayWithNullableElements(final SqlTypeName elementTypeName)
{
Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
@ -379,13 +379,33 @@ public class OperatorConversions
return this;
}
/**
* Sets the return type of the operator to an array type with elements of "typeName", marked as nullable.
*
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
* {@link #returnTypeArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be
* used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeNullableArrayWithNullableElements(final SqlTypeName elementTypeName)
{
this.returnTypeInference = ReturnTypes.cascade(
opBinding -> Calcites.createSqlArrayTypeWithNullability(
opBinding.getTypeFactory(),
elementTypeName,
true
),
SqlTypeTransforms.FORCE_NULLABLE
);
return this;
}
/**
* Provides customized return type inference logic.
*
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
* {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
* calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
* {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)}
* must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeInference(final SqlReturnTypeInference returnTypeInference)
{

View File

@ -38,7 +38,7 @@ public class MultiValueStringToArrayOperatorConversion extends DirectOperatorCon
.operatorBuilder("MV_TO_ARRAY")
.operandTypeChecker(OperandTypes.family(SqlTypeFamily.STRING))
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeNullableArray(SqlTypeName.VARCHAR)
.returnTypeNullableArrayWithNullableElements(SqlTypeName.VARCHAR)
.build();
public MultiValueStringToArrayOperatorConversion()

View File

@ -165,7 +165,7 @@ public class NestedDataOperatorConversions
.operatorBuilder("JSON_PATHS")
.operandTypeChecker(OperandTypes.ANY)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.returnTypeNullableArray(SqlTypeName.VARCHAR)
.returnTypeArrayWithNullableElements(SqlTypeName.VARCHAR)
.build();
@Override
@ -207,7 +207,7 @@ public class NestedDataOperatorConversions
)
)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.returnTypeNullableArray(SqlTypeName.VARCHAR)
.returnTypeNullableArrayWithNullableElements(SqlTypeName.VARCHAR)
.build();
@Override

View File

@ -40,7 +40,7 @@ public class StringToArrayOperatorConversion extends DirectOperatorConversion
)
)
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeNullableArray(SqlTypeName.VARCHAR)
.returnTypeNullableArrayWithNullableElements(SqlTypeName.VARCHAR)
.build();
public StringToArrayOperatorConversion()

View File

@ -34,9 +34,9 @@ import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
/**
* A Calcite {@code RexExecutor} that reduces Calcite expressions by evaluating them using Druid's own built-in
@ -130,7 +130,7 @@ public class DruidRexExecutor implements RexExecutor
double exprResultDouble = exprResult.asDouble();
if (Double.isNaN(exprResultDouble) || Double.isInfinite(exprResultDouble)) {
String expression = druidExpression.getExpression();
throw new UnsupportedSQLQueryException("'%s' evaluates to '%s' that is not supported in SQL. You can either cast the expression as bigint ('cast(%s as bigint)') or char ('cast(%s as char)') or change the expression itself",
throw new UnsupportedSQLQueryException("'%s' evaluates to '%s' that is not supported in SQL. You can either cast the expression as BIGINT ('CAST(%s as BIGINT)') or VARCHAR ('CAST(%s as VARCHAR)') or change the expression itself",
expression,
Double.toString(exprResultDouble),
expression,
@ -142,40 +142,42 @@ public class DruidRexExecutor implements RexExecutor
}
} else if (sqlTypeName == SqlTypeName.ARRAY) {
assert exprResult.isArray();
if (SqlTypeName.NUMERIC_TYPES.contains(constExp.getType().getComponentType().getSqlTypeName())) {
final Object[] array = exprResult.asArray();
if (array == null) {
literal = rexBuilder.makeNullLiteral(constExp.getType());
} else if (SqlTypeName.NUMERIC_TYPES.contains(constExp.getType().getComponentType().getSqlTypeName())) {
if (exprResult.type().getElementType().is(ExprType.LONG)) {
List<BigDecimal> resultAsBigDecimalList = Arrays.stream(exprResult.asArray())
.map(val -> {
final Number longVal = (Number) val;
if (longVal == null) {
return null;
}
return BigDecimal.valueOf(longVal.longValue());
})
.collect(Collectors.toList());
List<BigDecimal> resultAsBigDecimalList = new ArrayList<>(array.length);
for (Object val : array) {
final Number longVal = (Number) val;
if (longVal == null) {
resultAsBigDecimalList.add(null);
} else {
resultAsBigDecimalList.add(BigDecimal.valueOf(longVal.longValue()));
}
}
literal = rexBuilder.makeLiteral(resultAsBigDecimalList, constExp.getType(), true);
} else {
List<BigDecimal> resultAsBigDecimalList = Arrays.stream(exprResult.asArray()).map(
val -> {
final Number doubleVal = (Number) val;
if (doubleVal == null) {
return null;
}
if (Double.isNaN(doubleVal.doubleValue()) || Double.isInfinite(doubleVal.doubleValue())) {
String expression = druidExpression.getExpression();
throw new UnsupportedSQLQueryException(
"'%s' contains an element that evaluates to '%s' which is not supported in SQL. You can either cast the element in the array to bigint or char or change the expression itself",
expression,
Double.toString(doubleVal.doubleValue())
);
}
return BigDecimal.valueOf(doubleVal.doubleValue());
}
).collect(Collectors.toList());
List<BigDecimal> resultAsBigDecimalList = new ArrayList<>(array.length);
for (Object val : array) {
final Number doubleVal = (Number) val;
if (doubleVal == null) {
resultAsBigDecimalList.add(null);
} else if (Double.isNaN(doubleVal.doubleValue()) || Double.isInfinite(doubleVal.doubleValue())) {
String expression = druidExpression.getExpression();
throw new UnsupportedSQLQueryException(
"'%s' contains an element that evaluates to '%s' which is not supported in SQL. You can either cast the element in the ARRAY to BIGINT or VARCHAR or change the expression itself",
expression,
Double.toString(doubleVal.doubleValue())
);
} else {
resultAsBigDecimalList.add(BigDecimal.valueOf(doubleVal.doubleValue()));
}
}
literal = rexBuilder.makeLiteral(resultAsBigDecimalList, constExp.getType(), true);
}
} else {
literal = rexBuilder.makeLiteral(Arrays.asList(exprResult.asArray()), constExp.getType(), true);
literal = rexBuilder.makeLiteral(Arrays.asList(array), constExp.getType(), true);
}
} else if (sqlTypeName == SqlTypeName.OTHER) {
// complex constant is not reducible, so just leave it as an expression

View File

@ -2506,4 +2506,82 @@ public class CalciteNestedDataQueryTest extends BaseCalciteQueryTest
);
}
@Test
public void testJsonPathsNonJsonInput()
{
testQuery(
"SELECT JSON_PATHS(string), JSON_PATHS(1234), JSON_PATHS('1234'), JSON_PATHS(1.1), JSON_PATHS(null)\n"
+ "FROM druid.nested",
ImmutableList.of(
Druids.newScanQueryBuilder()
.dataSource(DATA_SOURCE)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(
expressionVirtualColumn("v0", "json_paths(\"string\")", ColumnType.STRING_ARRAY),
expressionVirtualColumn("v1", "array('$')", ColumnType.STRING_ARRAY)
)
.columns("v0", "v1")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.legacy(false)
.build()
),
ImmutableList.of(
new Object[]{"[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]"},
new Object[]{"[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]"},
new Object[]{"[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]"},
new Object[]{"[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]"},
new Object[]{"[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]"},
new Object[]{"[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]"},
new Object[]{"[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]", "[\"$\"]"}
),
RowSignature.builder()
.add("EXPR$0", ColumnType.STRING_ARRAY)
.add("EXPR$1", ColumnType.STRING_ARRAY)
.add("EXPR$2", ColumnType.STRING_ARRAY)
.add("EXPR$3", ColumnType.STRING_ARRAY)
.add("EXPR$4", ColumnType.STRING_ARRAY)
.build()
);
}
@Test
public void testJsonKeysNonJsonInput()
{
testQuery(
"SELECT JSON_KEYS(string, '$'), JSON_KEYS(1234, '$'), JSON_KEYS('1234', '$'), JSON_KEYS(1.1, '$'), JSON_KEYS(null, '$')\n"
+ "FROM druid.nested",
ImmutableList.of(
Druids.newScanQueryBuilder()
.dataSource(DATA_SOURCE)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(
expressionVirtualColumn("v0", "json_keys(\"string\",'$')", ColumnType.STRING_ARRAY),
expressionVirtualColumn("v1", "null", ColumnType.STRING_ARRAY)
)
.columns("v0", "v1")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.legacy(false)
.build()
),
ImmutableList.of(
new Object[]{null, null, null, null, null},
new Object[]{null, null, null, null, null},
new Object[]{null, null, null, null, null},
new Object[]{null, null, null, null, null},
new Object[]{null, null, null, null, null},
new Object[]{null, null, null, null, null},
new Object[]{null, null, null, null, null}
),
RowSignature.builder()
.add("EXPR$0", ColumnType.STRING_ARRAY)
.add("EXPR$1", ColumnType.STRING_ARRAY)
.add("EXPR$2", ColumnType.STRING_ARRAY)
.add("EXPR$3", ColumnType.STRING_ARRAY)
.add("EXPR$4", ColumnType.STRING_ARRAY)
.build()
);
}
}

View File

@ -313,7 +313,7 @@ public class CalciteSelectQueryTest extends BaseCalciteQueryTest
public void testSelectConstantExpressionEquivalentToNaN()
{
expectedException.expectMessage(
"'(log10(0) - log10(0))' evaluates to 'NaN' that is not supported in SQL. You can either cast the expression as bigint ('cast((log10(0) - log10(0)) as bigint)') or char ('cast((log10(0) - log10(0)) as char)') or change the expression itself");
"'(log10(0) - log10(0))' evaluates to 'NaN' that is not supported in SQL. You can either cast the expression as BIGINT ('CAST((log10(0) - log10(0)) as BIGINT)') or VARCHAR ('CAST((log10(0) - log10(0)) as VARCHAR)') or change the expression itself");
testQuery(
"SELECT log10(0) - log10(0), dim1 FROM foo LIMIT 1",
ImmutableList.of(),
@ -325,7 +325,7 @@ public class CalciteSelectQueryTest extends BaseCalciteQueryTest
public void testSelectConstantExpressionEquivalentToInfinity()
{
expectedException.expectMessage(
"'log10(0)' evaluates to '-Infinity' that is not supported in SQL. You can either cast the expression as bigint ('cast(log10(0) as bigint)') or char ('cast(log10(0) as char)') or change the expression itself");
"'log10(0)' evaluates to '-Infinity' that is not supported in SQL. You can either cast the expression as BIGINT ('CAST(log10(0) as BIGINT)') or VARCHAR ('CAST(log10(0) as VARCHAR)') or change the expression itself");
testQuery(
"SELECT log10(0), dim1 FROM foo LIMIT 1",
ImmutableList.of(),