diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/Case.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/Case.java index 59ec2c38d00..7536612a67d 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/Case.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/Case.java @@ -58,6 +58,7 @@ public class Case extends ConditionalFunction { for (IfConditional conditional : conditions) { dataType = DataTypeConversion.commonType(dataType, conditional.dataType()); } + dataType = DataTypeConversion.commonType(dataType, elseResult.dataType()); } } return dataType; diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseTests.java index 807be397f91..b4de311c920 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseTests.java @@ -6,19 +6,23 @@ package org.elasticsearch.xpack.sql.expression.predicate.conditional; import org.elasticsearch.xpack.sql.expression.Expression; +import org.elasticsearch.xpack.sql.expression.Literal; import org.elasticsearch.xpack.sql.expression.function.scalar.FunctionTestUtils; import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.sql.tree.AbstractNodeTestCase; import org.elasticsearch.xpack.sql.tree.NodeSubclassTests; import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.tree.SourceTests; +import org.elasticsearch.xpack.sql.type.DataType; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Objects; import static org.elasticsearch.xpack.sql.expression.function.scalar.FunctionTestUtils.randomIntLiteral; import static org.elasticsearch.xpack.sql.expression.function.scalar.FunctionTestUtils.randomStringLiteral; +import static org.elasticsearch.xpack.sql.tree.Source.EMPTY; import static org.elasticsearch.xpack.sql.tree.SourceTests.randomSource; /** @@ -77,6 +81,42 @@ public class CaseTests extends AbstractNodeTestCase { assertEquals(new Case(c.source(), newChildren), c.replaceChildren(newChildren)); } + public void testDataTypes() { + // CASE WHEN 1 = 1 THEN NULL + // ELSE 'default' + // END + Case c = new Case(EMPTY, Arrays.asList( + new IfConditional(EMPTY, new Equals(EMPTY, Literal.of(EMPTY, 1), Literal.of(EMPTY, 1)), Literal.NULL), + Literal.of(EMPTY, "default"))); + assertEquals(DataType.KEYWORD, c.dataType()); + + // CASE WHEN 1 = 1 THEN 'foo' + // ELSE NULL + // END + c = new Case(EMPTY, Arrays.asList( + new IfConditional(EMPTY, new Equals(EMPTY, Literal.of(EMPTY, 1), Literal.of(EMPTY, 1)), Literal.of(EMPTY, "foo")), + Literal.NULL)); + assertEquals(DataType.KEYWORD, c.dataType()); + + // CASE WHEN 1 = 1 THEN NULL + // ELSE NULL + // END + c = new Case(EMPTY, Arrays.asList( + new IfConditional(EMPTY, new Equals(EMPTY, Literal.of(EMPTY, 1), Literal.of(EMPTY, 1)), Literal.NULL), + Literal.NULL)); + assertEquals(DataType.NULL, c.dataType()); + + // CASE WHEN 1 = 1 THEN NULL + // WHEN 2 = 2 THEN 'foo' + // ELSE NULL + // END + c = new Case(EMPTY, Arrays.asList( + new IfConditional(EMPTY, new Equals(EMPTY, Literal.of(EMPTY, 1), Literal.of(EMPTY, 1)), Literal.NULL), + new IfConditional(EMPTY, new Equals(EMPTY, Literal.of(EMPTY, 2), Literal.of(EMPTY, 2)), Literal.of(EMPTY, "foo")), + Literal.NULL)); + assertEquals(DataType.KEYWORD, c.dataType()); + } + private List mutateChildren(Case c) { boolean removeConditional = randomBoolean(); List expressions = new ArrayList<>(c.children().size());