SQL: Implement data type verification for conditionals (#35916)
Add special verifier rule to check that the arguments of conditional functions are of the same or compatible types. This way the user gets a descriptive error message with line number and column indicating where is the offending argument. Closes: #35907
This commit is contained in:
parent
110c4fdd65
commit
c91ef1105d
|
@ -19,6 +19,7 @@ import org.elasticsearch.xpack.sql.expression.function.FunctionAttribute;
|
|||
import org.elasticsearch.xpack.sql.expression.function.Functions;
|
||||
import org.elasticsearch.xpack.sql.expression.function.Score;
|
||||
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
|
||||
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalFunction;
|
||||
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
|
||||
import org.elasticsearch.xpack.sql.plan.logical.Aggregate;
|
||||
import org.elasticsearch.xpack.sql.plan.logical.Distinct;
|
||||
|
@ -169,7 +170,7 @@ public final class Verifier {
|
|||
for (Attribute a : p.intputSet()) {
|
||||
String nameCandidate = useQualifier ? a.qualifiedName() : a.name();
|
||||
// add only primitives (object types would only result in another error)
|
||||
if (!(a.dataType() == DataType.UNSUPPORTED) && a.dataType().isPrimitive()) {
|
||||
if ((a.dataType() != DataType.UNSUPPORTED) && a.dataType().isPrimitive()) {
|
||||
potentialMatches.add(nameCandidate);
|
||||
}
|
||||
}
|
||||
|
@ -220,6 +221,7 @@ public final class Verifier {
|
|||
Set<Failure> localFailures = new LinkedHashSet<>();
|
||||
|
||||
validateInExpression(p, localFailures);
|
||||
validateConditional(p, localFailures);
|
||||
|
||||
if (!groupingFailures.contains(p)) {
|
||||
checkGroupBy(p, localFailures, resolvedFunctions, groupingFailures);
|
||||
|
@ -282,14 +284,13 @@ public final class Verifier {
|
|||
*/
|
||||
private static boolean checkGroupBy(LogicalPlan p, Set<Failure> localFailures,
|
||||
Map<String, Function> resolvedFunctions, Set<LogicalPlan> groupingFailures) {
|
||||
return checkGroupByAgg(p, localFailures, groupingFailures, resolvedFunctions)
|
||||
&& checkGroupByOrder(p, localFailures, groupingFailures, resolvedFunctions)
|
||||
return checkGroupByAgg(p, localFailures, resolvedFunctions)
|
||||
&& checkGroupByOrder(p, localFailures, groupingFailures)
|
||||
&& checkGroupByHaving(p, localFailures, groupingFailures, resolvedFunctions);
|
||||
}
|
||||
|
||||
// check whether an orderBy failed or if it occurs on a non-key
|
||||
private static boolean checkGroupByOrder(LogicalPlan p, Set<Failure> localFailures,
|
||||
Set<LogicalPlan> groupingFailures, Map<String, Function> functions) {
|
||||
private static boolean checkGroupByOrder(LogicalPlan p, Set<Failure> localFailures, Set<LogicalPlan> groupingFailures) {
|
||||
if (p instanceof OrderBy) {
|
||||
OrderBy o = (OrderBy) p;
|
||||
LogicalPlan child = o.child();
|
||||
|
@ -432,8 +433,7 @@ public final class Verifier {
|
|||
|
||||
|
||||
// check whether plain columns specified in an agg are mentioned in the group-by
|
||||
private static boolean checkGroupByAgg(LogicalPlan p, Set<Failure> localFailures,
|
||||
Set<LogicalPlan> groupingFailures, Map<String, Function> functions) {
|
||||
private static boolean checkGroupByAgg(LogicalPlan p, Set<Failure> localFailures, Map<String, Function> functions) {
|
||||
if (p instanceof Aggregate) {
|
||||
Aggregate a = (Aggregate) p;
|
||||
|
||||
|
@ -578,7 +578,7 @@ public final class Verifier {
|
|||
e.forEachUp((In in) -> {
|
||||
DataType dt = in.value().dataType();
|
||||
for (Expression value : in.list()) {
|
||||
if (areTypesCompatible(in.value().dataType(), value.dataType()) == false) {
|
||||
if (areTypesCompatible(dt, value.dataType()) == false) {
|
||||
localFailures.add(fail(value, "expected data type [%s], value provided is of type [%s]",
|
||||
dt, value.dataType()));
|
||||
return;
|
||||
|
@ -588,6 +588,28 @@ public final class Verifier {
|
|||
In.class));
|
||||
}
|
||||
|
||||
private static void validateConditional(LogicalPlan p, Set<Failure> localFailures) {
|
||||
p.forEachExpressions(e ->
|
||||
e.forEachUp((ConditionalFunction cf) -> {
|
||||
DataType dt = DataType.NULL;
|
||||
|
||||
for (Expression child : cf.children()) {
|
||||
if (dt == DataType.NULL) {
|
||||
if (Expressions.isNull(child) == false) {
|
||||
dt = child.dataType();
|
||||
}
|
||||
} else {
|
||||
if (areTypesCompatible(dt, child.dataType()) == false) {
|
||||
localFailures.add(fail(child, "expected data type [%s], value provided is of type [%s]",
|
||||
dt, child.dataType()));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
ConditionalFunction.class));
|
||||
}
|
||||
|
||||
private static boolean areTypesCompatible(DataType left, DataType right) {
|
||||
if (left == right) {
|
||||
return true;
|
||||
|
@ -598,4 +620,4 @@ public final class Verifier {
|
|||
(left.isNumeric() && right.isNumeric());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,11 @@ import org.elasticsearch.xpack.sql.analysis.index.EsIndex;
|
|||
import org.elasticsearch.xpack.sql.analysis.index.IndexResolution;
|
||||
import org.elasticsearch.xpack.sql.analysis.index.IndexResolverTests;
|
||||
import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry;
|
||||
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Coalesce;
|
||||
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Greatest;
|
||||
import org.elasticsearch.xpack.sql.expression.predicate.conditional.IfNull;
|
||||
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Least;
|
||||
import org.elasticsearch.xpack.sql.expression.predicate.conditional.NullIf;
|
||||
import org.elasticsearch.xpack.sql.parser.SqlParser;
|
||||
import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
|
||||
import org.elasticsearch.xpack.sql.stats.Metrics;
|
||||
|
@ -423,4 +428,32 @@ public class VerifierErrorMessagesTests extends ESTestCase {
|
|||
+ "[integer] in [basic], [long] in [incompatible]",
|
||||
incompatibleError("SELECT languages FROM \"*\" ORDER BY SIGN(ABS(emp_no))"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testConditionalWithDifferentDataTypes_SelectClause() {
|
||||
@SuppressWarnings("unchecked")
|
||||
String function = randomFrom(IfNull.class, NullIf.class).getSimpleName();
|
||||
assertEquals("1:" + (22 + function.length()) +
|
||||
": expected data type [INTEGER], value provided is of type [KEYWORD]",
|
||||
error("SELECT 1 = 1 OR " + function + "(3, '4') > 1"));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
String arbirtraryArgsfunction = randomFrom(Coalesce.class, Greatest.class, Least.class).getSimpleName();
|
||||
assertEquals("1:" + (34 + arbirtraryArgsfunction.length()) +
|
||||
": expected data type [INTEGER], value provided is of type [KEYWORD]",
|
||||
error("SELECT 1 = 1 OR " + arbirtraryArgsfunction + "(null, null, 3, '4') > 1"));
|
||||
}
|
||||
|
||||
public void testConditionalWithDifferentDataTypes_WhereClause() {
|
||||
@SuppressWarnings("unchecked")
|
||||
String function = randomFrom(IfNull.class, NullIf.class).getSimpleName();
|
||||
assertEquals("1:" + (34 + function.length()) +
|
||||
": expected data type [KEYWORD], value provided is of type [INTEGER]",
|
||||
error("SELECT * FROM test WHERE " + function + "('foo', 4) > 1"));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
String arbirtraryArgsfunction = randomFrom(Coalesce.class, Greatest.class, Least.class).getSimpleName();
|
||||
assertEquals("1:" + (46 + arbirtraryArgsfunction.length()) +
|
||||
": expected data type [KEYWORD], value provided is of type [INTEGER]",
|
||||
error("SELECT * FROM test WHERE " + arbirtraryArgsfunction + "(null, null, 'foo', 4) > 1"));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue