Previously, `CASE` and `IIF` when translated to painless scripts (used in GROUP BY, HAVING, WHERE) a custom `caseFunction` registered in the `InternalSqlScriptUtils` was used. This function received and array of arbitrary length: ```[condition1, result1, condition2, result2, ... elseResult]``` Painless doesn't know of the context and therefore is evaluating all conditions and results before invoking the `caseFunction` on them. As a consequence, erroneous result expressions (i.e. division by 0) where always evaluated despite of the guarding condition. Replace the `caseFunction` with painless `<cond> ? <res1> : <res2>` expressions to properly guard the result expressions and only evaluate the one for which its guarding condition evaluates to true (or of course the elseResult). As a bonus, this approach includes performance benefits since we avoid unnecessary evaluations of both conditions and result expressions. Fixes: #49672 (cherry picked from commit 9584b345d89f797bfb658212b928b9812804f02f)
This commit is contained in:
parent
e1cab4feb4
commit
fdac9e99fa
|
@ -197,6 +197,24 @@ END as lang_skills FROM test_emp GROUP BY lang_skills ORDER BY 2;
|
||||||
10 | zero
|
10 | zero
|
||||||
;
|
;
|
||||||
|
|
||||||
|
caseGroupByProtectedDivisionByZero
|
||||||
|
schema::x:i
|
||||||
|
SELECT CASE WHEN languages = 1 THEN NULL ELSE ( salary / (languages - 1) ) END AS x FROM test_emp GROUP BY 1 ORDER BY 1 LIMIT 10;
|
||||||
|
|
||||||
|
x
|
||||||
|
---------------
|
||||||
|
null
|
||||||
|
6331
|
||||||
|
6486
|
||||||
|
7780
|
||||||
|
7974
|
||||||
|
8068
|
||||||
|
8489
|
||||||
|
8935
|
||||||
|
9043
|
||||||
|
9071
|
||||||
|
;
|
||||||
|
|
||||||
caseGroupByAndHaving
|
caseGroupByAndHaving
|
||||||
schema::count:l|gender:s|languages:byte
|
schema::count:l|gender:s|languages:byte
|
||||||
SELECT count(*) AS count, gender, languages FROM test_emp
|
SELECT count(*) AS count, gender, languages FROM test_emp
|
||||||
|
@ -353,6 +371,28 @@ IIF(NVL(languages, 0) = 0, 'zero',
|
||||||
10 |zero
|
10 |zero
|
||||||
;
|
;
|
||||||
|
|
||||||
|
iifGroupByProtectedDivisionByZero
|
||||||
|
schema::count:l|x:i
|
||||||
|
SELECT count(*) AS count,
|
||||||
|
IIF(languages - 1 = 0, 0,
|
||||||
|
IIF(languages - 1 = 1, (salary / 10000) / (languages - 1),
|
||||||
|
IIF(languages - 1 = 2, (salary / 10000) / languages,
|
||||||
|
IIF(languages - 1 = 3, (salary / 10000) / (languages + 1),
|
||||||
|
(salary / 10000) / (languages + 2))))) as x FROM test_emp GROUP BY x ORDER BY 2;
|
||||||
|
|
||||||
|
count | x
|
||||||
|
---------------+---------------
|
||||||
|
10 |null
|
||||||
|
50 |0
|
||||||
|
14 |1
|
||||||
|
8 |2
|
||||||
|
4 |3
|
||||||
|
5 |4
|
||||||
|
6 |5
|
||||||
|
2 |6
|
||||||
|
1 |7
|
||||||
|
;
|
||||||
|
|
||||||
iifGroupByAndHaving
|
iifGroupByAndHaving
|
||||||
schema::count:l|gender:s|languages:byte
|
schema::count:l|gender:s|languages:byte
|
||||||
SELECT count(*) AS count, gender, languages FROM test_emp
|
SELECT count(*) AS count, gender, languages FROM test_emp
|
||||||
|
|
|
@ -38,7 +38,6 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.string.SubstringFu
|
||||||
import org.elasticsearch.xpack.sql.expression.literal.geo.GeoShape;
|
import org.elasticsearch.xpack.sql.expression.literal.geo.GeoShape;
|
||||||
import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalDayTime;
|
import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalDayTime;
|
||||||
import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalYearMonth;
|
import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalYearMonth;
|
||||||
import org.elasticsearch.xpack.sql.expression.predicate.conditional.CaseProcessor;
|
|
||||||
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalProcessor.ConditionalOperation;
|
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalProcessor.ConditionalOperation;
|
||||||
import org.elasticsearch.xpack.sql.expression.predicate.conditional.NullIfProcessor;
|
import org.elasticsearch.xpack.sql.expression.predicate.conditional.NullIfProcessor;
|
||||||
import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.SqlBinaryArithmeticOperation;
|
import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.SqlBinaryArithmeticOperation;
|
||||||
|
@ -67,10 +66,6 @@ public class InternalSqlScriptUtils extends InternalQlScriptUtils {
|
||||||
//
|
//
|
||||||
// Conditional
|
// Conditional
|
||||||
//
|
//
|
||||||
public static Object caseFunction(List<Object> expressions) {
|
|
||||||
return CaseProcessor.apply(expressions);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Object coalesce(List<Object> expressions) {
|
public static Object coalesce(List<Object> expressions) {
|
||||||
return ConditionalOperation.COALESCE.apply(expressions);
|
return ConditionalOperation.COALESCE.apply(expressions);
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.xpack.ql.expression.Expressions;
|
||||||
import org.elasticsearch.xpack.ql.expression.gen.pipeline.Pipe;
|
import org.elasticsearch.xpack.ql.expression.gen.pipeline.Pipe;
|
||||||
import org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder;
|
import org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder;
|
||||||
import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
|
import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
|
||||||
|
import org.elasticsearch.xpack.ql.expression.gen.script.Scripts;
|
||||||
import org.elasticsearch.xpack.ql.tree.NodeInfo;
|
import org.elasticsearch.xpack.ql.tree.NodeInfo;
|
||||||
import org.elasticsearch.xpack.ql.tree.Source;
|
import org.elasticsearch.xpack.ql.tree.Source;
|
||||||
import org.elasticsearch.xpack.ql.type.DataType;
|
import org.elasticsearch.xpack.ql.type.DataType;
|
||||||
|
@ -19,7 +20,6 @@ import org.elasticsearch.xpack.sql.type.SqlDataTypes;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.StringJoiner;
|
|
||||||
|
|
||||||
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
|
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
|
||||||
import static org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder.paramsBuilder;
|
import static org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder.paramsBuilder;
|
||||||
|
@ -161,14 +161,26 @@ public class Case extends ConditionalFunction {
|
||||||
}
|
}
|
||||||
templates.add(asScript(elseResult));
|
templates.add(asScript(elseResult));
|
||||||
|
|
||||||
StringJoiner template = new StringJoiner(",", "{sql}.caseFunction([", "])");
|
// Use painless ?: expressions to prevent evaluation of return expression
|
||||||
|
// if the condition which guards it evaluates to false (e.g. division by 0)
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
ParamsBuilder params = paramsBuilder();
|
ParamsBuilder params = paramsBuilder();
|
||||||
|
for (int i = 0; i < templates.size(); i++) {
|
||||||
for (ScriptTemplate scriptTemplate : templates) {
|
ScriptTemplate scriptTemplate = templates.get(i);
|
||||||
template.add(scriptTemplate.template());
|
if (i < templates.size() - 1) {
|
||||||
|
if (i % 2 == 0) {
|
||||||
|
// painless ? : operator expects primitive boolean, thus we use nullSafeFilter
|
||||||
|
// to convert object Boolean to primitive boolean (null => false)
|
||||||
|
sb.append(Scripts.nullSafeFilter(scriptTemplate).template()).append(" ? ");
|
||||||
|
} else {
|
||||||
|
sb.append(scriptTemplate.template()).append(" : ");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sb.append(scriptTemplate.template());
|
||||||
|
}
|
||||||
params.script(scriptTemplate.params());
|
params.script(scriptTemplate.params());
|
||||||
}
|
}
|
||||||
|
|
||||||
return new ScriptTemplate(formatTemplate(template.toString()), params.build(), dataType());
|
return new ScriptTemplate(formatTemplate(sb.toString()), params.build(), dataType());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,18 +50,6 @@ public class CaseProcessor implements Processor {
|
||||||
return processors.get(processors.size() - 1).process(input);
|
return processors.get(processors.size() - 1).process(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Object apply(List<Object> objects) {
|
|
||||||
// Check every condition in sequence and if it evaluates to TRUE,
|
|
||||||
// evaluate and return the result associated with that condition.
|
|
||||||
for (int i = 0; i < objects.size() - 2; i += 2) {
|
|
||||||
if (objects.get(i) == Boolean.TRUE) {
|
|
||||||
return objects.get(i + 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// resort to default value
|
|
||||||
return objects.get(objects.size() - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
if (this == o) {
|
if (this == o) {
|
||||||
|
|
|
@ -71,7 +71,6 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
|
||||||
#
|
#
|
||||||
# Conditional
|
# Conditional
|
||||||
#
|
#
|
||||||
def caseFunction(java.util.List)
|
|
||||||
def coalesce(java.util.List)
|
def coalesce(java.util.List)
|
||||||
def greatest(java.util.List)
|
def greatest(java.util.List)
|
||||||
def least(java.util.List)
|
def least(java.util.List)
|
||||||
|
|
|
@ -674,9 +674,9 @@ public class QueryTranslatorTests extends ESTestCase {
|
||||||
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
|
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
|
||||||
assertNotNull(groupingContext);
|
assertNotNull(groupingContext);
|
||||||
ScriptTemplate scriptTemplate = groupingContext.tail.script();
|
ScriptTemplate scriptTemplate = groupingContext.tail.script();
|
||||||
assertEquals("InternalSqlScriptUtils.caseFunction([InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue("
|
assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(doc,params.v0)," +
|
||||||
+ "doc,params.v0),params.v1),params.v2,InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(" +
|
"params.v1)) ? params.v2 : InternalQlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.regex(InternalQlScriptUtils." +
|
||||||
"doc,params.v3),params.v4),params.v5,params.v6])",
|
"docValue(doc,params.v3),params.v4)) ? params.v5 : params.v6",
|
||||||
scriptTemplate.toString());
|
scriptTemplate.toString());
|
||||||
assertEquals("[{v=keyword}, {v=^.*foo.*$}, {v=1}, {v=keyword}, {v=.*bar.*}, {v=2}, {v=3}]",
|
assertEquals("[{v=keyword}, {v=^.*foo.*$}, {v=1}, {v=keyword}, {v=.*bar.*}, {v=2}, {v=3}]",
|
||||||
scriptTemplate.params().toString());
|
scriptTemplate.params().toString());
|
||||||
|
@ -1194,9 +1194,9 @@ public class QueryTranslatorTests extends ESTestCase {
|
||||||
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
|
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
|
||||||
assertNotNull(groupingContext);
|
assertNotNull(groupingContext);
|
||||||
ScriptTemplate scriptTemplate = groupingContext.tail.script();
|
ScriptTemplate scriptTemplate = groupingContext.tail.script();
|
||||||
assertEquals("InternalSqlScriptUtils.caseFunction([InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(" + ""
|
assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v0)," +
|
||||||
+ "doc,params.v0),params.v1),params.v2,InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v3)," +
|
"params.v1)) ? params.v2 : InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(" +
|
||||||
"params.v4),params.v5,params.v6])",
|
"doc,params.v3),params.v4)) ? params.v5 : params.v6",
|
||||||
scriptTemplate.toString());
|
scriptTemplate.toString());
|
||||||
assertEquals("[{v=int}, {v=10}, {v=foo}, {v=int}, {v=20}, {v=bar}, {v=default}]", scriptTemplate.params().toString());
|
assertEquals("[{v=int}, {v=10}, {v=foo}, {v=int}, {v=20}, {v=bar}, {v=default}]", scriptTemplate.params().toString());
|
||||||
}
|
}
|
||||||
|
@ -1209,8 +1209,8 @@ public class QueryTranslatorTests extends ESTestCase {
|
||||||
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
|
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
|
||||||
assertNotNull(groupingContext);
|
assertNotNull(groupingContext);
|
||||||
ScriptTemplate scriptTemplate = groupingContext.tail.script();
|
ScriptTemplate scriptTemplate = groupingContext.tail.script();
|
||||||
assertEquals("InternalSqlScriptUtils.caseFunction([InternalQlScriptUtils.gt(" +
|
assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v0)," +
|
||||||
"InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2,params.v3])",
|
"params.v1)) ? params.v2 : params.v3",
|
||||||
scriptTemplate.toString());
|
scriptTemplate.toString());
|
||||||
assertEquals("[{v=int}, {v=20}, {v=foo}, {v=bar}]", scriptTemplate.params().toString());
|
assertEquals("[{v=int}, {v=20}, {v=foo}, {v=bar}]", scriptTemplate.params().toString());
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue