SQL: Fix function resolution (#34137)

Remove CamelCase to CAMEL_CASE conversion when resolving
a function. Only convert user input to upper case and then
try to match with aliases or primary names.

Keep the internal conversion FunctionName to FUNCTION__NAME
which provides flexibility when registering functions by their class
name.

Fixes: #34114
This commit is contained in:
Marios Trivyzas 2018-10-02 02:06:51 +03:00 committed by GitHub
parent ad3218b4ab
commit eb1113ba78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 43 deletions

View File

@ -771,9 +771,9 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
return uf; return uf;
} }
String normalizedName = functionRegistry.concreteFunctionName(name); String functionName = functionRegistry.resolveAlias(name);
List<Function> list = getList(seen, normalizedName); List<Function> list = getList(seen, functionName);
// first try to resolve from seen functions // first try to resolve from seen functions
if (!list.isEmpty()) { if (!list.isEmpty()) {
for (Function seenFunction : list) { for (Function seenFunction : list) {
@ -784,11 +784,11 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
} }
// not seen before, use the registry // not seen before, use the registry
if (!functionRegistry.functionExists(name)) { if (!functionRegistry.functionExists(functionName)) {
return uf.missing(normalizedName, functionRegistry.listFunctions()); return uf.missing(functionName, functionRegistry.listFunctions());
} }
// TODO: look into Generator for significant terms, etc.. // TODO: look into Generator for significant terms, etc..
FunctionDefinition def = functionRegistry.resolveFunction(normalizedName); FunctionDefinition def = functionRegistry.resolveFunction(functionName);
Function f = uf.buildResolved(timeZone, def); Function f = uf.buildResolved(timeZone, def);
list.add(f); list.add(f);

View File

@ -90,6 +90,7 @@ import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.TimeZone; import java.util.TimeZone;
import java.util.function.BiFunction; import java.util.function.BiFunction;
@ -211,21 +212,23 @@ public class FunctionRegistry {
} }
} }
public FunctionDefinition resolveFunction(String name) { public FunctionDefinition resolveFunction(String functionName) {
FunctionDefinition def = defs.get(normalize(name)); FunctionDefinition def = defs.get(functionName);
if (def == null) { if (def == null) {
throw new SqlIllegalArgumentException("Cannot find function {}; this should have been caught during analysis", name); throw new SqlIllegalArgumentException(
"Cannot find function {}; this should have been caught during analysis",
functionName);
} }
return def; return def;
} }
public String concreteFunctionName(String alias) { public String resolveAlias(String alias) {
String normalized = normalize(alias); String upperCase = alias.toUpperCase(Locale.ROOT);
return aliases.getOrDefault(normalized, normalized); return aliases.getOrDefault(upperCase, upperCase);
} }
public boolean functionExists(String name) { public boolean functionExists(String functionName) {
return defs.containsKey(normalize(name)); return defs.containsKey(functionName);
} }
public Collection<FunctionDefinition> listFunctions() { public Collection<FunctionDefinition> listFunctions() {

View File

@ -6,31 +6,35 @@
package org.elasticsearch.xpack.sql.expression.function; package org.elasticsearch.xpack.sql.expression.function;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.sql.tree.Location; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.tree.LocationTests;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction; import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.parser.ParsingException; import org.elasticsearch.xpack.sql.parser.ParsingException;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.tree.LocationTests;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.type.DataType;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.TimeZone; import java.util.TimeZone;
import static java.util.Collections.emptyList;
import static org.elasticsearch.xpack.sql.expression.function.FunctionRegistry.def; import static org.elasticsearch.xpack.sql.expression.function.FunctionRegistry.def;
import static org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction.ResolutionType.DISTINCT; import static org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction.ResolutionType.DISTINCT;
import static org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction.ResolutionType.EXTRACT; import static org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction.ResolutionType.EXTRACT;
import static org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction.ResolutionType.STANDARD; import static org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction.ResolutionType.STANDARD;
import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static java.util.Collections.emptyList;
public class FunctionRegistryTests extends ESTestCase { public class FunctionRegistryTests extends ESTestCase {
public void testNoArgFunction() { public void testNoArgFunction() {
UnresolvedFunction ur = uf(STANDARD); UnresolvedFunction ur = uf(STANDARD);
FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, Dummy::new))); FunctionRegistry r = new FunctionRegistry(Collections.singletonList(def(DummyFunction.class, DummyFunction::new)));
FunctionDefinition def = r.resolveFunction(ur.name()); FunctionDefinition def = r.resolveFunction(ur.name());
assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location());
@ -47,9 +51,10 @@ public class FunctionRegistryTests extends ESTestCase {
public void testUnaryFunction() { public void testUnaryFunction() {
UnresolvedFunction ur = uf(STANDARD, mock(Expression.class)); UnresolvedFunction ur = uf(STANDARD, mock(Expression.class));
FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, (Location l, Expression e) -> { FunctionRegistry r = new FunctionRegistry(Collections.singletonList(
def(DummyFunction.class, (Location l, Expression e) -> {
assertSame(e, ur.children().get(0)); assertSame(e, ur.children().get(0));
return new Dummy(l); return new DummyFunction(l);
}))); })));
FunctionDefinition def = r.resolveFunction(ur.name()); FunctionDefinition def = r.resolveFunction(ur.name());
assertFalse(def.datetime()); assertFalse(def.datetime());
@ -74,10 +79,11 @@ public class FunctionRegistryTests extends ESTestCase {
public void testUnaryDistinctAwareFunction() { public void testUnaryDistinctAwareFunction() {
boolean urIsDistinct = randomBoolean(); boolean urIsDistinct = randomBoolean();
UnresolvedFunction ur = uf(urIsDistinct ? DISTINCT : STANDARD, mock(Expression.class)); UnresolvedFunction ur = uf(urIsDistinct ? DISTINCT : STANDARD, mock(Expression.class));
FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, (Location l, Expression e, boolean distinct) -> { FunctionRegistry r = new FunctionRegistry(Collections.singletonList(
def(DummyFunction.class, (Location l, Expression e, boolean distinct) -> {
assertEquals(urIsDistinct, distinct); assertEquals(urIsDistinct, distinct);
assertSame(e, ur.children().get(0)); assertSame(e, ur.children().get(0));
return new Dummy(l); return new DummyFunction(l);
}))); })));
FunctionDefinition def = r.resolveFunction(ur.name()); FunctionDefinition def = r.resolveFunction(ur.name());
assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location());
@ -98,10 +104,11 @@ public class FunctionRegistryTests extends ESTestCase {
boolean urIsExtract = randomBoolean(); boolean urIsExtract = randomBoolean();
UnresolvedFunction ur = uf(urIsExtract ? EXTRACT : STANDARD, mock(Expression.class)); UnresolvedFunction ur = uf(urIsExtract ? EXTRACT : STANDARD, mock(Expression.class));
TimeZone providedTimeZone = randomTimeZone(); TimeZone providedTimeZone = randomTimeZone();
FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, (Location l, Expression e, TimeZone tz) -> { FunctionRegistry r = new FunctionRegistry(Collections.singletonList(
def(DummyFunction.class, (Location l, Expression e, TimeZone tz) -> {
assertEquals(providedTimeZone, tz); assertEquals(providedTimeZone, tz);
assertSame(e, ur.children().get(0)); assertSame(e, ur.children().get(0));
return new Dummy(l); return new DummyFunction(l);
}))); })));
FunctionDefinition def = r.resolveFunction(ur.name()); FunctionDefinition def = r.resolveFunction(ur.name());
assertEquals(ur.location(), ur.buildResolved(providedTimeZone, def).location()); assertEquals(ur.location(), ur.buildResolved(providedTimeZone, def).location());
@ -125,10 +132,11 @@ public class FunctionRegistryTests extends ESTestCase {
public void testBinaryFunction() { public void testBinaryFunction() {
UnresolvedFunction ur = uf(STANDARD, mock(Expression.class), mock(Expression.class)); UnresolvedFunction ur = uf(STANDARD, mock(Expression.class), mock(Expression.class));
FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, (Location l, Expression lhs, Expression rhs) -> { FunctionRegistry r = new FunctionRegistry(Collections.singletonList(
def(DummyFunction.class, (Location l, Expression lhs, Expression rhs) -> {
assertSame(lhs, ur.children().get(0)); assertSame(lhs, ur.children().get(0));
assertSame(rhs, ur.children().get(1)); assertSame(rhs, ur.children().get(1));
return new Dummy(l); return new DummyFunction(l);
}))); })));
FunctionDefinition def = r.resolveFunction(ur.name()); FunctionDefinition def = r.resolveFunction(ur.name());
assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location()); assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location());
@ -156,17 +164,60 @@ public class FunctionRegistryTests extends ESTestCase {
assertThat(e.getMessage(), endsWith("expects exactly two arguments")); assertThat(e.getMessage(), endsWith("expects exactly two arguments"));
} }
private UnresolvedFunction uf(UnresolvedFunction.ResolutionType resolutionType, Expression... children) { public void testFunctionResolving() {
return new UnresolvedFunction(LocationTests.randomLocation(), "dummy", resolutionType, Arrays.asList(children)); UnresolvedFunction ur = uf(STANDARD, mock(Expression.class));
FunctionRegistry r = new FunctionRegistry(
Collections.singletonList(def(DummyFunction.class, (Location l, Expression e) -> {
assertSame(e, ur.children().get(0));
return new DummyFunction(l);
}, "DUMMY_FUNC")));
// Resolve by primary name
FunctionDefinition def = r.resolveFunction(r.resolveAlias("DuMMy_FuncTIon"));
assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location());
def = r.resolveFunction(r.resolveAlias("Dummy_Function"));
assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location());
def = r.resolveFunction(r.resolveAlias("dummy_function"));
assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location());
def = r.resolveFunction(r.resolveAlias("DUMMY_FUNCTION"));
assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location());
// Resolve by alias
def = r.resolveFunction(r.resolveAlias("DumMy_FunC"));
assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location());
def = r.resolveFunction(r.resolveAlias("dummy_func"));
assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location());
def = r.resolveFunction(r.resolveAlias("DUMMY_FUNC"));
assertEquals(ur.location(), ur.buildResolved(randomTimeZone(), def).location());
// Not resolved
SqlIllegalArgumentException e = expectThrows(SqlIllegalArgumentException.class,
() -> r.resolveFunction(r.resolveAlias("DummyFunction")));
assertThat(e.getMessage(),
is("Cannot find function DUMMYFUNCTION; this should have been caught during analysis"));
e = expectThrows(SqlIllegalArgumentException.class,
() -> r.resolveFunction(r.resolveAlias("dummyFunction")));
assertThat(e.getMessage(),
is("Cannot find function DUMMYFUNCTION; this should have been caught during analysis"));
} }
public static class Dummy extends ScalarFunction { private UnresolvedFunction uf(UnresolvedFunction.ResolutionType resolutionType, Expression... children) {
public Dummy(Location location) { return new UnresolvedFunction(LocationTests.randomLocation(), "DUMMY_FUNCTION", resolutionType, Arrays.asList(children));
}
public static class DummyFunction extends ScalarFunction {
public DummyFunction(Location location) {
super(location, emptyList()); super(location, emptyList());
} }
@Override @Override
protected NodeInfo<Dummy> info() { protected NodeInfo<DummyFunction> info() {
return NodeInfo.create(this); return NodeInfo.create(this);
} }

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.sql.tree; package org.elasticsearch.xpack.sql.tree;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.PathUtils; import org.elasticsearch.common.io.PathUtils;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
@ -418,7 +417,7 @@ public class NodeSubclassTests<T extends B, B extends Node<B>> extends ESTestCas
} }
} else if (toBuildClass == ChildrenAreAProperty.class) { } else if (toBuildClass == ChildrenAreAProperty.class) {
/* /*
* While any subclass of Dummy will do here we want to prevent * While any subclass of DummyFunction will do here we want to prevent
* stack overflow so we use the one without children. * stack overflow so we use the one without children.
*/ */
if (argClass == Dummy.class) { if (argClass == Dummy.class) {