From eb1113ba78c7e4d33759e4911abe28ee1f199a1c Mon Sep 17 00:00:00 2001 From: Marios Trivyzas Date: Tue, 2 Oct 2018 02:06:51 +0300 Subject: [PATCH] SQL: Fix function resolution (#34137) Remove CamelCase to CAMEL_CASE conversion when resolving a function. Only convert user input to upper case and then try to match with aliases or primary names. Keep the internal conversion FunctionName to FUNCTION__NAME which provides flexibility when registering functions by their class name. Fixes: #34114 --- .../xpack/sql/analysis/analyzer/Analyzer.java | 10 +- .../expression/function/FunctionRegistry.java | 19 ++-- .../function/FunctionRegistryTests.java | 107 +++++++++++++----- .../xpack/sql/tree/NodeSubclassTests.java | 3 +- 4 files changed, 96 insertions(+), 43 deletions(-) diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Analyzer.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Analyzer.java index 6de93853f02..c13160c9335 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Analyzer.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Analyzer.java @@ -771,9 +771,9 @@ public class Analyzer extends RuleExecutor { return uf; } - String normalizedName = functionRegistry.concreteFunctionName(name); + String functionName = functionRegistry.resolveAlias(name); - List list = getList(seen, normalizedName); + List list = getList(seen, functionName); // first try to resolve from seen functions if (!list.isEmpty()) { for (Function seenFunction : list) { @@ -784,11 +784,11 @@ public class Analyzer extends RuleExecutor { } // not seen before, use the registry - if (!functionRegistry.functionExists(name)) { - return uf.missing(normalizedName, functionRegistry.listFunctions()); + if (!functionRegistry.functionExists(functionName)) { + return uf.missing(functionName, functionRegistry.listFunctions()); } // TODO: look into Generator for significant terms, etc.. - FunctionDefinition def = functionRegistry.resolveFunction(normalizedName); + FunctionDefinition def = functionRegistry.resolveFunction(functionName); Function f = uf.buildResolved(timeZone, def); list.add(f); diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistry.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistry.java index b5d1b5c8167..caafd8294c6 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistry.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistry.java @@ -90,6 +90,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.TimeZone; import java.util.function.BiFunction; @@ -211,21 +212,23 @@ public class FunctionRegistry { } } - public FunctionDefinition resolveFunction(String name) { - FunctionDefinition def = defs.get(normalize(name)); + public FunctionDefinition resolveFunction(String functionName) { + FunctionDefinition def = defs.get(functionName); if (def == null) { - throw new SqlIllegalArgumentException("Cannot find function {}; this should have been caught during analysis", name); + throw new SqlIllegalArgumentException( + "Cannot find function {}; this should have been caught during analysis", + functionName); } return def; } - public String concreteFunctionName(String alias) { - String normalized = normalize(alias); - return aliases.getOrDefault(normalized, normalized); + public String resolveAlias(String alias) { + String upperCase = alias.toUpperCase(Locale.ROOT); + return aliases.getOrDefault(upperCase, upperCase); } - public boolean functionExists(String name) { - return defs.containsKey(normalize(name)); + public boolean functionExists(String functionName) { + return defs.containsKey(functionName); } public Collection listFunctions() { diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistryTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistryTests.java index bef75e3dc32..0ca75ee05d9 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistryTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistryTests.java @@ -6,31 +6,35 @@ package org.elasticsearch.xpack.sql.expression.function; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.sql.tree.Location; -import org.elasticsearch.xpack.sql.tree.LocationTests; -import org.elasticsearch.xpack.sql.tree.NodeInfo; -import org.elasticsearch.xpack.sql.type.DataType; +import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction; import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; import org.elasticsearch.xpack.sql.parser.ParsingException; +import org.elasticsearch.xpack.sql.tree.Location; +import org.elasticsearch.xpack.sql.tree.LocationTests; +import org.elasticsearch.xpack.sql.tree.NodeInfo; +import org.elasticsearch.xpack.sql.type.DataType; + import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.TimeZone; +import static java.util.Collections.emptyList; import static org.elasticsearch.xpack.sql.expression.function.FunctionRegistry.def; import static org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction.ResolutionType.DISTINCT; import static org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction.ResolutionType.EXTRACT; import static org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction.ResolutionType.STANDARD; import static org.hamcrest.Matchers.endsWith; +import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; -import static java.util.Collections.emptyList; public class FunctionRegistryTests extends ESTestCase { public void testNoArgFunction() { UnresolvedFunction ur = uf(STANDARD); - FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, Dummy::new))); + FunctionRegistry r = new FunctionRegistry(Collections.singletonList(def(DummyFunction.class, DummyFunction::new))); FunctionDefinition def = r.resolveFunction(ur.name()); assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); @@ -47,9 +51,10 @@ public class FunctionRegistryTests extends ESTestCase { public void testUnaryFunction() { UnresolvedFunction ur = uf(STANDARD, mock(Expression.class)); - FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, (Location l, Expression e) -> { + FunctionRegistry r = new FunctionRegistry(Collections.singletonList( + def(DummyFunction.class, (Location l, Expression e) -> { assertSame(e, ur.children().get(0)); - return new Dummy(l); + return new DummyFunction(l); }))); FunctionDefinition def = r.resolveFunction(ur.name()); assertFalse(def.datetime()); @@ -74,11 +79,12 @@ public class FunctionRegistryTests extends ESTestCase { public void testUnaryDistinctAwareFunction() { boolean urIsDistinct = randomBoolean(); UnresolvedFunction ur = uf(urIsDistinct ? DISTINCT : STANDARD, mock(Expression.class)); - FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, (Location l, Expression e, boolean distinct) -> { - assertEquals(urIsDistinct, distinct); - assertSame(e, ur.children().get(0)); - return new Dummy(l); - }))); + FunctionRegistry r = new FunctionRegistry(Collections.singletonList( + def(DummyFunction.class, (Location l, Expression e, boolean distinct) -> { + assertEquals(urIsDistinct, distinct); + assertSame(e, ur.children().get(0)); + return new DummyFunction(l); + }))); FunctionDefinition def = r.resolveFunction(ur.name()); assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); assertFalse(def.datetime()); @@ -98,11 +104,12 @@ public class FunctionRegistryTests extends ESTestCase { boolean urIsExtract = randomBoolean(); UnresolvedFunction ur = uf(urIsExtract ? EXTRACT : STANDARD, mock(Expression.class)); TimeZone providedTimeZone = randomTimeZone(); - FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, (Location l, Expression e, TimeZone tz) -> { - assertEquals(providedTimeZone, tz); - assertSame(e, ur.children().get(0)); - return new Dummy(l); - }))); + FunctionRegistry r = new FunctionRegistry(Collections.singletonList( + def(DummyFunction.class, (Location l, Expression e, TimeZone tz) -> { + assertEquals(providedTimeZone, tz); + assertSame(e, ur.children().get(0)); + return new DummyFunction(l); + }))); FunctionDefinition def = r.resolveFunction(ur.name()); assertEquals(ur.location(), ur.buildResolved(providedTimeZone, def).location()); assertTrue(def.datetime()); @@ -125,11 +132,12 @@ public class FunctionRegistryTests extends ESTestCase { public void testBinaryFunction() { UnresolvedFunction ur = uf(STANDARD, mock(Expression.class), mock(Expression.class)); - FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, (Location l, Expression lhs, Expression rhs) -> { - assertSame(lhs, ur.children().get(0)); - assertSame(rhs, ur.children().get(1)); - return new Dummy(l); - }))); + FunctionRegistry r = new FunctionRegistry(Collections.singletonList( + def(DummyFunction.class, (Location l, Expression lhs, Expression rhs) -> { + assertSame(lhs, ur.children().get(0)); + assertSame(rhs, ur.children().get(1)); + return new DummyFunction(l); + }))); FunctionDefinition def = r.resolveFunction(ur.name()); assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); assertFalse(def.datetime()); @@ -156,17 +164,60 @@ public class FunctionRegistryTests extends ESTestCase { assertThat(e.getMessage(), endsWith("expects exactly two arguments")); } - private UnresolvedFunction uf(UnresolvedFunction.ResolutionType resolutionType, Expression... children) { - return new UnresolvedFunction(LocationTests.randomLocation(), "dummy", resolutionType, Arrays.asList(children)); + public void testFunctionResolving() { + UnresolvedFunction ur = uf(STANDARD, mock(Expression.class)); + FunctionRegistry r = new FunctionRegistry( + Collections.singletonList(def(DummyFunction.class, (Location l, Expression e) -> { + assertSame(e, ur.children().get(0)); + return new DummyFunction(l); + }, "DUMMY_FUNC"))); + + // Resolve by primary name + FunctionDefinition def = r.resolveFunction(r.resolveAlias("DuMMy_FuncTIon")); + assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); + + def = r.resolveFunction(r.resolveAlias("Dummy_Function")); + assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); + + def = r.resolveFunction(r.resolveAlias("dummy_function")); + assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); + + def = r.resolveFunction(r.resolveAlias("DUMMY_FUNCTION")); + assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); + + // Resolve by alias + def = r.resolveFunction(r.resolveAlias("DumMy_FunC")); + assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); + + def = r.resolveFunction(r.resolveAlias("dummy_func")); + assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); + + def = r.resolveFunction(r.resolveAlias("DUMMY_FUNC")); + assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); + + // Not resolved + SqlIllegalArgumentException e = expectThrows(SqlIllegalArgumentException.class, + () -> r.resolveFunction(r.resolveAlias("DummyFunction"))); + assertThat(e.getMessage(), + is("Cannot find function DUMMYFUNCTION; this should have been caught during analysis")); + + e = expectThrows(SqlIllegalArgumentException.class, + () -> r.resolveFunction(r.resolveAlias("dummyFunction"))); + assertThat(e.getMessage(), + is("Cannot find function DUMMYFUNCTION; this should have been caught during analysis")); } - public static class Dummy extends ScalarFunction { - public Dummy(Location location) { + private UnresolvedFunction uf(UnresolvedFunction.ResolutionType resolutionType, Expression... children) { + return new UnresolvedFunction(LocationTests.randomLocation(), "DUMMY_FUNCTION", resolutionType, Arrays.asList(children)); + } + + public static class DummyFunction extends ScalarFunction { + public DummyFunction(Location location) { super(location, emptyList()); } @Override - protected NodeInfo info() { + protected NodeInfo info() { return NodeInfo.create(this); } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/tree/NodeSubclassTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/tree/NodeSubclassTests.java index fdc4f5db990..90fd7392960 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/tree/NodeSubclassTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/tree/NodeSubclassTests.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.sql.tree; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; - import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.PathUtils; import org.elasticsearch.test.ESTestCase; @@ -418,7 +417,7 @@ public class NodeSubclassTests> extends ESTestCas } } else if (toBuildClass == ChildrenAreAProperty.class) { /* - * While any subclass of Dummy will do here we want to prevent + * While any subclass of DummyFunction will do here we want to prevent * stack overflow so we use the one without children. */ if (argClass == Dummy.class) {