SQL: Fix unecessary evaluation for CASE/IIF (#57159) (#57262)

Previously, `CASE` and `IIF` when translated to painless scripts
(used in GROUP BY, HAVING, WHERE) a custom `caseFunction`
registered in the `InternalSqlScriptUtils` was used. This function
received and array of arbitrary length:
```[condition1, result1, condition2, result2, ... elseResult]```

Painless doesn't know of the context and therefore is evaluating
all conditions and results before invoking the `caseFunction` on them.
As a consequence, erroneous result expressions (i.e. division by 0)
where always evaluated despite of the guarding condition.

Replace the `caseFunction` with painless `<cond> ? <res1> : <res2>`
expressions to properly guard the result expressions and only evaluate
the one for which its guarding condition evaluates to true (or of course
the elseResult).

As a bonus, this approach includes performance benefits since we avoid
unnecessary evaluations of both conditions and result expressions.

Fixes: #49672
(cherry picked from commit 9584b345d89f797bfb658212b928b9812804f02f)
This commit is contained in:
Marios Trivyzas 2020-05-28 11:30:14 +02:00 committed by GitHub
parent e1cab4feb4
commit fdac9e99fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 66 additions and 32 deletions

View File

@ -197,6 +197,24 @@ END as lang_skills FROM test_emp GROUP BY lang_skills ORDER BY 2;
10 | zero 10 | zero
; ;
caseGroupByProtectedDivisionByZero
schema::x:i
SELECT CASE WHEN languages = 1 THEN NULL ELSE ( salary / (languages - 1) ) END AS x FROM test_emp GROUP BY 1 ORDER BY 1 LIMIT 10;
x
---------------
null
6331
6486
7780
7974
8068
8489
8935
9043
9071
;
caseGroupByAndHaving caseGroupByAndHaving
schema::count:l|gender:s|languages:byte schema::count:l|gender:s|languages:byte
SELECT count(*) AS count, gender, languages FROM test_emp SELECT count(*) AS count, gender, languages FROM test_emp
@ -353,6 +371,28 @@ IIF(NVL(languages, 0) = 0, 'zero',
10 |zero 10 |zero
; ;
iifGroupByProtectedDivisionByZero
schema::count:l|x:i
SELECT count(*) AS count,
IIF(languages - 1 = 0, 0,
IIF(languages - 1 = 1, (salary / 10000) / (languages - 1),
IIF(languages - 1 = 2, (salary / 10000) / languages,
IIF(languages - 1 = 3, (salary / 10000) / (languages + 1),
(salary / 10000) / (languages + 2))))) as x FROM test_emp GROUP BY x ORDER BY 2;
count | x
---------------+---------------
10 |null
50 |0
14 |1
8 |2
4 |3
5 |4
6 |5
2 |6
1 |7
;
iifGroupByAndHaving iifGroupByAndHaving
schema::count:l|gender:s|languages:byte schema::count:l|gender:s|languages:byte
SELECT count(*) AS count, gender, languages FROM test_emp SELECT count(*) AS count, gender, languages FROM test_emp

View File

@ -38,7 +38,6 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.string.SubstringFu
import org.elasticsearch.xpack.sql.expression.literal.geo.GeoShape; import org.elasticsearch.xpack.sql.expression.literal.geo.GeoShape;
import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalDayTime; import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalDayTime;
import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalYearMonth; import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalYearMonth;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.CaseProcessor;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalProcessor.ConditionalOperation; import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalProcessor.ConditionalOperation;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.NullIfProcessor; import org.elasticsearch.xpack.sql.expression.predicate.conditional.NullIfProcessor;
import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.SqlBinaryArithmeticOperation; import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.SqlBinaryArithmeticOperation;
@ -67,10 +66,6 @@ public class InternalSqlScriptUtils extends InternalQlScriptUtils {
// //
// Conditional // Conditional
// //
public static Object caseFunction(List<Object> expressions) {
return CaseProcessor.apply(expressions);
}
public static Object coalesce(List<Object> expressions) { public static Object coalesce(List<Object> expressions) {
return ConditionalOperation.COALESCE.apply(expressions); return ConditionalOperation.COALESCE.apply(expressions);
} }

View File

@ -10,6 +10,7 @@ import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.gen.pipeline.Pipe; import org.elasticsearch.xpack.ql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder; import org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder;
import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate; import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.ql.expression.gen.script.Scripts;
import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.NodeInfo;
import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType; import org.elasticsearch.xpack.ql.type.DataType;
@ -19,7 +20,6 @@ import org.elasticsearch.xpack.sql.type.SqlDataTypes;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.StringJoiner;
import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder.paramsBuilder; import static org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder.paramsBuilder;
@ -161,14 +161,26 @@ public class Case extends ConditionalFunction {
} }
templates.add(asScript(elseResult)); templates.add(asScript(elseResult));
StringJoiner template = new StringJoiner(",", "{sql}.caseFunction([", "])"); // Use painless ?: expressions to prevent evaluation of return expression
// if the condition which guards it evaluates to false (e.g. division by 0)
StringBuilder sb = new StringBuilder();
ParamsBuilder params = paramsBuilder(); ParamsBuilder params = paramsBuilder();
for (int i = 0; i < templates.size(); i++) {
for (ScriptTemplate scriptTemplate : templates) { ScriptTemplate scriptTemplate = templates.get(i);
template.add(scriptTemplate.template()); if (i < templates.size() - 1) {
if (i % 2 == 0) {
// painless ? : operator expects primitive boolean, thus we use nullSafeFilter
// to convert object Boolean to primitive boolean (null => false)
sb.append(Scripts.nullSafeFilter(scriptTemplate).template()).append(" ? ");
} else {
sb.append(scriptTemplate.template()).append(" : ");
}
} else {
sb.append(scriptTemplate.template());
}
params.script(scriptTemplate.params()); params.script(scriptTemplate.params());
} }
return new ScriptTemplate(formatTemplate(template.toString()), params.build(), dataType()); return new ScriptTemplate(formatTemplate(sb.toString()), params.build(), dataType());
} }
} }

View File

@ -50,18 +50,6 @@ public class CaseProcessor implements Processor {
return processors.get(processors.size() - 1).process(input); return processors.get(processors.size() - 1).process(input);
} }
public static Object apply(List<Object> objects) {
// Check every condition in sequence and if it evaluates to TRUE,
// evaluate and return the result associated with that condition.
for (int i = 0; i < objects.size() - 2; i += 2) {
if (objects.get(i) == Boolean.TRUE) {
return objects.get(i + 1);
}
}
// resort to default value
return objects.get(objects.size() - 1);
}
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) { if (this == o) {

View File

@ -71,7 +71,6 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
# #
# Conditional # Conditional
# #
def caseFunction(java.util.List)
def coalesce(java.util.List) def coalesce(java.util.List)
def greatest(java.util.List) def greatest(java.util.List)
def least(java.util.List) def least(java.util.List)

View File

@ -674,9 +674,9 @@ public class QueryTranslatorTests extends ESTestCase {
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings()); GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
assertNotNull(groupingContext); assertNotNull(groupingContext);
ScriptTemplate scriptTemplate = groupingContext.tail.script(); ScriptTemplate scriptTemplate = groupingContext.tail.script();
assertEquals("InternalSqlScriptUtils.caseFunction([InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(" assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(doc,params.v0)," +
+ "doc,params.v0),params.v1),params.v2,InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(" + "params.v1)) ? params.v2 : InternalQlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.regex(InternalQlScriptUtils." +
"doc,params.v3),params.v4),params.v5,params.v6])", "docValue(doc,params.v3),params.v4)) ? params.v5 : params.v6",
scriptTemplate.toString()); scriptTemplate.toString());
assertEquals("[{v=keyword}, {v=^.*foo.*$}, {v=1}, {v=keyword}, {v=.*bar.*}, {v=2}, {v=3}]", assertEquals("[{v=keyword}, {v=^.*foo.*$}, {v=1}, {v=keyword}, {v=.*bar.*}, {v=2}, {v=3}]",
scriptTemplate.params().toString()); scriptTemplate.params().toString());
@ -1194,9 +1194,9 @@ public class QueryTranslatorTests extends ESTestCase {
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings()); GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
assertNotNull(groupingContext); assertNotNull(groupingContext);
ScriptTemplate scriptTemplate = groupingContext.tail.script(); ScriptTemplate scriptTemplate = groupingContext.tail.script();
assertEquals("InternalSqlScriptUtils.caseFunction([InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(" + "" assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v0)," +
+ "doc,params.v0),params.v1),params.v2,InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v3)," + "params.v1)) ? params.v2 : InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(" +
"params.v4),params.v5,params.v6])", "doc,params.v3),params.v4)) ? params.v5 : params.v6",
scriptTemplate.toString()); scriptTemplate.toString());
assertEquals("[{v=int}, {v=10}, {v=foo}, {v=int}, {v=20}, {v=bar}, {v=default}]", scriptTemplate.params().toString()); assertEquals("[{v=int}, {v=10}, {v=foo}, {v=int}, {v=20}, {v=bar}, {v=default}]", scriptTemplate.params().toString());
} }
@ -1209,8 +1209,8 @@ public class QueryTranslatorTests extends ESTestCase {
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings()); GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
assertNotNull(groupingContext); assertNotNull(groupingContext);
ScriptTemplate scriptTemplate = groupingContext.tail.script(); ScriptTemplate scriptTemplate = groupingContext.tail.script();
assertEquals("InternalSqlScriptUtils.caseFunction([InternalQlScriptUtils.gt(" + assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v0)," +
"InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2,params.v3])", "params.v1)) ? params.v2 : params.v3",
scriptTemplate.toString()); scriptTemplate.toString());
assertEquals("[{v=int}, {v=20}, {v=foo}, {v=bar}]", scriptTemplate.params().toString()); assertEquals("[{v=int}, {v=20}, {v=foo}, {v=bar}]", scriptTemplate.params().toString());
} }