SQL: Fix unecessary evaluation for CASE/IIF (#57159) (#57262)

Previously, `CASE` and `IIF` when translated to painless scripts
(used in GROUP BY, HAVING, WHERE) a custom `caseFunction`
registered in the `InternalSqlScriptUtils` was used. This function
received and array of arbitrary length:
```[condition1, result1, condition2, result2, ... elseResult]```

Painless doesn't know of the context and therefore is evaluating
all conditions and results before invoking the `caseFunction` on them.
As a consequence, erroneous result expressions (i.e. division by 0)
where always evaluated despite of the guarding condition.

Replace the `caseFunction` with painless `<cond> ? <res1> : <res2>`
expressions to properly guard the result expressions and only evaluate
the one for which its guarding condition evaluates to true (or of course
the elseResult).

As a bonus, this approach includes performance benefits since we avoid
unnecessary evaluations of both conditions and result expressions.

Fixes: #49672
(cherry picked from commit 9584b345d89f797bfb658212b928b9812804f02f)
This commit is contained in:
Marios Trivyzas 2020-05-28 11:30:14 +02:00 committed by GitHub
parent e1cab4feb4
commit fdac9e99fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 66 additions and 32 deletions

View File

@ -197,6 +197,24 @@ END as lang_skills FROM test_emp GROUP BY lang_skills ORDER BY 2;
10 | zero
;
caseGroupByProtectedDivisionByZero
schema::x:i
SELECT CASE WHEN languages = 1 THEN NULL ELSE ( salary / (languages - 1) ) END AS x FROM test_emp GROUP BY 1 ORDER BY 1 LIMIT 10;
x
---------------
null
6331
6486
7780
7974
8068
8489
8935
9043
9071
;
caseGroupByAndHaving
schema::count:l|gender:s|languages:byte
SELECT count(*) AS count, gender, languages FROM test_emp
@ -353,6 +371,28 @@ IIF(NVL(languages, 0) = 0, 'zero',
10 |zero
;
iifGroupByProtectedDivisionByZero
schema::count:l|x:i
SELECT count(*) AS count,
IIF(languages - 1 = 0, 0,
IIF(languages - 1 = 1, (salary / 10000) / (languages - 1),
IIF(languages - 1 = 2, (salary / 10000) / languages,
IIF(languages - 1 = 3, (salary / 10000) / (languages + 1),
(salary / 10000) / (languages + 2))))) as x FROM test_emp GROUP BY x ORDER BY 2;
count | x
---------------+---------------
10 |null
50 |0
14 |1
8 |2
4 |3
5 |4
6 |5
2 |6
1 |7
;
iifGroupByAndHaving
schema::count:l|gender:s|languages:byte
SELECT count(*) AS count, gender, languages FROM test_emp

View File

@ -38,7 +38,6 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.string.SubstringFu
import org.elasticsearch.xpack.sql.expression.literal.geo.GeoShape;
import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalDayTime;
import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalYearMonth;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.CaseProcessor;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalProcessor.ConditionalOperation;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.NullIfProcessor;
import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.SqlBinaryArithmeticOperation;
@ -67,10 +66,6 @@ public class InternalSqlScriptUtils extends InternalQlScriptUtils {
//
// Conditional
//
public static Object caseFunction(List<Object> expressions) {
return CaseProcessor.apply(expressions);
}
public static Object coalesce(List<Object> expressions) {
return ConditionalOperation.COALESCE.apply(expressions);
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder;
import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.ql.expression.gen.script.Scripts;
import org.elasticsearch.xpack.ql.tree.NodeInfo;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
@ -19,7 +20,6 @@ import org.elasticsearch.xpack.sql.type.SqlDataTypes;
import java.util.ArrayList;
import java.util.List;
import java.util.StringJoiner;
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder.paramsBuilder;
@ -161,14 +161,26 @@ public class Case extends ConditionalFunction {
}
templates.add(asScript(elseResult));
StringJoiner template = new StringJoiner(",", "{sql}.caseFunction([", "])");
// Use painless ?: expressions to prevent evaluation of return expression
// if the condition which guards it evaluates to false (e.g. division by 0)
StringBuilder sb = new StringBuilder();
ParamsBuilder params = paramsBuilder();
for (ScriptTemplate scriptTemplate : templates) {
template.add(scriptTemplate.template());
for (int i = 0; i < templates.size(); i++) {
ScriptTemplate scriptTemplate = templates.get(i);
if (i < templates.size() - 1) {
if (i % 2 == 0) {
// painless ? : operator expects primitive boolean, thus we use nullSafeFilter
// to convert object Boolean to primitive boolean (null => false)
sb.append(Scripts.nullSafeFilter(scriptTemplate).template()).append(" ? ");
} else {
sb.append(scriptTemplate.template()).append(" : ");
}
} else {
sb.append(scriptTemplate.template());
}
params.script(scriptTemplate.params());
}
return new ScriptTemplate(formatTemplate(template.toString()), params.build(), dataType());
return new ScriptTemplate(formatTemplate(sb.toString()), params.build(), dataType());
}
}

View File

@ -50,18 +50,6 @@ public class CaseProcessor implements Processor {
return processors.get(processors.size() - 1).process(input);
}
public static Object apply(List<Object> objects) {
// Check every condition in sequence and if it evaluates to TRUE,
// evaluate and return the result associated with that condition.
for (int i = 0; i < objects.size() - 2; i += 2) {
if (objects.get(i) == Boolean.TRUE) {
return objects.get(i + 1);
}
}
// resort to default value
return objects.get(objects.size() - 1);
}
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@ -71,7 +71,6 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
#
# Conditional
#
def caseFunction(java.util.List)
def coalesce(java.util.List)
def greatest(java.util.List)
def least(java.util.List)

View File

@ -674,9 +674,9 @@ public class QueryTranslatorTests extends ESTestCase {
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
assertNotNull(groupingContext);
ScriptTemplate scriptTemplate = groupingContext.tail.script();
assertEquals("InternalSqlScriptUtils.caseFunction([InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue("
+ "doc,params.v0),params.v1),params.v2,InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(" +
"doc,params.v3),params.v4),params.v5,params.v6])",
assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(doc,params.v0)," +
"params.v1)) ? params.v2 : InternalQlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.regex(InternalQlScriptUtils." +
"docValue(doc,params.v3),params.v4)) ? params.v5 : params.v6",
scriptTemplate.toString());
assertEquals("[{v=keyword}, {v=^.*foo.*$}, {v=1}, {v=keyword}, {v=.*bar.*}, {v=2}, {v=3}]",
scriptTemplate.params().toString());
@ -1194,9 +1194,9 @@ public class QueryTranslatorTests extends ESTestCase {
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
assertNotNull(groupingContext);
ScriptTemplate scriptTemplate = groupingContext.tail.script();
assertEquals("InternalSqlScriptUtils.caseFunction([InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(" + ""
+ "doc,params.v0),params.v1),params.v2,InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v3)," +
"params.v4),params.v5,params.v6])",
assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v0)," +
"params.v1)) ? params.v2 : InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(" +
"doc,params.v3),params.v4)) ? params.v5 : params.v6",
scriptTemplate.toString());
assertEquals("[{v=int}, {v=10}, {v=foo}, {v=int}, {v=20}, {v=bar}, {v=default}]", scriptTemplate.params().toString());
}
@ -1209,8 +1209,8 @@ public class QueryTranslatorTests extends ESTestCase {
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
assertNotNull(groupingContext);
ScriptTemplate scriptTemplate = groupingContext.tail.script();
assertEquals("InternalSqlScriptUtils.caseFunction([InternalQlScriptUtils.gt(" +
"InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2,params.v3])",
assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v0)," +
"params.v1)) ? params.v2 : params.v3",
scriptTemplate.toString());
assertEquals("[{v=int}, {v=20}, {v=foo}, {v=bar}]", scriptTemplate.params().toString());
}