Previously, `CASE` and `IIF` when translated to painless scripts (used in GROUP BY, HAVING, WHERE) a custom `caseFunction` registered in the `InternalSqlScriptUtils` was used. This function received and array of arbitrary length: ```[condition1, result1, condition2, result2, ... elseResult]``` Painless doesn't know of the context and therefore is evaluating all conditions and results before invoking the `caseFunction` on them. As a consequence, erroneous result expressions (i.e. division by 0) where always evaluated despite of the guarding condition. Replace the `caseFunction` with painless `<cond> ? <res1> : <res2>` expressions to properly guard the result expressions and only evaluate the one for which its guarding condition evaluates to true (or of course the elseResult). As a bonus, this approach includes performance benefits since we avoid unnecessary evaluations of both conditions and result expressions. Fixes: #49672 (cherry picked from commit 9584b345d89f797bfb658212b928b9812804f02f)
This commit is contained in:
parent
e1cab4feb4
commit
fdac9e99fa
|
@ -197,6 +197,24 @@ END as lang_skills FROM test_emp GROUP BY lang_skills ORDER BY 2;
|
|||
10 | zero
|
||||
;
|
||||
|
||||
caseGroupByProtectedDivisionByZero
|
||||
schema::x:i
|
||||
SELECT CASE WHEN languages = 1 THEN NULL ELSE ( salary / (languages - 1) ) END AS x FROM test_emp GROUP BY 1 ORDER BY 1 LIMIT 10;
|
||||
|
||||
x
|
||||
---------------
|
||||
null
|
||||
6331
|
||||
6486
|
||||
7780
|
||||
7974
|
||||
8068
|
||||
8489
|
||||
8935
|
||||
9043
|
||||
9071
|
||||
;
|
||||
|
||||
caseGroupByAndHaving
|
||||
schema::count:l|gender:s|languages:byte
|
||||
SELECT count(*) AS count, gender, languages FROM test_emp
|
||||
|
@ -353,6 +371,28 @@ IIF(NVL(languages, 0) = 0, 'zero',
|
|||
10 |zero
|
||||
;
|
||||
|
||||
iifGroupByProtectedDivisionByZero
|
||||
schema::count:l|x:i
|
||||
SELECT count(*) AS count,
|
||||
IIF(languages - 1 = 0, 0,
|
||||
IIF(languages - 1 = 1, (salary / 10000) / (languages - 1),
|
||||
IIF(languages - 1 = 2, (salary / 10000) / languages,
|
||||
IIF(languages - 1 = 3, (salary / 10000) / (languages + 1),
|
||||
(salary / 10000) / (languages + 2))))) as x FROM test_emp GROUP BY x ORDER BY 2;
|
||||
|
||||
count | x
|
||||
---------------+---------------
|
||||
10 |null
|
||||
50 |0
|
||||
14 |1
|
||||
8 |2
|
||||
4 |3
|
||||
5 |4
|
||||
6 |5
|
||||
2 |6
|
||||
1 |7
|
||||
;
|
||||
|
||||
iifGroupByAndHaving
|
||||
schema::count:l|gender:s|languages:byte
|
||||
SELECT count(*) AS count, gender, languages FROM test_emp
|
||||
|
|
|
@ -38,7 +38,6 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.string.SubstringFu
|
|||
import org.elasticsearch.xpack.sql.expression.literal.geo.GeoShape;
|
||||
import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalDayTime;
|
||||
import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalYearMonth;
|
||||
import org.elasticsearch.xpack.sql.expression.predicate.conditional.CaseProcessor;
|
||||
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalProcessor.ConditionalOperation;
|
||||
import org.elasticsearch.xpack.sql.expression.predicate.conditional.NullIfProcessor;
|
||||
import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.SqlBinaryArithmeticOperation;
|
||||
|
@ -67,10 +66,6 @@ public class InternalSqlScriptUtils extends InternalQlScriptUtils {
|
|||
//
|
||||
// Conditional
|
||||
//
|
||||
public static Object caseFunction(List<Object> expressions) {
|
||||
return CaseProcessor.apply(expressions);
|
||||
}
|
||||
|
||||
public static Object coalesce(List<Object> expressions) {
|
||||
return ConditionalOperation.COALESCE.apply(expressions);
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.xpack.ql.expression.Expressions;
|
|||
import org.elasticsearch.xpack.ql.expression.gen.pipeline.Pipe;
|
||||
import org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder;
|
||||
import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
|
||||
import org.elasticsearch.xpack.ql.expression.gen.script.Scripts;
|
||||
import org.elasticsearch.xpack.ql.tree.NodeInfo;
|
||||
import org.elasticsearch.xpack.ql.tree.Source;
|
||||
import org.elasticsearch.xpack.ql.type.DataType;
|
||||
|
@ -19,7 +20,6 @@ import org.elasticsearch.xpack.sql.type.SqlDataTypes;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.StringJoiner;
|
||||
|
||||
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
|
||||
import static org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder.paramsBuilder;
|
||||
|
@ -161,14 +161,26 @@ public class Case extends ConditionalFunction {
|
|||
}
|
||||
templates.add(asScript(elseResult));
|
||||
|
||||
StringJoiner template = new StringJoiner(",", "{sql}.caseFunction([", "])");
|
||||
// Use painless ?: expressions to prevent evaluation of return expression
|
||||
// if the condition which guards it evaluates to false (e.g. division by 0)
|
||||
StringBuilder sb = new StringBuilder();
|
||||
ParamsBuilder params = paramsBuilder();
|
||||
|
||||
for (ScriptTemplate scriptTemplate : templates) {
|
||||
template.add(scriptTemplate.template());
|
||||
for (int i = 0; i < templates.size(); i++) {
|
||||
ScriptTemplate scriptTemplate = templates.get(i);
|
||||
if (i < templates.size() - 1) {
|
||||
if (i % 2 == 0) {
|
||||
// painless ? : operator expects primitive boolean, thus we use nullSafeFilter
|
||||
// to convert object Boolean to primitive boolean (null => false)
|
||||
sb.append(Scripts.nullSafeFilter(scriptTemplate).template()).append(" ? ");
|
||||
} else {
|
||||
sb.append(scriptTemplate.template()).append(" : ");
|
||||
}
|
||||
} else {
|
||||
sb.append(scriptTemplate.template());
|
||||
}
|
||||
params.script(scriptTemplate.params());
|
||||
}
|
||||
|
||||
return new ScriptTemplate(formatTemplate(template.toString()), params.build(), dataType());
|
||||
return new ScriptTemplate(formatTemplate(sb.toString()), params.build(), dataType());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,18 +50,6 @@ public class CaseProcessor implements Processor {
|
|||
return processors.get(processors.size() - 1).process(input);
|
||||
}
|
||||
|
||||
public static Object apply(List<Object> objects) {
|
||||
// Check every condition in sequence and if it evaluates to TRUE,
|
||||
// evaluate and return the result associated with that condition.
|
||||
for (int i = 0; i < objects.size() - 2; i += 2) {
|
||||
if (objects.get(i) == Boolean.TRUE) {
|
||||
return objects.get(i + 1);
|
||||
}
|
||||
}
|
||||
// resort to default value
|
||||
return objects.get(objects.size() - 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
|
|
|
@ -71,7 +71,6 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
|
|||
#
|
||||
# Conditional
|
||||
#
|
||||
def caseFunction(java.util.List)
|
||||
def coalesce(java.util.List)
|
||||
def greatest(java.util.List)
|
||||
def least(java.util.List)
|
||||
|
|
|
@ -674,9 +674,9 @@ public class QueryTranslatorTests extends ESTestCase {
|
|||
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
|
||||
assertNotNull(groupingContext);
|
||||
ScriptTemplate scriptTemplate = groupingContext.tail.script();
|
||||
assertEquals("InternalSqlScriptUtils.caseFunction([InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue("
|
||||
+ "doc,params.v0),params.v1),params.v2,InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(" +
|
||||
"doc,params.v3),params.v4),params.v5,params.v6])",
|
||||
assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(doc,params.v0)," +
|
||||
"params.v1)) ? params.v2 : InternalQlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.regex(InternalQlScriptUtils." +
|
||||
"docValue(doc,params.v3),params.v4)) ? params.v5 : params.v6",
|
||||
scriptTemplate.toString());
|
||||
assertEquals("[{v=keyword}, {v=^.*foo.*$}, {v=1}, {v=keyword}, {v=.*bar.*}, {v=2}, {v=3}]",
|
||||
scriptTemplate.params().toString());
|
||||
|
@ -1194,9 +1194,9 @@ public class QueryTranslatorTests extends ESTestCase {
|
|||
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
|
||||
assertNotNull(groupingContext);
|
||||
ScriptTemplate scriptTemplate = groupingContext.tail.script();
|
||||
assertEquals("InternalSqlScriptUtils.caseFunction([InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(" + ""
|
||||
+ "doc,params.v0),params.v1),params.v2,InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v3)," +
|
||||
"params.v4),params.v5,params.v6])",
|
||||
assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v0)," +
|
||||
"params.v1)) ? params.v2 : InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(" +
|
||||
"doc,params.v3),params.v4)) ? params.v5 : params.v6",
|
||||
scriptTemplate.toString());
|
||||
assertEquals("[{v=int}, {v=10}, {v=foo}, {v=int}, {v=20}, {v=bar}, {v=default}]", scriptTemplate.params().toString());
|
||||
}
|
||||
|
@ -1209,8 +1209,8 @@ public class QueryTranslatorTests extends ESTestCase {
|
|||
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
|
||||
assertNotNull(groupingContext);
|
||||
ScriptTemplate scriptTemplate = groupingContext.tail.script();
|
||||
assertEquals("InternalSqlScriptUtils.caseFunction([InternalQlScriptUtils.gt(" +
|
||||
"InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2,params.v3])",
|
||||
assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v0)," +
|
||||
"params.v1)) ? params.v2 : params.v3",
|
||||
scriptTemplate.toString());
|
||||
assertEquals("[{v=int}, {v=20}, {v=foo}, {v=bar}]", scriptTemplate.params().toString());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue