Allow casted literal values in SQL functions accepting literals (Part 2) (#15316)

This commit is contained in:
Laksh Singla 2023-11-03 21:22:19 +05:30 committed by GitHub
parent f39a778f7d
commit 0cc8839a60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 438 additions and 10 deletions

View File

@ -45,6 +45,12 @@ java.util.Random#<init>() @ Use ThreadLocalRandom.current() or the constructor w
java.lang.Math#random() @ Use ThreadLocalRandom.current()
java.util.regex.Pattern#matches(java.lang.String,java.lang.CharSequence) @ Use String.startsWith(), endsWith(), contains(), or compile and cache a Pattern explicitly
org.apache.calcite.sql.type.OperandTypes#LITERAL @ LITERAL type checker throws when literals with CAST are passed. Use org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker instead.
org.apache.calcite.sql.type.OperandTypes#BOOLEAN_LITERAL @ Create a type checker like org.apache.calcite.sql.type.POSITIVE_INTEGER_LITERAL and use that instead
org.apache.calcite.sql.type.OperandTypes#ARRAY_BOOLEAN_LITERAL @ Create a type checker like org.apache.calcite.sql.type.POSITIVE_INTEGER_LITERAL and use that instead
org.apache.calcite.sql.type.OperandTypes#POSITIVE_INTEGER_LITERAL @ Use org.apache.calcite.sql.type.POSITIVE_INTEGER_LITERAL instead
org.apache.calcite.sql.type.OperandTypes#UNIT_INTERVAL_NUMERIC_LITERAL @ Create a type checker like org.apache.calcite.sql.type.POSITIVE_INTEGER_LITERAL and use that instead
org.apache.calcite.sql.type.OperandTypes#NUMERIC_UNIT_INTERVAL_NUMERIC_LITERAL @ Create a type checker like org.apache.calcite.sql.type.POSITIVE_INTEGER_LITERAL and use that instead
org.apache.calcite.sql.type.OperandTypes#NULLABLE_LITERAL @ Create an instance of org.apache.calcite.sql.type.CastedLiteralOperandTypeChecker that allows nulls and use that instead
org.apache.commons.io.FileUtils#getTempDirectory() @ Use org.junit.rules.TemporaryFolder for tests instead
org.apache.commons.io.FileUtils#deleteDirectory(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils#deleteDirectory()
org.apache.commons.io.FileUtils#forceMkdir(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils.mkdirp instead

View File

@ -26,6 +26,7 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.CastedLiteralOperandTypeCheckers;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
@ -156,7 +157,7 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
OperandTypes.sequence(
"'" + name + "(column, size)'",
OperandTypes.ANY,
OperandTypes.POSITIVE_INTEGER_LITERAL
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL
),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC)
),
@ -164,8 +165,8 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
OperandTypes.sequence(
"'" + name + "(column, size, scale)'",
OperandTypes.ANY,
OperandTypes.POSITIVE_INTEGER_LITERAL,
OperandTypes.POSITIVE_INTEGER_LITERAL
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL,
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL
),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
),
@ -173,8 +174,8 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
OperandTypes.sequence(
"'" + name + "(column, size, scale, strictNumberParsing)'",
OperandTypes.ANY,
OperandTypes.POSITIVE_INTEGER_LITERAL,
OperandTypes.POSITIVE_INTEGER_LITERAL,
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL,
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL,
OperandTypes.BOOLEAN
),
OperandTypes.family(

View File

@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
//CHECKSTYLE.OFF: PackageName - Must be in Calcite
package org.apache.calcite.sql.type;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.util.Static;
import org.apache.calcite.util.Util;
/**
* Like {@link LiteralOperandTypeChecker}, but also allows casted literals.
*
* "Casted literals" are like `CAST(100 AS INTEGER)`. While it doesn't make sense to cast a literal that the user
* themselves enter, it is important to add a broader validation to allow these literals because Calcite's JDBC driver
* doesn't allow the wildcards (?)to work without a cast, and there's no workaround it.
* <p>
* This makes sure that the functions using the literal operand type checker can be workaround the JDBC's restriction,
* without being marked as invalid SQL input
*/
public class CastedLiteralOperandTypeChecker implements SqlSingleOperandTypeChecker
{
public static SqlSingleOperandTypeChecker LITERAL = new CastedLiteralOperandTypeChecker(false);
private final boolean allowNull;
CastedLiteralOperandTypeChecker(boolean allowNull)
{
this.allowNull = allowNull;
}
@Override
public boolean checkSingleOperandType(
SqlCallBinding callBinding,
SqlNode node,
int iFormalOperand,
boolean throwOnFailure
)
{
Util.discard(iFormalOperand);
if (SqlUtil.isNullLiteral(node, true)) {
if (allowNull) {
return true;
}
if (throwOnFailure) {
throw callBinding.newError(
Static.RESOURCE.argumentMustNotBeNull(
callBinding.getOperator().getName()));
}
return false;
}
// The following line of code is the only difference between the OperandTypes.LITERAL and this type checker
if (!SqlUtil.isLiteral(node, true) && !SqlUtil.isLiteralChain(node)) {
if (throwOnFailure) {
throw callBinding.newError(
Static.RESOURCE.argumentMustBeLiteral(
callBinding.getOperator().getName()));
}
return false;
}
return true;
}
@Override
public String getAllowedSignatures(SqlOperator op, String opName)
{
return "<LITERAL>";
}
}

View File

@ -0,0 +1,159 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
//CHECKSTYLE.OFF: PackageName - Must be in Calcite
package org.apache.calcite.sql.type;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.util.Static;
import org.apache.druid.error.DruidException;
import java.math.BigDecimal;
public class CastedLiteralOperandTypeCheckers
{
public static final SqlSingleOperandTypeChecker LITERAL = new CastedLiteralOperandTypeChecker(false);
/**
* Blatantly copied from {@link OperandTypes#POSITIVE_INTEGER_LITERAL}, however the reference to the {@link #LITERAL}
* is the one which accepts casted literals
*/
public static final SqlSingleOperandTypeChecker POSITIVE_INTEGER_LITERAL =
new FamilyOperandTypeChecker(
ImmutableList.of(SqlTypeFamily.INTEGER),
i -> false
)
{
@Override
public boolean checkSingleOperandType(
SqlCallBinding callBinding,
SqlNode operand,
int iFormalOperand,
SqlTypeFamily family,
boolean throwOnFailure
)
{
// This LITERAL refers to the above implementation, the one which allows casted literals
if (!LITERAL.checkSingleOperandType(
callBinding,
operand,
iFormalOperand,
throwOnFailure
)) {
return false;
}
if (!super.checkSingleOperandType(
callBinding,
operand,
iFormalOperand,
family,
throwOnFailure
)) {
return false;
}
final SqlLiteral arg = fetchPrimitiveLiteralFromCasts(operand);
final BigDecimal value = arg.getValueAs(BigDecimal.class);
if (value.compareTo(BigDecimal.ZERO) < 0
|| hasFractionalPart(value)) {
if (throwOnFailure) {
throw callBinding.newError(
Static.RESOURCE.argumentMustBePositiveInteger(
callBinding.getOperator().getName()));
}
return false;
}
if (value.compareTo(BigDecimal.valueOf(Integer.MAX_VALUE)) > 0) {
if (throwOnFailure) {
throw callBinding.newError(
Static.RESOURCE.numberLiteralOutOfRange(value.toString()));
}
return false;
}
return true;
}
/** Returns whether a number has any fractional part.
*
* @see BigDecimal#longValueExact() */
private boolean hasFractionalPart(BigDecimal bd)
{
return bd.precision() - bd.scale() <= 0;
}
};
public static boolean isLiteral(SqlNode node, boolean allowCast)
{
assert node != null;
if (node instanceof SqlLiteral) {
return true;
}
if (!allowCast) {
return false;
}
switch (node.getKind()) {
case CAST:
// "CAST(e AS type)" is literal if "e" is literal
return isLiteral(((SqlCall) node).operand(0), true);
case MAP_VALUE_CONSTRUCTOR:
case ARRAY_VALUE_CONSTRUCTOR:
return ((SqlCall) node).getOperandList().stream()
.allMatch(o -> isLiteral(o, true));
case DEFAULT:
return true; // DEFAULT is always NULL
default:
return false;
}
}
/**
* Fetches primitive literals from the casts, including NULL literal.
* It throws if the entered node isn't a primitive literal, which can be cast multiple times.
*
* Therefore, it would fail on the following types:
* 1. Nodes that are not of the form CAST(....(CAST LITERAL AS TYPE).....)
* 2. ARRAY and MAP literals. This won't be required since we are only using this method in the type checker for
* primitive types
*/
private static SqlLiteral fetchPrimitiveLiteralFromCasts(SqlNode node)
{
if (node == null) {
throw DruidException.defensive("'node' cannot be null");
}
if (node instanceof SqlLiteral) {
return (SqlLiteral) node;
}
switch (node.getKind()) {
case CAST:
return fetchPrimitiveLiteralFromCasts(((SqlCall) node).operand(0));
case DEFAULT:
return SqlLiteral.createNull(SqlParserPos.ZERO);
default:
throw DruidException.defensive("Expected a literal or a cast on the literal. Found [%s] instead", node.getKind());
}
}
}

View File

@ -26,6 +26,7 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.CastedLiteralOperandTypeCheckers;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
@ -156,7 +157,7 @@ public class ArrayConcatSqlAggregator implements SqlAggregator
OperandTypes.sequence(
StringUtils.format("'%s(expr, maxSizeBytes)'", NAME),
OperandTypes.ARRAY,
OperandTypes.POSITIVE_INTEGER_LITERAL
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL
)
),
SqlFunctionCategory.USER_DEFINED_FUNCTION,

View File

@ -28,6 +28,7 @@ import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.type.CastedLiteralOperandTypeCheckers;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
@ -179,7 +180,11 @@ public class ArraySqlAggregator implements SqlAggregator
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(StringUtils.format("'%s(expr, maxSizeBytes)'", NAME), OperandTypes.ANY, OperandTypes.POSITIVE_INTEGER_LITERAL),
OperandTypes.sequence(
StringUtils.format("'%s(expr, maxSizeBytes)'", NAME),
OperandTypes.ANY,
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL
),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
)
),

View File

@ -29,6 +29,7 @@ import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.type.CastedLiteralOperandTypeCheckers;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
@ -251,7 +252,7 @@ public class StringSqlAggregator implements SqlAggregator
StringUtils.format("'%s(expr, separator, maxSizeBytes)'", name),
OperandTypes.ANY,
OperandTypes.STRING,
OperandTypes.POSITIVE_INTEGER_LITERAL
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL
),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.NUMERIC)
)

View File

@ -2708,6 +2708,106 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
);
}
@Test
public void testArrayAggArraysWithMaxSizeBytes()
{
// Produces nested array - ARRAY<ARRAY<LONG>>, which frame writers don't support. A way to get this query
// to run would be to use nested columns.
msqIncompatible();
cannotVectorize();
testQuery(
"SELECT ARRAY_AGG(ARRAY[l1, l2], 10000), ARRAY_AGG(DISTINCT ARRAY[l1, l2], CAST(10000 AS INTEGER)) FROM numfoo",
QUERY_CONTEXT_NO_STRINGIFY_ARRAY,
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.virtualColumns(
expressionVirtualColumn("v0", "array(\"l1\",\"l2\")", ColumnType.LONG_ARRAY)
)
.aggregators(
aggregators(
new ExpressionLambdaAggregatorFactory(
"a0",
ImmutableSet.of("v0"),
"__acc",
"ARRAY<ARRAY<LONG>>[]",
"ARRAY<ARRAY<LONG>>[]",
true,
true,
false,
"array_append(\"__acc\", \"v0\")",
"array_concat(\"__acc\", \"a0\")",
null,
null,
new HumanReadableBytes(10000),
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"a1",
ImmutableSet.of("v0"),
"__acc",
"ARRAY<ARRAY<LONG>>[]",
"ARRAY<ARRAY<LONG>>[]",
true,
true,
false,
"array_set_add(\"__acc\", \"v0\")",
"array_set_add_all(\"__acc\", \"a1\")",
null,
null,
new HumanReadableBytes(10000),
TestExprMacroTable.INSTANCE
)
)
)
.context(QUERY_CONTEXT_NO_STRINGIFY_ARRAY)
.build()
),
(sql, queryResults) -> {
// ordering is not stable in array_agg and array_concat_agg
List<Object[]> expected = ImmutableList.of(
useDefault ?
new Object[]{
Arrays.asList(
Arrays.asList(7L, 0L),
Arrays.asList(325323L, 325323L),
Arrays.asList(0L, 0L),
Arrays.asList(0L, 0L),
Arrays.asList(0L, 0L),
Arrays.asList(0L, 0L)
),
Arrays.asList(
Arrays.asList(0L, 0L),
Arrays.asList(7L, 0L),
Arrays.asList(325323L, 325323L)
)
}
:
new Object[]{
Arrays.asList(
Arrays.asList(7L, null),
Arrays.asList(325323L, 325323L),
Arrays.asList(0L, 0L),
Arrays.asList(null, null),
Arrays.asList(null, null),
Arrays.asList(null, null)
),
Arrays.asList(
Arrays.asList(null, null),
Arrays.asList(0L, 0L),
Arrays.asList(7L, null),
Arrays.asList(325323L, 325323L)
)
}
);
assertResultsDeepEquals(sql, expected, queryResults.results);
}
);
}
@Test
public void testArrayConcatAggArrays()
{
@ -2769,6 +2869,69 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
);
}
@Test
public void testArrayConcatAggArraysWithMaxSizeBytes()
{
cannotVectorize();
testQuery(
"SELECT ARRAY_CONCAT_AGG(ARRAY[l1, l2], 10000), ARRAY_CONCAT_AGG(DISTINCT ARRAY[l1, l2], CAST(10000 AS INTEGER)) "
+ "FROM numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.virtualColumns(
expressionVirtualColumn("v0", "array(\"l1\",\"l2\")", ColumnType.LONG_ARRAY)
)
.aggregators(
aggregators(
new ExpressionLambdaAggregatorFactory(
"a0",
ImmutableSet.of("v0"),
"__acc",
"ARRAY<LONG>[]",
"ARRAY<LONG>[]",
true,
false,
false,
"array_concat(\"__acc\", \"v0\")",
"array_concat(\"__acc\", \"a0\")",
null,
null,
new HumanReadableBytes(10000),
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"a1",
ImmutableSet.of("v0"),
"__acc",
"ARRAY<LONG>[]",
"ARRAY<LONG>[]",
true,
false,
false,
"array_set_add_all(\"__acc\", \"v0\")",
"array_set_add_all(\"__acc\", \"a1\")",
null,
null,
new HumanReadableBytes(10000),
TestExprMacroTable.INSTANCE
)
)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
useDefault
? new Object[]{"[7,0,325323,325323,0,0,0,0,0,0,0,0]", "[0,7,325323]"}
: new Object[]{"[7,null,325323,325323,0,0,null,null,null,null,null,null]", "[null,0,7,325323]"}
)
);
}
@Test
public void testArrayAggArrayColumns()
@ -3031,7 +3194,7 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
{
cannotVectorize();
testQuery(
"SELECT ARRAY_AGG(l1, 128), ARRAY_AGG(DISTINCT l1, 128) FROM numfoo",
"SELECT ARRAY_AGG(l1, 128), ARRAY_AGG(DISTINCT l1, CAST(128 AS INTEGER)) FROM numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)

View File

@ -13478,7 +13478,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
{
cannotVectorize();
testQuery(
"SELECT STRING_AGG(l1, ',', 128), STRING_AGG(DISTINCT l1, ',', 128) FROM numfoo",
"SELECT STRING_AGG(l1, ',', 128), STRING_AGG(DISTINCT l1, ',', CAST(128 AS INTEGER)) FROM numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)