From fdac9e99fa8bc0334768d412db8b1178094a53bf Mon Sep 17 00:00:00 2001 From: Marios Trivyzas Date: Thu, 28 May 2020 11:30:14 +0200 Subject: [PATCH] SQL: Fix unecessary evaluation for CASE/IIF (#57159) (#57262) Previously, `CASE` and `IIF` when translated to painless scripts (used in GROUP BY, HAVING, WHERE) a custom `caseFunction` registered in the `InternalSqlScriptUtils` was used. This function received and array of arbitrary length: ```[condition1, result1, condition2, result2, ... elseResult]``` Painless doesn't know of the context and therefore is evaluating all conditions and results before invoking the `caseFunction` on them. As a consequence, erroneous result expressions (i.e. division by 0) where always evaluated despite of the guarding condition. Replace the `caseFunction` with painless ` ? : ` expressions to properly guard the result expressions and only evaluate the one for which its guarding condition evaluates to true (or of course the elseResult). As a bonus, this approach includes performance benefits since we avoid unnecessary evaluations of both conditions and result expressions. Fixes: #49672 (cherry picked from commit 9584b345d89f797bfb658212b928b9812804f02f) --- .../src/main/resources/conditionals.csv-spec | 40 +++++++++++++++++++ .../whitelist/InternalSqlScriptUtils.java | 5 --- .../predicate/conditional/Case.java | 24 ++++++++--- .../predicate/conditional/CaseProcessor.java | 12 ------ .../xpack/sql/plugin/sql_whitelist.txt | 1 - .../sql/planner/QueryTranslatorTests.java | 16 ++++---- 6 files changed, 66 insertions(+), 32 deletions(-) diff --git a/x-pack/plugin/sql/qa/server/src/main/resources/conditionals.csv-spec b/x-pack/plugin/sql/qa/server/src/main/resources/conditionals.csv-spec index e6453ad1420..bf72f9958cb 100644 --- a/x-pack/plugin/sql/qa/server/src/main/resources/conditionals.csv-spec +++ b/x-pack/plugin/sql/qa/server/src/main/resources/conditionals.csv-spec @@ -197,6 +197,24 @@ END as lang_skills FROM test_emp GROUP BY lang_skills ORDER BY 2; 10 | zero ; +caseGroupByProtectedDivisionByZero +schema::x:i +SELECT CASE WHEN languages = 1 THEN NULL ELSE ( salary / (languages - 1) ) END AS x FROM test_emp GROUP BY 1 ORDER BY 1 LIMIT 10; + + x +--------------- +null +6331 +6486 +7780 +7974 +8068 +8489 +8935 +9043 +9071 +; + caseGroupByAndHaving schema::count:l|gender:s|languages:byte SELECT count(*) AS count, gender, languages FROM test_emp @@ -353,6 +371,28 @@ IIF(NVL(languages, 0) = 0, 'zero', 10 |zero ; +iifGroupByProtectedDivisionByZero +schema::count:l|x:i +SELECT count(*) AS count, +IIF(languages - 1 = 0, 0, + IIF(languages - 1 = 1, (salary / 10000) / (languages - 1), + IIF(languages - 1 = 2, (salary / 10000) / languages, + IIF(languages - 1 = 3, (salary / 10000) / (languages + 1), + (salary / 10000) / (languages + 2))))) as x FROM test_emp GROUP BY x ORDER BY 2; + + count | x +---------------+--------------- +10 |null +50 |0 +14 |1 +8 |2 +4 |3 +5 |4 +6 |5 +2 |6 +1 |7 +; + iifGroupByAndHaving schema::count:l|gender:s|languages:byte SELECT count(*) AS count, gender, languages FROM test_emp diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/whitelist/InternalSqlScriptUtils.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/whitelist/InternalSqlScriptUtils.java index d7911f85511..ec07becd609 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/whitelist/InternalSqlScriptUtils.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/whitelist/InternalSqlScriptUtils.java @@ -38,7 +38,6 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.string.SubstringFu import org.elasticsearch.xpack.sql.expression.literal.geo.GeoShape; import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalDayTime; import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalYearMonth; -import org.elasticsearch.xpack.sql.expression.predicate.conditional.CaseProcessor; import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalProcessor.ConditionalOperation; import org.elasticsearch.xpack.sql.expression.predicate.conditional.NullIfProcessor; import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.SqlBinaryArithmeticOperation; @@ -67,10 +66,6 @@ public class InternalSqlScriptUtils extends InternalQlScriptUtils { // // Conditional // - public static Object caseFunction(List expressions) { - return CaseProcessor.apply(expressions); - } - public static Object coalesce(List expressions) { return ConditionalOperation.COALESCE.apply(expressions); } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/Case.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/Case.java index 362d14b2df0..9a8dfcf87b1 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/Case.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/Case.java @@ -10,6 +10,7 @@ import org.elasticsearch.xpack.ql.expression.Expressions; import org.elasticsearch.xpack.ql.expression.gen.pipeline.Pipe; import org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder; import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate; +import org.elasticsearch.xpack.ql.expression.gen.script.Scripts; import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.ql.type.DataType; @@ -19,7 +20,6 @@ import org.elasticsearch.xpack.sql.type.SqlDataTypes; import java.util.ArrayList; import java.util.List; -import java.util.StringJoiner; import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder.paramsBuilder; @@ -161,14 +161,26 @@ public class Case extends ConditionalFunction { } templates.add(asScript(elseResult)); - StringJoiner template = new StringJoiner(",", "{sql}.caseFunction([", "])"); + // Use painless ?: expressions to prevent evaluation of return expression + // if the condition which guards it evaluates to false (e.g. division by 0) + StringBuilder sb = new StringBuilder(); ParamsBuilder params = paramsBuilder(); - - for (ScriptTemplate scriptTemplate : templates) { - template.add(scriptTemplate.template()); + for (int i = 0; i < templates.size(); i++) { + ScriptTemplate scriptTemplate = templates.get(i); + if (i < templates.size() - 1) { + if (i % 2 == 0) { + // painless ? : operator expects primitive boolean, thus we use nullSafeFilter + // to convert object Boolean to primitive boolean (null => false) + sb.append(Scripts.nullSafeFilter(scriptTemplate).template()).append(" ? "); + } else { + sb.append(scriptTemplate.template()).append(" : "); + } + } else { + sb.append(scriptTemplate.template()); + } params.script(scriptTemplate.params()); } - return new ScriptTemplate(formatTemplate(template.toString()), params.build(), dataType()); + return new ScriptTemplate(formatTemplate(sb.toString()), params.build(), dataType()); } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseProcessor.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseProcessor.java index dc4d083d53c..d18b4b2fc24 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseProcessor.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseProcessor.java @@ -50,18 +50,6 @@ public class CaseProcessor implements Processor { return processors.get(processors.size() - 1).process(input); } - public static Object apply(List objects) { - // Check every condition in sequence and if it evaluates to TRUE, - // evaluate and return the result associated with that condition. - for (int i = 0; i < objects.size() - 2; i += 2) { - if (objects.get(i) == Boolean.TRUE) { - return objects.get(i + 1); - } - } - // resort to default value - return objects.get(objects.size() - 1); - } - @Override public boolean equals(Object o) { if (this == o) { diff --git a/x-pack/plugin/sql/src/main/resources/org/elasticsearch/xpack/sql/plugin/sql_whitelist.txt b/x-pack/plugin/sql/src/main/resources/org/elasticsearch/xpack/sql/plugin/sql_whitelist.txt index 57059886601..66b5010918e 100644 --- a/x-pack/plugin/sql/src/main/resources/org/elasticsearch/xpack/sql/plugin/sql_whitelist.txt +++ b/x-pack/plugin/sql/src/main/resources/org/elasticsearch/xpack/sql/plugin/sql_whitelist.txt @@ -71,7 +71,6 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS # # Conditional # - def caseFunction(java.util.List) def coalesce(java.util.List) def greatest(java.util.List) def least(java.util.List) diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java index 1970d28b426..3fbed84b718 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java @@ -674,9 +674,9 @@ public class QueryTranslatorTests extends ESTestCase { GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings()); assertNotNull(groupingContext); ScriptTemplate scriptTemplate = groupingContext.tail.script(); - assertEquals("InternalSqlScriptUtils.caseFunction([InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(" - + "doc,params.v0),params.v1),params.v2,InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(" + - "doc,params.v3),params.v4),params.v5,params.v6])", + assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(doc,params.v0)," + + "params.v1)) ? params.v2 : InternalQlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.regex(InternalQlScriptUtils." + + "docValue(doc,params.v3),params.v4)) ? params.v5 : params.v6", scriptTemplate.toString()); assertEquals("[{v=keyword}, {v=^.*foo.*$}, {v=1}, {v=keyword}, {v=.*bar.*}, {v=2}, {v=3}]", scriptTemplate.params().toString()); @@ -1194,9 +1194,9 @@ public class QueryTranslatorTests extends ESTestCase { GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings()); assertNotNull(groupingContext); ScriptTemplate scriptTemplate = groupingContext.tail.script(); - assertEquals("InternalSqlScriptUtils.caseFunction([InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(" + "" - + "doc,params.v0),params.v1),params.v2,InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v3)," + - "params.v4),params.v5,params.v6])", + assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v0)," + + "params.v1)) ? params.v2 : InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(" + + "doc,params.v3),params.v4)) ? params.v5 : params.v6", scriptTemplate.toString()); assertEquals("[{v=int}, {v=10}, {v=foo}, {v=int}, {v=20}, {v=bar}, {v=default}]", scriptTemplate.params().toString()); } @@ -1209,8 +1209,8 @@ public class QueryTranslatorTests extends ESTestCase { GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings()); assertNotNull(groupingContext); ScriptTemplate scriptTemplate = groupingContext.tail.script(); - assertEquals("InternalSqlScriptUtils.caseFunction([InternalQlScriptUtils.gt(" + - "InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2,params.v3])", + assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v0)," + + "params.v1)) ? params.v2 : params.v3", scriptTemplate.toString()); assertEquals("[{v=int}, {v=20}, {v=foo}, {v=bar}]", scriptTemplate.params().toString()); }