diff --git a/sql/server/src/main/java/org/elasticsearch/xpack/sql/execution/PlanExecutor.java b/sql/server/src/main/java/org/elasticsearch/xpack/sql/execution/PlanExecutor.java index f73cf446efe..447365dc261 100644 --- a/sql/server/src/main/java/org/elasticsearch/xpack/sql/execution/PlanExecutor.java +++ b/sql/server/src/main/java/org/elasticsearch/xpack/sql/execution/PlanExecutor.java @@ -12,7 +12,6 @@ import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer; import org.elasticsearch.xpack.sql.analysis.analyzer.PreAnalyzer; import org.elasticsearch.xpack.sql.analysis.index.IndexResolver; import org.elasticsearch.xpack.sql.execution.search.SourceGenerator; -import org.elasticsearch.xpack.sql.expression.function.DefaultFunctionRegistry; import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.sql.optimizer.Optimizer; import org.elasticsearch.xpack.sql.parser.SqlParser; @@ -38,7 +37,7 @@ public class PlanExecutor { public PlanExecutor(Client client, IndexResolver indexResolver) { this.client = client; this.indexResolver = indexResolver; - this.functionRegistry = new DefaultFunctionRegistry(); + this.functionRegistry = new FunctionRegistry(); this.preAnalyzer = new PreAnalyzer(); this.optimizer = new Optimizer(); diff --git a/sql/server/src/main/java/org/elasticsearch/xpack/sql/expression/function/AbstractFunctionRegistry.java b/sql/server/src/main/java/org/elasticsearch/xpack/sql/expression/function/AbstractFunctionRegistry.java deleted file mode 100644 index 1464226ab6e..00000000000 --- a/sql/server/src/main/java/org/elasticsearch/xpack/sql/expression/function/AbstractFunctionRegistry.java +++ /dev/null @@ -1,203 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.sql.expression.function; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; -import org.elasticsearch.xpack.sql.expression.Expression; -import org.elasticsearch.xpack.sql.parser.ParsingException; -import org.elasticsearch.xpack.sql.tree.Location; -import org.elasticsearch.xpack.sql.util.StringUtils; -import org.joda.time.DateTimeZone; - -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.function.BiFunction; -import java.util.regex.Pattern; - -import static java.util.Collections.emptyList; -import static java.util.Collections.unmodifiableList; -import static java.util.stream.Collectors.toList; - -abstract class AbstractFunctionRegistry implements FunctionRegistry { - private final Map defs = new LinkedHashMap<>(); - private final Map aliases; - - protected AbstractFunctionRegistry(List functions) { - this.aliases = new HashMap<>(); - for (FunctionDefinition f : functions) { - defs.put(f.name(), f); - for (String alias : f.aliases()) { - Object old = aliases.put(alias, f.name()); - if (old != null) { - throw new IllegalArgumentException("alias [" + alias + "] is used by [" + old + "] and [" + f.name() + "]"); - } - defs.put(alias, f); - } - } - } - - @Override - public Function resolveFunction(UnresolvedFunction ur, DateTimeZone timeZone) { - FunctionDefinition def = defs.get(normalize(ur.name())); - if (def == null) { - throw new SqlIllegalArgumentException("Cannot find function %s; this should have been caught during analysis", ur.name()); - } - return def.builder().apply(ur, timeZone); - } - - @Override - public String concreteFunctionName(String alias) { - String normalized = normalize(alias); - return aliases.getOrDefault(normalized, normalized); - } - - @Override - public boolean functionExists(String name) { - return defs.containsKey(normalize(name)); - } - - @Override - public Collection listFunctions() { - return defs.entrySet().stream() - .map(e -> new FunctionDefinition(e.getKey(), emptyList(), e.getValue().clazz(), e.getValue().builder())) - .collect(toList()); - } - - @Override - public Collection listFunctions(String pattern) { - Pattern p = Strings.hasText(pattern) ? StringUtils.likeRegex(normalize(pattern)) : null; - return defs.entrySet().stream() - .filter(e -> p == null || p.matcher(e.getKey()).matches()) - .map(e -> new FunctionDefinition(e.getKey(), emptyList(), e.getValue().clazz(), e.getValue().builder())) - .collect(toList()); - } - - /** - * Build a {@linkplain FunctionDefinition} for a no-argument function that - * is not aware of time zone and does not support {@code DISTINCT}. - */ - protected static FunctionDefinition def(Class function, - java.util.function.Function ctorRef, String... aliases) { - FunctionBuilder builder = (location, children, distinct, tz) -> { - if (false == children.isEmpty()) { - throw new IllegalArgumentException("expects only a single argument"); - } - if (distinct) { - throw new IllegalArgumentException("does not support DISTINCT yet it was specified"); - } - return ctorRef.apply(location); - }; - return def(function, builder, aliases); - } - - /** - * Build a {@linkplain FunctionDefinition} for a unary function that is not - * aware of time zone and does not support {@code DISTINCT}. - */ - @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - protected static FunctionDefinition def(Class function, - BiFunction ctorRef, String... aliases) { - FunctionBuilder builder = (location, children, distinct, tz) -> { - if (children.size() != 1) { - throw new IllegalArgumentException("expects only a single argument"); - } - if (distinct) { - throw new IllegalArgumentException("does not support DISTINCT yet it was specified"); - } - return ctorRef.apply(location, children.get(0)); - }; - return def(function, builder, aliases); - } - - /** - * Build a {@linkplain FunctionDefinition} for a unary function that is not - * aware of time zone but does support {@code DISTINCT}. - */ - @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - protected static FunctionDefinition def(Class function, - DistinctAwareUnaryFunctionBuilder ctorRef, String... aliases) { - FunctionBuilder builder = (location, children, distinct, tz) -> { - if (children.size() != 1) { - throw new IllegalArgumentException("expects only a single argument"); - } - return ctorRef.build(location, children.get(0), distinct); - }; - return def(function, builder, aliases); - } - protected interface DistinctAwareUnaryFunctionBuilder { - T build(Location location, Expression target, boolean distinct); - } - - /** - * Build a {@linkplain FunctionDefinition} for a unary function that is - * aware of time zone and does not support {@code DISTINCT}. - */ - @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - protected static FunctionDefinition def(Class function, - TimeZoneAwareUnaryFunctionBuilder ctorRef, String... aliases) { - FunctionBuilder builder = (location, children, distinct, tz) -> { - if (children.size() != 1) { - throw new IllegalArgumentException("expects only a single argument"); - } - if (distinct) { - throw new IllegalArgumentException("does not support DISTINCT yet it was specified"); - } - return ctorRef.build(location, children.get(0), tz); - }; - return def(function, builder, aliases); - } - protected interface TimeZoneAwareUnaryFunctionBuilder { - T build(Location location, Expression target, DateTimeZone tz); - } - - /** - * Build a {@linkplain FunctionDefinition} for a binary function that is - * not aware of time zone and does not support {@code DISTINCT}. - */ - @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - protected static FunctionDefinition def(Class function, - BinaryFunctionBuilder ctorRef, String... aliases) { - FunctionBuilder builder = (location, children, distinct, tz) -> { - if (children.size() != 2) { - throw new IllegalArgumentException("expects only a single argument"); - } - if (distinct) { - throw new IllegalArgumentException("does not support DISTINCT yet it was specified"); - } - return ctorRef.build(location, children.get(0), children.get(1)); - }; - return def(function, builder, aliases); - } - protected interface BinaryFunctionBuilder { - T build(Location location, Expression lhs, Expression rhs); - } - - private static FunctionDefinition def(Class function, FunctionBuilder builder, String... aliases) { - String primaryName = normalize(function.getSimpleName()); - BiFunction realBuilder = (uf, tz) -> { - try { - return builder.build(uf.location(), uf.children(), uf.distinct(), tz); - } catch (IllegalArgumentException e) { - throw new ParsingException("error builder [" + primaryName + "]: " + e.getMessage(), e, - uf.location().getLineNumber(), uf.location().getColumnNumber()); - } - }; - return new FunctionDefinition(primaryName, unmodifiableList(Arrays.asList(aliases)), function, realBuilder); - } - private interface FunctionBuilder { - Function build(Location location, List children, boolean distinct, DateTimeZone tz); - } - - protected static String normalize(String name) { - // translate CamelCase to camel_case - return StringUtils.camelCaseToUnderscore(name); - } -} diff --git a/sql/server/src/main/java/org/elasticsearch/xpack/sql/expression/function/DefaultFunctionRegistry.java b/sql/server/src/main/java/org/elasticsearch/xpack/sql/expression/function/DefaultFunctionRegistry.java deleted file mode 100644 index 4ef83f8ae3a..00000000000 --- a/sql/server/src/main/java/org/elasticsearch/xpack/sql/expression/function/DefaultFunctionRegistry.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.sql.expression.function; - -import org.elasticsearch.xpack.sql.expression.function.Score; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Avg; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Correlation; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Count; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Covariance; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Kurtosis; -import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixCount; -import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixMean; -import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixVariance; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Max; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Mean; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Min; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentile; -import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRank; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Skewness; -import org.elasticsearch.xpack.sql.expression.function.aggregate.StddevPop; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Sum; -import org.elasticsearch.xpack.sql.expression.function.aggregate.SumOfSquares; -import org.elasticsearch.xpack.sql.expression.function.aggregate.VarPop; -import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DayOfMonth; -import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DayOfWeek; -import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DayOfYear; -import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.HourOfDay; -import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.MinuteOfDay; -import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.MinuteOfHour; -import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.MonthOfYear; -import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.SecondOfMinute; -import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.Year; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.ACos; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.ASin; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.ATan; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Abs; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Cbrt; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Ceil; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Cos; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Cosh; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Degrees; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.E; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Exp; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Expm1; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Floor; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Log; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Log10; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Pi; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Radians; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Round; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Sin; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Sinh; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Sqrt; -import org.elasticsearch.xpack.sql.expression.function.scalar.math.Tan; -import java.util.Arrays; -import java.util.List; - -import static java.util.Collections.unmodifiableList; - -public class DefaultFunctionRegistry extends AbstractFunctionRegistry { - private static final List FUNCTIONS = unmodifiableList(Arrays.asList( - // Aggregate functions - def(Avg.class, Avg::new), - def(Count.class, Count::new), - def(Max.class, Max::new), - def(Min.class, Min::new), - def(Sum.class, Sum::new), - // Statistics - def(Mean.class, Mean::new), - def(StddevPop.class, StddevPop::new), - def(VarPop.class, VarPop::new), - def(Percentile.class, Percentile::new), - def(PercentileRank.class, PercentileRank::new), - def(SumOfSquares.class, SumOfSquares::new), - // Matrix aggs - def(MatrixCount.class, MatrixCount::new), - def(MatrixMean.class, MatrixMean::new), - def(MatrixVariance.class, MatrixVariance::new), - def(Skewness.class, Skewness::new), - def(Kurtosis.class, Kurtosis::new), - def(Covariance.class, Covariance::new), - def(Correlation.class, Correlation::new), - // Scalar functions - // Date - def(DayOfMonth.class, DayOfMonth::new, "DAY", "DOM"), - def(DayOfWeek.class, DayOfWeek::new, "DOW"), - def(DayOfYear.class, DayOfYear::new, "DOY"), - def(HourOfDay.class, HourOfDay::new, "HOUR"), - def(MinuteOfDay.class, MinuteOfDay::new), - def(MinuteOfHour.class, MinuteOfHour::new, "MINUTE"), - def(SecondOfMinute.class, SecondOfMinute::new, "SECOND"), - def(MonthOfYear.class, MonthOfYear::new, "MONTH"), - def(Year.class, Year::new), - // Math - def(Abs.class, Abs::new), - def(ACos.class, ACos::new), - def(ASin.class, ASin::new), - def(ATan.class, ATan::new), - def(Cbrt.class, Cbrt::new), - def(Ceil.class, Ceil::new), - def(Cos.class, Cos::new), - def(Cosh.class, Cosh::new), - def(Degrees.class, Degrees::new), - def(E.class, E::new), - def(Exp.class, Exp::new), - def(Expm1.class, Expm1::new), - def(Floor.class, Floor::new), - def(Log.class, Log::new), - def(Log10.class, Log10::new), - def(Pi.class, Pi::new), - def(Radians.class, Radians::new), - def(Round.class, Round::new), - def(Sin.class, Sin::new), - def(Sinh.class, Sinh::new), - def(Sqrt.class, Sqrt::new), - def(Tan.class, Tan::new), - // Special - def(Score.class, Score::new))); - - public DefaultFunctionRegistry() { - super(FUNCTIONS); - } -} diff --git a/sql/server/src/main/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistry.java b/sql/server/src/main/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistry.java index a41a3b5ce5c..99fbd9fa45c 100644 --- a/sql/server/src/main/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistry.java +++ b/sql/server/src/main/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistry.java @@ -5,20 +5,314 @@ */ package org.elasticsearch.xpack.sql.expression.function; +import org.elasticsearch.common.Strings; + +import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; +import org.elasticsearch.xpack.sql.expression.Expression; +import org.elasticsearch.xpack.sql.expression.function.Score; +import org.elasticsearch.xpack.sql.expression.function.aggregate.Avg; +import org.elasticsearch.xpack.sql.expression.function.aggregate.Correlation; +import org.elasticsearch.xpack.sql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.sql.expression.function.aggregate.Covariance; +import org.elasticsearch.xpack.sql.expression.function.aggregate.Kurtosis; +import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixCount; +import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixMean; +import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixVariance; +import org.elasticsearch.xpack.sql.expression.function.aggregate.Max; +import org.elasticsearch.xpack.sql.expression.function.aggregate.Mean; +import org.elasticsearch.xpack.sql.expression.function.aggregate.Min; +import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentile; +import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRank; +import org.elasticsearch.xpack.sql.expression.function.aggregate.Skewness; +import org.elasticsearch.xpack.sql.expression.function.aggregate.StddevPop; +import org.elasticsearch.xpack.sql.expression.function.aggregate.Sum; +import org.elasticsearch.xpack.sql.expression.function.aggregate.SumOfSquares; +import org.elasticsearch.xpack.sql.expression.function.aggregate.VarPop; +import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DayOfMonth; +import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DayOfWeek; +import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DayOfYear; +import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.HourOfDay; +import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.MinuteOfDay; +import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.MinuteOfHour; +import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.MonthOfYear; +import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.SecondOfMinute; +import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.Year; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.ACos; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.ASin; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.ATan; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Abs; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Cbrt; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Ceil; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Cos; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Cosh; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Degrees; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.E; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Exp; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Expm1; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Floor; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Log; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Log10; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Pi; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Radians; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Round; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Sin; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Sinh; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Sqrt; +import org.elasticsearch.xpack.sql.expression.function.scalar.math.Tan; +import org.elasticsearch.xpack.sql.parser.ParsingException; +import org.elasticsearch.xpack.sql.tree.Location; +import org.elasticsearch.xpack.sql.util.StringUtils; import org.joda.time.DateTimeZone; +import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.regex.Pattern; -public interface FunctionRegistry { +import static java.util.Collections.emptyList; +import static java.util.Collections.unmodifiableList; +import static java.util.stream.Collectors.toList; - Function resolveFunction(UnresolvedFunction ur, DateTimeZone timeZone); +public class FunctionRegistry { + private static final List DEFAULT_FUNCTIONS = unmodifiableList(Arrays.asList( + // Aggregate functions + def(Avg.class, Avg::new), + def(Count.class, Count::new), + def(Max.class, Max::new), + def(Min.class, Min::new), + def(Sum.class, Sum::new), + // Statistics + def(Mean.class, Mean::new), + def(StddevPop.class, StddevPop::new), + def(VarPop.class, VarPop::new), + def(Percentile.class, Percentile::new), + def(PercentileRank.class, PercentileRank::new), + def(SumOfSquares.class, SumOfSquares::new), + // Matrix aggs + def(MatrixCount.class, MatrixCount::new), + def(MatrixMean.class, MatrixMean::new), + def(MatrixVariance.class, MatrixVariance::new), + def(Skewness.class, Skewness::new), + def(Kurtosis.class, Kurtosis::new), + def(Covariance.class, Covariance::new), + def(Correlation.class, Correlation::new), + // Scalar functions + // Date + def(DayOfMonth.class, DayOfMonth::new, "DAY", "DOM"), + def(DayOfWeek.class, DayOfWeek::new, "DOW"), + def(DayOfYear.class, DayOfYear::new, "DOY"), + def(HourOfDay.class, HourOfDay::new, "HOUR"), + def(MinuteOfDay.class, MinuteOfDay::new), + def(MinuteOfHour.class, MinuteOfHour::new, "MINUTE"), + def(SecondOfMinute.class, SecondOfMinute::new, "SECOND"), + def(MonthOfYear.class, MonthOfYear::new, "MONTH"), + def(Year.class, Year::new), + // Math + def(Abs.class, Abs::new), + def(ACos.class, ACos::new), + def(ASin.class, ASin::new), + def(ATan.class, ATan::new), + def(Cbrt.class, Cbrt::new), + def(Ceil.class, Ceil::new), + def(Cos.class, Cos::new), + def(Cosh.class, Cosh::new), + def(Degrees.class, Degrees::new), + def(E.class, E::new), + def(Exp.class, Exp::new), + def(Expm1.class, Expm1::new), + def(Floor.class, Floor::new), + def(Log.class, Log::new), + def(Log10.class, Log10::new), + def(Pi.class, Pi::new), + def(Radians.class, Radians::new), + def(Round.class, Round::new), + def(Sin.class, Sin::new), + def(Sinh.class, Sinh::new), + def(Sqrt.class, Sqrt::new), + def(Tan.class, Tan::new), + // Special + def(Score.class, Score::new))); - String concreteFunctionName(String alias); + private final Map defs = new LinkedHashMap<>(); + private final Map aliases; - boolean functionExists(String name); + /** + * Constructor to build with the default list of functions. + */ + public FunctionRegistry() { + this(DEFAULT_FUNCTIONS); + } - Collection listFunctions(); + /** + * Constructor specifying alternate functions for testing. + */ + FunctionRegistry(List functions) { + this.aliases = new HashMap<>(); + for (FunctionDefinition f : functions) { + defs.put(f.name(), f); + for (String alias : f.aliases()) { + Object old = aliases.put(alias, f.name()); + if (old != null) { + throw new IllegalArgumentException("alias [" + alias + "] is used by [" + old + "] and [" + f.name() + "]"); + } + defs.put(alias, f); + } + } + } - Collection listFunctions(String pattern); + public Function resolveFunction(UnresolvedFunction ur, DateTimeZone timeZone) { + FunctionDefinition def = defs.get(normalize(ur.name())); + if (def == null) { + throw new SqlIllegalArgumentException("Cannot find function %s; this should have been caught during analysis", ur.name()); + } + return def.builder().apply(ur, timeZone); + } + public String concreteFunctionName(String alias) { + String normalized = normalize(alias); + return aliases.getOrDefault(normalized, normalized); + } + + public boolean functionExists(String name) { + return defs.containsKey(normalize(name)); + } + + public Collection listFunctions() { + return defs.entrySet().stream() + .map(e -> new FunctionDefinition(e.getKey(), emptyList(), e.getValue().clazz(), e.getValue().builder())) + .collect(toList()); + } + + public Collection listFunctions(String pattern) { + Pattern p = Strings.hasText(pattern) ? StringUtils.likeRegex(normalize(pattern)) : null; + return defs.entrySet().stream() + .filter(e -> p == null || p.matcher(e.getKey()).matches()) + .map(e -> new FunctionDefinition(e.getKey(), emptyList(), e.getValue().clazz(), e.getValue().builder())) + .collect(toList()); + } + + /** + * Build a {@linkplain FunctionDefinition} for a no-argument function that + * is not aware of time zone and does not support {@code DISTINCT}. + */ + static FunctionDefinition def(Class function, + java.util.function.Function ctorRef, String... aliases) { + FunctionBuilder builder = (location, children, distinct, tz) -> { + if (false == children.isEmpty()) { + throw new IllegalArgumentException("expects no arguments"); + } + if (distinct) { + throw new IllegalArgumentException("does not support DISTINCT yet it was specified"); + } + return ctorRef.apply(location); + }; + return def(function, builder, aliases); + } + + /** + * Build a {@linkplain FunctionDefinition} for a unary function that is not + * aware of time zone and does not support {@code DISTINCT}. + */ + @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do + static FunctionDefinition def(Class function, + BiFunction ctorRef, String... aliases) { + FunctionBuilder builder = (location, children, distinct, tz) -> { + if (children.size() != 1) { + throw new IllegalArgumentException("expects exactly one argument"); + } + if (distinct) { + throw new IllegalArgumentException("does not support DISTINCT yet it was specified"); + } + return ctorRef.apply(location, children.get(0)); + }; + return def(function, builder, aliases); + } + + /** + * Build a {@linkplain FunctionDefinition} for a unary function that is not + * aware of time zone but does support {@code DISTINCT}. + */ + @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do + static FunctionDefinition def(Class function, + DistinctAwareUnaryFunctionBuilder ctorRef, String... aliases) { + FunctionBuilder builder = (location, children, distinct, tz) -> { + if (children.size() != 1) { + throw new IllegalArgumentException("expects exactly one argument"); + } + return ctorRef.build(location, children.get(0), distinct); + }; + return def(function, builder, aliases); + } + interface DistinctAwareUnaryFunctionBuilder { + T build(Location location, Expression target, boolean distinct); + } + + /** + * Build a {@linkplain FunctionDefinition} for a unary function that is + * aware of time zone and does not support {@code DISTINCT}. + */ + @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do + static FunctionDefinition def(Class function, + TimeZoneAwareUnaryFunctionBuilder ctorRef, String... aliases) { + FunctionBuilder builder = (location, children, distinct, tz) -> { + if (children.size() != 1) { + throw new IllegalArgumentException("expects exactly one argument"); + } + if (distinct) { + throw new IllegalArgumentException("does not support DISTINCT yet it was specified"); + } + return ctorRef.build(location, children.get(0), tz); + }; + return def(function, builder, aliases); + } + interface TimeZoneAwareUnaryFunctionBuilder { + T build(Location location, Expression target, DateTimeZone tz); + } + + /** + * Build a {@linkplain FunctionDefinition} for a binary function that is + * not aware of time zone and does not support {@code DISTINCT}. + */ + @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do + static FunctionDefinition def(Class function, + BinaryFunctionBuilder ctorRef, String... aliases) { + FunctionBuilder builder = (location, children, distinct, tz) -> { + if (children.size() != 2) { + throw new IllegalArgumentException("expects exactly two arguments"); + } + if (distinct) { + throw new IllegalArgumentException("does not support DISTINCT yet it was specified"); + } + return ctorRef.build(location, children.get(0), children.get(1)); + }; + return def(function, builder, aliases); + } + interface BinaryFunctionBuilder { + T build(Location location, Expression lhs, Expression rhs); + } + + private static FunctionDefinition def(Class function, FunctionBuilder builder, String... aliases) { + String primaryName = normalize(function.getSimpleName()); + BiFunction realBuilder = (uf, tz) -> { + try { + return builder.build(uf.location(), uf.children(), uf.distinct(), tz); + } catch (IllegalArgumentException e) { + throw new ParsingException("error building [" + primaryName + "]: " + e.getMessage(), e, + uf.location().getLineNumber(), uf.location().getColumnNumber()); + } + }; + return new FunctionDefinition(primaryName, unmodifiableList(Arrays.asList(aliases)), function, realBuilder); + } + private interface FunctionBuilder { + Function build(Location location, List children, boolean distinct, DateTimeZone tz); + } + + private static String normalize(String name) { + // translate CamelCase to camel_case + return StringUtils.camelCaseToUnderscore(name); + } } diff --git a/sql/server/src/main/java/org/elasticsearch/xpack/sql/session/SqlSession.java b/sql/server/src/main/java/org/elasticsearch/xpack/sql/session/SqlSession.java index fb992b4d9c5..693fd25e632 100644 --- a/sql/server/src/main/java/org/elasticsearch/xpack/sql/session/SqlSession.java +++ b/sql/server/src/main/java/org/elasticsearch/xpack/sql/session/SqlSession.java @@ -13,7 +13,6 @@ import org.elasticsearch.xpack.sql.analysis.analyzer.PreAnalyzer.PreAnalysis; import org.elasticsearch.xpack.sql.analysis.index.GetIndexResult; import org.elasticsearch.xpack.sql.analysis.index.IndexResolver; import org.elasticsearch.xpack.sql.analysis.index.MappingException; -import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.sql.optimizer.Optimizer; import org.elasticsearch.xpack.sql.parser.SqlParser; diff --git a/sql/server/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/FieldAttributeTests.java b/sql/server/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/FieldAttributeTests.java index 64ac2216fd6..f211975103b 100644 --- a/sql/server/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/FieldAttributeTests.java +++ b/sql/server/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/FieldAttributeTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Expressions; import org.elasticsearch.xpack.sql.expression.FieldAttribute; import org.elasticsearch.xpack.sql.expression.NamedExpression; -import org.elasticsearch.xpack.sql.expression.function.DefaultFunctionRegistry; import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.sql.parser.SqlParser; import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan; @@ -45,7 +44,7 @@ public class FieldAttributeTests extends ESTestCase { public FieldAttributeTests() { parser = new SqlParser(DateTimeZone.UTC); - functionRegistry = new DefaultFunctionRegistry(); + functionRegistry = new FunctionRegistry(); Map mapping = TypesTests.loadMapping("mapping-multi-field-variation.json"); @@ -142,4 +141,4 @@ public class FieldAttributeTests extends ESTestCase { assertThat(names, not(hasItem("unsupported"))); assertThat(names, hasItems("bool", "text", "keyword", "int")); } -} \ No newline at end of file +} diff --git a/sql/server/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java b/sql/server/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java index 0083f89d8ff..b8a47c96b35 100644 --- a/sql/server/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java +++ b/sql/server/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java @@ -9,7 +9,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.sql.analysis.AnalysisException; import org.elasticsearch.xpack.sql.analysis.index.EsIndex; import org.elasticsearch.xpack.sql.analysis.index.GetIndexResult; -import org.elasticsearch.xpack.sql.expression.function.DefaultFunctionRegistry; +import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.sql.parser.SqlParser; import org.elasticsearch.xpack.sql.type.DataType; import org.elasticsearch.xpack.sql.type.TypesTests; @@ -27,7 +27,7 @@ public class VerifierErrorMessagesTests extends ESTestCase { } private String verify(GetIndexResult getIndexResult, String sql) { - Analyzer analyzer = new Analyzer(new DefaultFunctionRegistry(), getIndexResult, DateTimeZone.UTC); + Analyzer analyzer = new Analyzer(new FunctionRegistry(), getIndexResult, DateTimeZone.UTC); AnalysisException e = expectThrows(AnalysisException.class, () -> analyzer.analyze(parser.createStatement(sql), true)); assertTrue(e.getMessage().startsWith("Found ")); String header = "Found 1 problem(s)\nline "; @@ -117,4 +117,4 @@ public class VerifierErrorMessagesTests extends ESTestCase { assertEquals("1:8: Cannot use field [unsupported], its type [ip_range] is unsupported", verify("SELECT unsupported FROM test")); } -} \ No newline at end of file +} diff --git a/sql/server/src/test/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistryTests.java b/sql/server/src/test/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistryTests.java new file mode 100644 index 00000000000..dbdc2d3c3bd --- /dev/null +++ b/sql/server/src/test/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistryTests.java @@ -0,0 +1,167 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.sql.expression.function; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.sql.tree.Location; +import org.elasticsearch.xpack.sql.tree.LocationTests; +import org.elasticsearch.xpack.sql.type.DataType; +import org.joda.time.DateTimeZone; +import org.elasticsearch.xpack.sql.expression.Expression; +import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction; +import org.elasticsearch.xpack.sql.expression.function.scalar.processor.definition.ProcessorDefinition; +import org.elasticsearch.xpack.sql.expression.function.scalar.script.ScriptTemplate; +import org.elasticsearch.xpack.sql.parser.ParsingException; +import java.util.Arrays; + +import static org.elasticsearch.xpack.sql.expression.function.FunctionRegistry.def; +import static org.hamcrest.Matchers.endsWith; +import static org.mockito.Mockito.mock; +import static java.util.Collections.emptyList; + +public class FunctionRegistryTests extends ESTestCase { + public void testNoArgFunction() { + UnresolvedFunction ur = uf(false); + FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, Dummy::new))); + assertEquals(ur.location(), r.resolveFunction(ur, randomDateTimeZone()).location()); + + // Distinct isn't supported + ParsingException e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(true), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("does not support DISTINCT yet it was specified")); + + // Any children aren't supported + e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(false, mock(Expression.class)), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("expects no arguments")); + } + + public void testUnaryFunction() { + UnresolvedFunction ur = uf(false, mock(Expression.class)); + FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, (Location l, Expression e) -> { + assertSame(e, ur.children().get(0)); + return new Dummy(l); + }))); + assertEquals(ur.location(), r.resolveFunction(ur, randomDateTimeZone()).location()); + + // Distinct isn't supported + ParsingException e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(true, mock(Expression.class)), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("does not support DISTINCT yet it was specified")); + + // No children aren't supported + e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(false), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("expects exactly one argument")); + + // Multiple children aren't supported + e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(false, mock(Expression.class), mock(Expression.class)), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("expects exactly one argument")); + } + + public void testUnaryDistinctAwareFunction() { + UnresolvedFunction ur = uf(randomBoolean(), mock(Expression.class)); + FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, (Location l, Expression e, boolean distinct) -> { + assertEquals(ur.distinct(), distinct); + assertSame(e, ur.children().get(0)); + return new Dummy(l); + }))); + assertEquals(ur.location(), r.resolveFunction(ur, randomDateTimeZone()).location()); + + // No children aren't supported + ParsingException e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(false), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("expects exactly one argument")); + + // Multiple children aren't supported + e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(false, mock(Expression.class), mock(Expression.class)), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("expects exactly one argument")); + } + + public void testTimeZoneAwareFunction() { + UnresolvedFunction ur = uf(false, mock(Expression.class)); + DateTimeZone providedTimeZone = randomDateTimeZone(); + FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, (Location l, Expression e, DateTimeZone tz) -> { + assertEquals(providedTimeZone, tz); + assertSame(e, ur.children().get(0)); + return new Dummy(l); + }))); + assertEquals(ur.location(), r.resolveFunction(ur, providedTimeZone).location()); + + // Distinct isn't supported + ParsingException e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(true, mock(Expression.class)), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("does not support DISTINCT yet it was specified")); + + // No children aren't supported + e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(false), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("expects exactly one argument")); + + // Multiple children aren't supported + e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(false, mock(Expression.class), mock(Expression.class)), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("expects exactly one argument")); + } + + public void testBinaryFunction() { + UnresolvedFunction ur = uf(false, mock(Expression.class), mock(Expression.class)); + FunctionRegistry r = new FunctionRegistry(Arrays.asList(def(Dummy.class, (Location l, Expression lhs, Expression rhs) -> { + assertSame(lhs, ur.children().get(0)); + assertSame(rhs, ur.children().get(1)); + return new Dummy(l); + }))); + assertEquals(ur.location(), r.resolveFunction(ur, randomDateTimeZone()).location()); + + // Distinct isn't supported + ParsingException e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(true, mock(Expression.class), mock(Expression.class)), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("does not support DISTINCT yet it was specified")); + + // No children aren't supported + e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(false), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("expects exactly two arguments")); + + // One child isn't supported + e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(false, mock(Expression.class)), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("expects exactly two arguments")); + + // Many children aren't supported + e = expectThrows(ParsingException.class, () -> + r.resolveFunction(uf(false, mock(Expression.class), mock(Expression.class), mock(Expression.class)), randomDateTimeZone())); + assertThat(e.getMessage(), endsWith("expects exactly two arguments")); + } + + private UnresolvedFunction uf(boolean distinct, Expression... children) { + return new UnresolvedFunction(LocationTests.randomLocation(), "dummy", distinct, Arrays.asList(children)); + } + + private static class Dummy extends ScalarFunction { + private Dummy(Location location) { + super(location, emptyList()); + } + + + @Override + public DataType dataType() { + return null; + } + + @Override + public ScriptTemplate asScript() { + return null; + } + + @Override + protected ProcessorDefinition makeProcessorDefinition() { + return null; + } + } +} diff --git a/sql/server/src/test/java/org/elasticsearch/xpack/sql/planner/VerifierErrorMessagesTests.java b/sql/server/src/test/java/org/elasticsearch/xpack/sql/planner/VerifierErrorMessagesTests.java index 62ad8171497..e0921f27fea 100644 --- a/sql/server/src/test/java/org/elasticsearch/xpack/sql/planner/VerifierErrorMessagesTests.java +++ b/sql/server/src/test/java/org/elasticsearch/xpack/sql/planner/VerifierErrorMessagesTests.java @@ -9,7 +9,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer; import org.elasticsearch.xpack.sql.analysis.index.EsIndex; import org.elasticsearch.xpack.sql.analysis.index.GetIndexResult; -import org.elasticsearch.xpack.sql.expression.function.DefaultFunctionRegistry; +import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.sql.optimizer.Optimizer; import org.elasticsearch.xpack.sql.parser.SqlParser; import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan; @@ -34,7 +34,7 @@ public class VerifierErrorMessagesTests extends ESTestCase { mapping.put("keyword", DataTypes.KEYWORD); EsIndex test = new EsIndex("test", mapping); GetIndexResult getIndexResult = GetIndexResult.valid(test); - Analyzer analyzer = new Analyzer(new DefaultFunctionRegistry(), getIndexResult, DateTimeZone.UTC); + Analyzer analyzer = new Analyzer(new FunctionRegistry(), getIndexResult, DateTimeZone.UTC); LogicalPlan plan = optimizer.optimize(analyzer.analyze(parser.createStatement(sql), true)); PlanningException e = expectThrows(PlanningException.class, () -> planner.mapPlan(plan, true)); assertTrue(e.getMessage().startsWith("Found "));