diff --git a/x-pack/plugin/sql/qa/single-node/src/test/java/org/elasticsearch/xpack/sql/qa/single_node/CliExplainIT.java b/x-pack/plugin/sql/qa/single-node/src/test/java/org/elasticsearch/xpack/sql/qa/single_node/CliExplainIT.java index af52e50348c..c7069f18e02 100644 --- a/x-pack/plugin/sql/qa/single-node/src/test/java/org/elasticsearch/xpack/sql/qa/single_node/CliExplainIT.java +++ b/x-pack/plugin/sql/qa/single-node/src/test/java/org/elasticsearch/xpack/sql/qa/single_node/CliExplainIT.java @@ -19,19 +19,19 @@ public class CliExplainIT extends CliIntegrationTestCase { assertThat(command("EXPLAIN (PLAN PARSED) SELECT * FROM test"), containsString("plan")); assertThat(readLine(), startsWith("----------")); assertThat(readLine(), startsWith("With[{}]")); - assertThat(readLine(), startsWith("\\_Project[[?*]]")); + assertThat(readLine(), startsWith("\\_Project[[?* AS ?]]")); assertThat(readLine(), startsWith(" \\_UnresolvedRelation[test]")); assertEquals("", readLine()); assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT * FROM test"), containsString("plan")); assertThat(readLine(), startsWith("----------")); - assertThat(readLine(), startsWith("Project[[test_field{f}#")); + assertThat(readLine(), startsWith("Project[[test.test_field{f}#")); assertThat(readLine(), startsWith("\\_EsRelation[test][test_field{f}#")); assertEquals("", readLine()); assertThat(command("EXPLAIN (PLAN OPTIMIZED) SELECT * FROM test"), containsString("plan")); assertThat(readLine(), startsWith("----------")); - assertThat(readLine(), startsWith("Project[[test_field{f}#")); + assertThat(readLine(), startsWith("Project[[test.test_field{f}#")); assertThat(readLine(), startsWith("\\_EsRelation[test][test_field{f}#")); assertEquals("", readLine()); @@ -63,23 +63,23 @@ public class CliExplainIT extends CliIntegrationTestCase { assertThat(command("EXPLAIN (PLAN PARSED) SELECT * FROM test WHERE i = 2"), containsString("plan")); assertThat(readLine(), startsWith("----------")); assertThat(readLine(), startsWith("With[{}]")); - assertThat(readLine(), startsWith("\\_Project[[?*]]")); - assertThat(readLine(), startsWith(" \\_Filter[Equals[?i,2")); + assertThat(readLine(), startsWith("\\_Project[[?* AS ?]]")); + assertThat(readLine(), startsWith(" \\_Filter[?i == 2[INTEGER]]")); assertThat(readLine(), startsWith(" \\_UnresolvedRelation[test]")); assertEquals("", readLine()); assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT * FROM test WHERE i = 2"), containsString("plan")); assertThat(readLine(), startsWith("----------")); - assertThat(readLine(), startsWith("Project[[i{f}#")); - assertThat(readLine(), startsWith("\\_Filter[Equals[i")); + assertThat(readLine(), startsWith("Project[[test.i{f}#")); + assertThat(readLine(), startsWith("\\_Filter[test.i{f}#")); assertThat(readLine(), startsWith(" \\_EsRelation[test][i{f}#")); assertEquals("", readLine()); assertThat(command("EXPLAIN (PLAN OPTIMIZED) SELECT * FROM test WHERE i = 2"), containsString("plan")); assertThat(readLine(), startsWith("----------")); - assertThat(readLine(), startsWith("Project[[i{f}#")); - assertThat(readLine(), startsWith("\\_Filter[Equals[i")); + assertThat(readLine(), startsWith("Project[[test.i{f}#")); + assertThat(readLine(), startsWith("\\_Filter[test.i{f}")); assertThat(readLine(), startsWith(" \\_EsRelation[test][i{f}#")); assertEquals("", readLine()); @@ -119,20 +119,20 @@ public class CliExplainIT extends CliIntegrationTestCase { assertThat(command("EXPLAIN (PLAN PARSED) SELECT COUNT(*) FROM test"), containsString("plan")); assertThat(readLine(), startsWith("----------")); assertThat(readLine(), startsWith("With[{}]")); - assertThat(readLine(), startsWith("\\_Project[[?COUNT[?*]]]")); + assertThat(readLine(), startsWith("\\_Project[[?COUNT[?*] AS ?]]")); assertThat(readLine(), startsWith(" \\_UnresolvedRelation[test]")); assertEquals("", readLine()); assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT COUNT(*) FROM test"), containsString("plan")); assertThat(readLine(), startsWith("----------")); - assertThat(readLine(), startsWith("Aggregate[[],[Count[*=1")); + assertThat(readLine(), startsWith("Aggregate[[],[COUNT(*)")); assertThat(readLine(), startsWith("\\_EsRelation[test][i{f}#")); assertEquals("", readLine()); assertThat(command("EXPLAIN (PLAN OPTIMIZED) SELECT COUNT(*) FROM test"), containsString("plan")); assertThat(readLine(), startsWith("----------")); - assertThat(readLine(), startsWith("Aggregate[[],[Count[*=1")); + assertThat(readLine(), startsWith("Aggregate[[],[COUNT(*)")); assertThat(readLine(), startsWith("\\_EsRelation[test][i{f}#")); assertEquals("", readLine()); diff --git a/x-pack/plugin/sql/qa/src/main/java/org/elasticsearch/xpack/sql/qa/rest/RestSqlTestCase.java b/x-pack/plugin/sql/qa/src/main/java/org/elasticsearch/xpack/sql/qa/rest/RestSqlTestCase.java index d5cd695dd27..9457bfdf929 100644 --- a/x-pack/plugin/sql/qa/src/main/java/org/elasticsearch/xpack/sql/qa/rest/RestSqlTestCase.java +++ b/x-pack/plugin/sql/qa/src/main/java/org/elasticsearch/xpack/sql/qa/rest/RestSqlTestCase.java @@ -422,36 +422,36 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err boolean columnar = randomBoolean(); String expected = ""; if (columnar) { - expected = "{\n" + - " \"columns\" : [\n" + - " {\n" + - " \"name\" : \"test1\",\n" + - " \"type\" : \"text\"\n" + - " }\n" + - " ],\n" + - " \"values\" : [\n" + - " [\n" + - " \"test1\",\n" + - " \"test2\"\n" + - " ]\n" + - " ]\n" + + expected = "{\n" + + " \"columns\" : [\n" + + " {\n" + + " \"name\" : \"test1\",\n" + + " \"type\" : \"text\"\n" + + " }\n" + + " ],\n" + + " \"values\" : [\n" + + " [\n" + + " \"test1\",\n" + + " \"test2\"\n" + + " ]\n" + + " ]\n" + "}\n"; } else { - expected = "{\n" + - " \"columns\" : [\n" + - " {\n" + - " \"name\" : \"test1\",\n" + - " \"type\" : \"text\"\n" + - " }\n" + - " ],\n" + - " \"rows\" : [\n" + - " [\n" + - " \"test1\"\n" + - " ],\n" + - " [\n" + - " \"test2\"\n" + - " ]\n" + - " ]\n" + + expected = "{\n" + + " \"columns\" : [\n" + + " {\n" + + " \"name\" : \"test1\",\n" + + " \"type\" : \"text\"\n" + + " }\n" + + " ],\n" + + " \"rows\" : [\n" + + " [\n" + + " \"test1\"\n" + + " ],\n" + + " [\n" + + " \"test2\"\n" + + " ]\n" + + " ]\n" + "}\n"; } executeAndAssertPrettyPrinting(expected, "true", columnar); @@ -644,14 +644,14 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err Map aggregations2 = (Map) groupby.get("aggregations"); assertEquals(2, aggregations2.size()); - List aggKeys = new ArrayList<>(2); + List aggKeys = new ArrayList<>(2); String aggFilterKey = null; for (Map.Entry entry : aggregations2.entrySet()) { String key = entry.getKey(); if (key.startsWith("having")) { aggFilterKey = key; } else { - aggKeys.add(Integer.valueOf(key)); + aggKeys.add(key); @SuppressWarnings("unchecked") Map aggr = (Map) entry.getValue(); assertEquals(1, aggr.size()); diff --git a/x-pack/plugin/sql/qa/src/main/resources/agg.csv-spec b/x-pack/plugin/sql/qa/src/main/resources/agg.csv-spec index 182b6c2c76f..ed1ae60b14c 100644 --- a/x-pack/plugin/sql/qa/src/main/resources/agg.csv-spec +++ b/x-pack/plugin/sql/qa/src/main/resources/agg.csv-spec @@ -604,9 +604,31 @@ SELECT COUNT(ALL first_name) all_names, COUNT(*) c FROM test_emp; all_names | c ---------------+--------------- -90 |100 +90 |100 ; +countDistinctAndLiteral +schema::ln:l|ccc:l +SELECT COUNT(last_name) ln, COUNT(*) ccc FROM test_emp GROUP BY gender HAVING ln>5 AND ccc>5; + + ln | ccc +---------------+------------- +10 |10 +33 |33 +57 |57 +; + +countSmallCountTypesWithHaving +schema::ln:l|dln:l|fn:l|dfn:l|ccc:l +SELECT COUNT(last_name) ln, COUNT(distinct last_name) dln, COUNT(first_name) fn, COUNT(distinct first_name) dfn, COUNT(*) ccc FROM test_emp GROUP BY gender HAVING dln>5 AND ln>32 AND dfn>1 AND fn>1 AND ccc>5; + + ln | dln | fn | dfn | ccc +---------------+-------------+---------------+------------+------------- +33 |32 |32 |32 |33 +57 |54 |48 |48 |57 +; + + countAllCountTypesWithHaving schema::ln:l|dln:l|fn:l|dfn:l|ccc:l SELECT COUNT(last_name) ln, COUNT(distinct last_name) dln, COUNT(first_name) fn, COUNT(distinct first_name) dfn, COUNT(*) ccc FROM test_emp GROUP BY gender HAVING dln>5 AND ln>32 AND dfn>1 AND fn>1 AND ccc>5; diff --git a/x-pack/plugin/sql/qa/src/main/resources/math.csv-spec b/x-pack/plugin/sql/qa/src/main/resources/math.csv-spec index 372614dcb00..a333de34987 100644 --- a/x-pack/plugin/sql/qa/src/main/resources/math.csv-spec +++ b/x-pack/plugin/sql/qa/src/main/resources/math.csv-spec @@ -64,8 +64,8 @@ SELECT TRUNC(salary, 2) TRUNCATED, salary FROM test_emp GROUP BY TRUNCATED, sala truncateWithAsciiAndOrderBy SELECT TRUNCATE(ASCII(LEFT(first_name,1)), -1) AS initial, first_name, ASCII(LEFT(first_name, 1)) FROM test_emp ORDER BY ASCII(LEFT(first_name, 1)) DESC LIMIT 15; - initial | first_name |ASCII(LEFT(first_name,1)) ----------------+---------------+------------------------- + initial | first_name |ASCII(LEFT(first_name, 1)) +---------------+---------------+-------------------------- 90 |Zvonko |90 90 |Zhongwei |90 80 |Yongqiao |89 diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Analyzer.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Analyzer.java index c790626c5fb..030d059ccfb 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Analyzer.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Analyzer.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.sql.expression.FieldAttribute; import org.elasticsearch.xpack.sql.expression.Foldables; import org.elasticsearch.xpack.sql.expression.NamedExpression; import org.elasticsearch.xpack.sql.expression.Order; +import org.elasticsearch.xpack.sql.expression.ReferenceAttribute; import org.elasticsearch.xpack.sql.expression.SubQueryExpression; import org.elasticsearch.xpack.sql.expression.UnresolvedAlias; import org.elasticsearch.xpack.sql.expression.UnresolvedAttribute; @@ -28,10 +29,8 @@ import org.elasticsearch.xpack.sql.expression.function.FunctionDefinition; import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.sql.expression.function.Functions; import org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Count; import org.elasticsearch.xpack.sql.expression.function.scalar.Cast; import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.ArithmeticOperation; -import org.elasticsearch.xpack.sql.expression.predicate.regex.RegexMatch; import org.elasticsearch.xpack.sql.plan.TableIdentifier; import org.elasticsearch.xpack.sql.plan.logical.Aggregate; import org.elasticsearch.xpack.sql.plan.logical.EsRelation; @@ -66,7 +65,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.stream.Collectors; +import java.util.stream.Stream; import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; @@ -329,12 +328,13 @@ public class Analyzer extends RuleExecutor { return new Aggregate(a.source(), a.child(), a.groupings(), expandProjections(a.aggregates(), a.child())); } - // if the grouping is unresolved but the aggs are, use the latter to resolve the former + // if the grouping is unresolved but the aggs are, use the former to resolve the latter // solves the case of queries declaring an alias in SELECT and referring to it in GROUP BY + // e.g. SELECT x AS a ... GROUP BY a if (!a.expressionsResolved() && Resolvables.resolved(a.aggregates())) { List groupings = a.groupings(); List newGroupings = new ArrayList<>(); - AttributeMap resolved = Expressions.asAttributeMap(a.aggregates()); + AttributeMap resolved = Expressions.aliases(a.aggregates()); boolean changed = false; for (Expression grouping : groupings) { if (grouping instanceof UnresolvedAttribute) { @@ -363,9 +363,10 @@ public class Analyzer extends RuleExecutor { else if (plan instanceof OrderBy) { OrderBy o = (OrderBy) plan; if (!o.resolved()) { - List resolvedOrder = o.order().stream() - .map(or -> resolveExpression(or, o.child())) - .collect(toList()); + List resolvedOrder = new ArrayList<>(o.order().size()); + for (Order order : o.order()) { + resolvedOrder.add(resolveExpression(order, o.child())); + } return new OrderBy(o.source(), o.child(), resolvedOrder); } } @@ -606,19 +607,53 @@ public class Analyzer extends RuleExecutor { if (plan instanceof OrderBy && !plan.resolved() && plan.childrenResolved()) { OrderBy o = (OrderBy) plan; - List maybeResolved = o.order().stream() - .map(or -> tryResolveExpression(or, o.child())) - .collect(toList()); + LogicalPlan child = o.child(); + List maybeResolved = new ArrayList<>(); + for (Order or : o.order()) { + maybeResolved.add(or.resolved() ? or : tryResolveExpression(or, child)); + } + + Stream referencesStream = maybeResolved.stream() + .filter(Expression::resolved); + // if there are any references in the output + // try and resolve them to the source in order to compare the source expressions + // e.g. ORDER BY a + 1 + // \ SELECT a + 1 + // a + 1 in SELECT is actually Alias("a + 1", a + 1) and translates to ReferenceAttribute + // in the output. However it won't match the unnamed a + 1 despite being the same expression + // so explicitly compare the source + + // if there's a match, remove the item from the reference stream + if (Expressions.hasReferenceAttribute(child.outputSet())) { + final Map collectRefs = new LinkedHashMap<>(); - Set resolvedRefs = maybeResolved.stream() - .filter(Expression::resolved) - .collect(Collectors.toSet()); + // collect aliases + child.forEachUp(p -> p.forEachExpressionsUp(e -> { + if (e instanceof Alias) { + Alias a = (Alias) e; + collectRefs.put(a.toAttribute(), a.child()); + } + })); - AttributeSet missing = Expressions.filterReferences( - resolvedRefs, - o.child().outputSet() - ); + referencesStream = referencesStream.filter(r -> { + for (Attribute attr : child.outputSet()) { + if (attr instanceof ReferenceAttribute) { + Expression source = collectRefs.getOrDefault(attr, attr); + // found a match, no need to resolve it further + // so filter it out + if (source.equals(r.child())) { + return false; + } + } + } + return true; + }); + } + + AttributeSet resolvedRefs = Expressions.references(referencesStream.collect(toList())); + + AttributeSet missing = resolvedRefs.subtract(child.outputSet()); if (!missing.isEmpty()) { // Add missing attributes but project them away afterwards @@ -650,6 +685,7 @@ public class Analyzer extends RuleExecutor { if (plan instanceof Filter && !plan.resolved() && plan.childrenResolved()) { Filter f = (Filter) plan; Expression maybeResolved = tryResolveExpression(f.condition(), f.child()); + AttributeSet resolvedRefs = new AttributeSet(maybeResolved.references().stream() .filter(Expression::resolved) .collect(toList())); @@ -708,9 +744,11 @@ public class Analyzer extends RuleExecutor { if (plan instanceof Aggregate) { Aggregate a = (Aggregate) plan; // missing attributes can only be grouping expressions + // however take into account aliased groups + // SELECT x AS i ... GROUP BY i for (Attribute m : missing) { - // but we don't can't add an agg if the group is missing - if (!Expressions.anyMatch(a.groupings(), m::semanticEquals)) { + // but we can't add an agg if the group is missing + if (!Expressions.match(a.groupings(), m::semanticEquals)) { if (m instanceof Attribute) { // pass failure information to help the verifier m = new UnresolvedAttribute(m.source(), m.name(), m.qualifier(), null, null, @@ -758,7 +796,7 @@ public class Analyzer extends RuleExecutor { // SELECT int AS i FROM t WHERE i > 10 // // As such, identify all project and aggregates that have a Filter child - // and look at any resoled aliases that match and replace them. + // and look at any resolved aliases that match and replace them. private class ResolveFilterRefs extends AnalyzeRule { @Override @@ -815,49 +853,10 @@ public class Analyzer extends RuleExecutor { } } - // to avoid creating duplicate functions - // this rule does two iterations - // 1. collect all functions - // 2. search unresolved functions and first try resolving them from already 'seen' functions private class ResolveFunctions extends AnalyzeRule { @Override protected LogicalPlan rule(LogicalPlan plan) { - Map> seen = new LinkedHashMap<>(); - // collect (and replace duplicates) - LogicalPlan p = plan.transformExpressionsUp(e -> collectResolvedAndReplace(e, seen)); - // resolve based on seen - return resolve(p, seen); - } - - private Expression collectResolvedAndReplace(Expression e, Map> seen) { - if (e instanceof Function && e.resolved()) { - Function f = (Function) e; - String fName = f.functionName(); - // the function is resolved and its name normalized already - List list = getList(seen, fName); - for (Function seenFunction : list) { - if (seenFunction != f && f.arguments().equals(seenFunction.arguments())) { - // TODO: we should move to always compare the functions directly - // Special check for COUNT: an already seen COUNT function will be returned only if its DISTINCT property - // matches the one from the unresolved function to be checked. - // Same for LIKE/RLIKE: the equals function also compares the pattern of LIKE/RLIKE - if (seenFunction instanceof Count || seenFunction instanceof RegexMatch) { - if (seenFunction.equals(f)){ - return seenFunction; - } - } else { - return seenFunction; - } - } - } - list.add(f); - } - - return e; - } - - protected LogicalPlan resolve(LogicalPlan plan, Map> seen) { return plan.transformExpressionsUp(e -> { if (e instanceof UnresolvedFunction) { UnresolvedFunction uf = (UnresolvedFunction) e; @@ -880,48 +879,17 @@ public class Analyzer extends RuleExecutor { } String functionName = functionRegistry.resolveAlias(name); - - List list = getList(seen, functionName); - // first try to resolve from seen functions - if (!list.isEmpty()) { - for (Function seenFunction : list) { - if (uf.arguments().equals(seenFunction.arguments())) { - // Special check for COUNT: an already seen COUNT function will be returned only if its DISTINCT property - // matches the one from the unresolved function to be checked. - if (seenFunction instanceof Count) { - if (uf.sameAs((Count) seenFunction)) { - return seenFunction; - } - } else { - return seenFunction; - } - } - } - } - - // not seen before, use the registry - if (!functionRegistry.functionExists(functionName)) { + if (functionRegistry.functionExists(functionName) == false) { return uf.missing(functionName, functionRegistry.listFunctions()); } // TODO: look into Generator for significant terms, etc.. FunctionDefinition def = functionRegistry.resolveFunction(functionName); Function f = uf.buildResolved(configuration, def); - - list.add(f); return f; } return e; }); } - - private List getList(Map> seen, String name) { - List list = seen.get(name); - if (list == null) { - list = new ArrayList<>(); - seen.put(name, list); - } - return list; - } } private static class ResolveAliases extends AnalyzeRule { @@ -1103,11 +1071,12 @@ public class Analyzer extends RuleExecutor { Set missing = new LinkedHashSet<>(); for (Expression filterAgg : from.collect(Functions::isAggregate)) { - if (!Expressions.anyMatch(target.aggregates(), - a -> { - Attribute attr = Expressions.attribute(a); - return attr != null && attr.semanticEquals(Expressions.attribute(filterAgg)); - })) { + if (Expressions.anyMatch(target.aggregates(), a -> { + if (a instanceof Alias) { + a = ((Alias) a).child(); + } + return a.equals(filterAgg); + }) == false) { missing.add(Expressions.wrapAsNamed(filterAgg)); } } @@ -1135,10 +1104,10 @@ public class Analyzer extends RuleExecutor { List orders = ob.order(); // 1. collect aggs inside an order by - List aggs = new ArrayList<>(); + List aggs = new ArrayList<>(); for (Order order : orders) { if (Functions.isAggregate(order.child())) { - aggs.add(Expressions.wrapAsNamed(order.child())); + aggs.add(order.child()); } } if (aggs.isEmpty()) { @@ -1154,9 +1123,14 @@ public class Analyzer extends RuleExecutor { List missing = new ArrayList<>(); - for (NamedExpression orderedAgg : aggs) { - if (Expressions.anyMatch(a.aggregates(), e -> Expressions.equalsAsAttribute(e, orderedAgg)) == false) { - missing.add(orderedAgg); + for (Expression orderedAgg : aggs) { + if (Expressions.anyMatch(a.aggregates(), e -> { + if (e instanceof Alias) { + e = ((Alias) e).child(); + } + return e.equals(orderedAgg); + }) == false) { + missing.add(Expressions.wrapAsNamed(orderedAgg)); } } // agg already contains all aggs @@ -1176,39 +1150,6 @@ public class Analyzer extends RuleExecutor { } } - private class PruneDuplicateFunctions extends AnalyzeRule { - - @Override - protected boolean skipResolved() { - return false; - } - - @Override - public LogicalPlan rule(LogicalPlan plan) { - List seen = new ArrayList<>(); - LogicalPlan p = plan.transformExpressionsUp(e -> rule(e, seen)); - return p; - } - - private Expression rule(Expression e, List seen) { - if (e instanceof Function) { - Function f = (Function) e; - for (Function seenFunction : seen) { - if (seenFunction != f && functionsEquals(f, seenFunction)) { - return seenFunction; - } - } - seen.add(f); - } - - return e; - } - - private boolean functionsEquals(Function f, Function seenFunction) { - return f.sourceText().equals(seenFunction.sourceText()) && f.arguments().equals(seenFunction.arguments()); - } - } - private class ImplicitCasting extends AnalyzeRule { @Override @@ -1282,7 +1223,7 @@ public class Analyzer extends RuleExecutor { if (plan instanceof Aggregate) { Aggregate a = (Aggregate) plan; - // aliases inside GROUP BY are irellevant so remove all of them + // aliases inside GROUP BY are irrelevant so remove all of them // however aggregations are important (ultimately a projection) return new Aggregate(a.source(), a.child(), cleanAllAliases(a.groupings()), cleanChildrenAliases(a.aggregates())); } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java index 3f5caa064a2..34def1238d0 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.sql.analysis.analyzer; import org.elasticsearch.xpack.sql.capabilities.Unresolvable; import org.elasticsearch.xpack.sql.expression.Alias; import org.elasticsearch.xpack.sql.expression.Attribute; +import org.elasticsearch.xpack.sql.expression.AttributeMap; import org.elasticsearch.xpack.sql.expression.AttributeSet; import org.elasticsearch.xpack.sql.expression.Exists; import org.elasticsearch.xpack.sql.expression.Expression; @@ -15,18 +16,16 @@ import org.elasticsearch.xpack.sql.expression.Expressions; import org.elasticsearch.xpack.sql.expression.FieldAttribute; import org.elasticsearch.xpack.sql.expression.Literal; import org.elasticsearch.xpack.sql.expression.NamedExpression; +import org.elasticsearch.xpack.sql.expression.ReferenceAttribute; import org.elasticsearch.xpack.sql.expression.UnresolvedAttribute; import org.elasticsearch.xpack.sql.expression.function.Function; -import org.elasticsearch.xpack.sql.expression.function.FunctionAttribute; import org.elasticsearch.xpack.sql.expression.function.Functions; import org.elasticsearch.xpack.sql.expression.function.Score; import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction; -import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute; import org.elasticsearch.xpack.sql.expression.function.aggregate.Max; import org.elasticsearch.xpack.sql.expression.function.aggregate.Min; import org.elasticsearch.xpack.sql.expression.function.aggregate.TopHits; import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunction; -import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunctionAttribute; import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction; import org.elasticsearch.xpack.sql.plan.logical.Aggregate; import org.elasticsearch.xpack.sql.plan.logical.Distinct; @@ -215,8 +214,19 @@ public final class Verifier { // if there are no (major) unresolved failures, do more in-depth analysis if (failures.isEmpty()) { - // collect Function to better reason about encountered attributes - Map resolvedFunctions = Functions.collectFunctions(plan); + final Map collectRefs = new LinkedHashMap<>(); + + // collect Attribute sources + // only Aliases are interesting since these are the only ones that hide expressions + // FieldAttribute for example are self replicating. + plan.forEachUp(p -> p.forEachExpressionsUp(e -> { + if (e instanceof Alias) { + Alias a = (Alias) e; + collectRefs.put(a.toAttribute(), a.child()); + } + })); + + AttributeMap attributeRefs = new AttributeMap<>(collectRefs); // for filtering out duplicated errors final Set groupingFailures = new LinkedHashSet<>(); @@ -234,17 +244,17 @@ public final class Verifier { Set localFailures = new LinkedHashSet<>(); checkGroupingFunctionInGroupBy(p, localFailures); - checkFilterOnAggs(p, localFailures); - checkFilterOnGrouping(p, localFailures); + checkFilterOnAggs(p, localFailures, attributeRefs); + checkFilterOnGrouping(p, localFailures, attributeRefs); - if (!groupingFailures.contains(p)) { - checkGroupBy(p, localFailures, resolvedFunctions, groupingFailures); + if (groupingFailures.contains(p) == false) { + checkGroupBy(p, localFailures, attributeRefs, groupingFailures); } checkForScoreInsideFunctions(p, localFailures); checkNestedUsedInGroupByOrHaving(p, localFailures); checkForGeoFunctionsOnDocValues(p, localFailures); - checkPivot(p, localFailures); + checkPivot(p, localFailures, attributeRefs); // everything checks out // mark the plan as analyzed @@ -297,17 +307,18 @@ public final class Verifier { * 2a. HAVING also requires an Aggregate function * 3. composite agg (used for GROUP BY) allows ordering only on the group keys */ - private static boolean checkGroupBy(LogicalPlan p, Set localFailures, - Map resolvedFunctions, Set groupingFailures) { + private static boolean checkGroupBy(LogicalPlan p, Set localFailures, AttributeMap attributeRefs, + Set groupingFailures) { return checkGroupByInexactField(p, localFailures) - && checkGroupByAgg(p, localFailures, resolvedFunctions) - && checkGroupByOrder(p, localFailures, groupingFailures) - && checkGroupByHaving(p, localFailures, groupingFailures, resolvedFunctions) + && checkGroupByAgg(p, localFailures, attributeRefs) + && checkGroupByOrder(p, localFailures, groupingFailures, attributeRefs) + && checkGroupByHaving(p, localFailures, groupingFailures, attributeRefs) && checkGroupByTime(p, localFailures); } // check whether an orderBy failed or if it occurs on a non-key - private static boolean checkGroupByOrder(LogicalPlan p, Set localFailures, Set groupingFailures) { + private static boolean checkGroupByOrder(LogicalPlan p, Set localFailures, Set groupingFailures, + AttributeMap attributeRefs) { if (p instanceof OrderBy) { OrderBy o = (OrderBy) p; LogicalPlan child = o.child(); @@ -328,7 +339,7 @@ public final class Verifier { Expression e = oe.child(); // aggregates are allowed - if (Functions.isAggregate(e) || e instanceof AggregateFunctionAttribute) { + if (Functions.isAggregate(attributeRefs.getOrDefault(e, e))) { return; } @@ -375,7 +386,7 @@ public final class Verifier { } private static boolean checkGroupByHaving(LogicalPlan p, Set localFailures, - Set groupingFailures, Map functions) { + Set groupingFailures, AttributeMap attributeRefs) { if (p instanceof Filter) { Filter f = (Filter) p; if (f.child() instanceof Aggregate) { @@ -385,7 +396,7 @@ public final class Verifier { Set unsupported = new LinkedHashSet<>(); Expression condition = f.condition(); // variation of checkGroupMatch customized for HAVING, which requires just aggregations - condition.collectFirstChildren(c -> checkGroupByHavingHasOnlyAggs(c, missing, unsupported, functions)); + condition.collectFirstChildren(c -> checkGroupByHavingHasOnlyAggs(c, missing, unsupported, attributeRefs)); if (!missing.isEmpty()) { String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY; @@ -411,17 +422,11 @@ public final class Verifier { private static boolean checkGroupByHavingHasOnlyAggs(Expression e, Set missing, - Set unsupported, Map functions) { + Set unsupported, AttributeMap attributeRefs) { // resolve FunctionAttribute to backing functions - if (e instanceof FunctionAttribute) { - FunctionAttribute fa = (FunctionAttribute) e; - Function function = functions.get(fa.functionId()); - // TODO: this should be handled by a different rule - if (function == null) { - return false; - } - e = function; + if (e instanceof ReferenceAttribute) { + e = attributeRefs.get(e); } // scalar functions can be a binary tree @@ -432,7 +437,7 @@ public final class Verifier { // unwrap function to find the base for (Expression arg : sf.arguments()) { - arg.collectFirstChildren(c -> checkGroupByHavingHasOnlyAggs(c, missing, unsupported, functions)); + arg.collectFirstChildren(c -> checkGroupByHavingHasOnlyAggs(c, missing, unsupported, attributeRefs)); } return true; @@ -449,7 +454,7 @@ public final class Verifier { // Min & Max on a Keyword field will be translated to First & Last respectively unsupported.add(e); return true; - } + } } // skip literals / foldable @@ -482,22 +487,26 @@ public final class Verifier { Holder onlyExact = new Holder<>(Boolean.TRUE); expressions.forEach(e -> e.forEachUp(c -> { - EsField.Exact exact = c.getExactInfo(); - if (exact.hasExact() == false) { + EsField.Exact exact = c.getExactInfo(); + if (exact.hasExact() == false) { localFailures.add(fail(c, "Field [{}] of data type [{}] cannot be used for grouping; {}", c.sourceText(), c.dataType().typeName, exact.errorMsg())); onlyExact.set(Boolean.FALSE); - } - }, FieldAttribute.class)); + } + }, FieldAttribute.class)); return onlyExact.get(); } - private static boolean onlyRawFields(Iterable expressions, Set localFailures) { + private static boolean onlyRawFields(Iterable expressions, Set localFailures, + AttributeMap attributeRefs) { Holder onlyExact = new Holder<>(Boolean.TRUE); expressions.forEach(e -> e.forEachDown(c -> { - if (c instanceof Function || c instanceof FunctionAttribute) { + if (c instanceof ReferenceAttribute) { + c = attributeRefs.getOrDefault(c, c); + } + if (c instanceof Function) { localFailures.add(fail(c, "No functions allowed (yet); encountered [{}]", c.sourceText())); onlyExact.set(Boolean.FALSE); } @@ -522,7 +531,7 @@ public final class Verifier { } // check whether plain columns specified in an agg are mentioned in the group-by - private static boolean checkGroupByAgg(LogicalPlan p, Set localFailures, Map functions) { + private static boolean checkGroupByAgg(LogicalPlan p, Set localFailures, AttributeMap attributeRefs) { if (p instanceof Aggregate) { Aggregate a = (Aggregate) p; @@ -566,7 +575,7 @@ public final class Verifier { Map> missing = new LinkedHashMap<>(); a.aggregates().forEach(ne -> - ne.collectFirstChildren(c -> checkGroupMatch(c, ne, a.groupings(), missing, functions))); + ne.collectFirstChildren(c -> checkGroupMatch(c, ne, a.groupings(), missing, attributeRefs))); if (!missing.isEmpty()) { String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY; @@ -581,23 +590,16 @@ public final class Verifier { } private static boolean checkGroupMatch(Expression e, Node source, List groupings, - Map> missing, Map functions) { + Map> missing, AttributeMap attributeRefs) { // 1:1 match if (Expressions.match(groupings, e::semanticEquals)) { return true; } - // resolve FunctionAttribute to backing functions - if (e instanceof FunctionAttribute) { - FunctionAttribute fa = (FunctionAttribute) e; - Function function = functions.get(fa.functionId()); - // TODO: this should be handled by a different rule - if (function == null) { - return false; - } - e = function; + if (e instanceof ReferenceAttribute) { + e = attributeRefs.get(e); } // scalar functions can be a binary tree @@ -613,7 +615,7 @@ public final class Verifier { // unwrap function to find the base for (Expression arg : sf.arguments()) { - arg.collectFirstChildren(c -> checkGroupMatch(c, source, groupings, missing, functions)); + arg.collectFirstChildren(c -> checkGroupMatch(c, source, groupings, missing, attributeRefs)); } return true; @@ -658,7 +660,7 @@ public final class Verifier { Aggregate a = (Aggregate) p; a.aggregates().forEach(agg -> agg.forEachDown(e -> { if (a.groupings().size() == 0 - || Expressions.anyMatch(a.groupings(), g -> g instanceof Function && e.functionEquals((Function) g)) == false) { + || Expressions.anyMatch(a.groupings(), g -> g instanceof Function && e.equals(g)) == false) { localFailures.add(fail(e, "[{}] needs to be part of the grouping", Expressions.name(e))); } else { @@ -681,12 +683,12 @@ public final class Verifier { }); } - private static void checkFilterOnAggs(LogicalPlan p, Set localFailures) { + private static void checkFilterOnAggs(LogicalPlan p, Set localFailures, AttributeMap attributeRefs) { if (p instanceof Filter) { Filter filter = (Filter) p; if ((filter.child() instanceof Aggregate) == false) { filter.condition().forEachDown(e -> { - if (Functions.isAggregate(e) || e instanceof AggregateFunctionAttribute) { + if (Functions.isAggregate(attributeRefs.getOrDefault(e, e)) == true) { localFailures.add( fail(e, "Cannot use WHERE filtering on aggregate function [{}], use HAVING instead", Expressions.name(e))); } @@ -696,11 +698,11 @@ public final class Verifier { } - private static void checkFilterOnGrouping(LogicalPlan p, Set localFailures) { + private static void checkFilterOnGrouping(LogicalPlan p, Set localFailures, AttributeMap attributeRefs) { if (p instanceof Filter) { Filter filter = (Filter) p; filter.condition().forEachDown(e -> { - if (Functions.isGrouping(e) || e instanceof GroupingFunctionAttribute) { + if (Functions.isGrouping(attributeRefs.getOrDefault(e, e))) { localFailures .add(fail(e, "Cannot filter on grouping function [{}], use its argument instead", Expressions.name(e))); } @@ -787,11 +789,11 @@ public final class Verifier { }, FieldAttribute.class)), OrderBy.class); } - private static void checkPivot(LogicalPlan p, Set localFailures) { + private static void checkPivot(LogicalPlan p, Set localFailures, AttributeMap attributeRefs) { p.forEachDown(pv -> { // check only exact fields are used inside PIVOTing if (onlyExactFields(combine(pv.groupingSet(), pv.column()), localFailures) == false - || onlyRawFields(pv.groupingSet(), localFailures) == false) { + || onlyRawFields(pv.groupingSet(), localFailures, attributeRefs) == false) { // if that is not the case, no need to do further validation since the declaration is fundamentally wrong return; } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/execution/search/Querier.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/execution/search/Querier.java index 333d320e908..a47af9c20b8 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/execution/search/Querier.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/execution/search/Querier.java @@ -39,7 +39,6 @@ import org.elasticsearch.xpack.sql.execution.search.extractor.MetricAggExtractor import org.elasticsearch.xpack.sql.execution.search.extractor.PivotExtractor; import org.elasticsearch.xpack.sql.execution.search.extractor.TopHitsAggExtractor; import org.elasticsearch.xpack.sql.expression.Attribute; -import org.elasticsearch.xpack.sql.expression.ExpressionId; import org.elasticsearch.xpack.sql.expression.gen.pipeline.AggExtractorInput; import org.elasticsearch.xpack.sql.expression.gen.pipeline.AggPathInput; import org.elasticsearch.xpack.sql.expression.gen.pipeline.HitExtractorInput; @@ -377,11 +376,11 @@ public class Querier { protected List initBucketExtractors(SearchResponse response) { // create response extractors for the first time - List> refs = query.fields(); + List> refs = query.fields(); List exts = new ArrayList<>(refs.size()); ConstantExtractor totalCount = new ConstantExtractor(response.getHits().getTotalHits().value); - for (Tuple ref : refs) { + for (Tuple ref : refs) { exts.add(createExtractor(ref.v1(), totalCount)); } return exts; @@ -447,10 +446,10 @@ public class Querier { @Override protected void handleResponse(SearchResponse response, ActionListener listener) { // create response extractors for the first time - List> refs = query.fields(); + List> refs = query.fields(); List exts = new ArrayList<>(refs.size()); - for (Tuple ref : refs) { + for (Tuple ref : refs) { exts.add(createExtractor(ref.v1())); } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Alias.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Alias.java index 4ebc030c281..ef8611b4969 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Alias.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Alias.java @@ -5,14 +5,10 @@ */ package org.elasticsearch.xpack.sql.expression; -import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; -import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; import org.elasticsearch.xpack.sql.tree.NodeInfo; import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.type.DataType; -import org.elasticsearch.xpack.sql.type.EsField; -import java.util.Collections; import java.util.List; import static java.util.Collections.singletonList; @@ -44,11 +40,11 @@ public class Alias extends NamedExpression { this(source, name, qualifier, child, null); } - public Alias(Source source, String name, String qualifier, Expression child, ExpressionId id) { + public Alias(Source source, String name, String qualifier, Expression child, NameId id) { this(source, name, qualifier, child, id, false); } - public Alias(Source source, String name, String qualifier, Expression child, ExpressionId id, boolean synthetic) { + public Alias(Source source, String name, String qualifier, Expression child, NameId id, boolean synthetic) { super(source, name, singletonList(child), id, synthetic); this.child = child; this.qualifier = qualifier; @@ -92,35 +88,13 @@ public class Alias extends NamedExpression { @Override public Attribute toAttribute() { if (lazyAttribute == null) { - lazyAttribute = createAttribute(); + lazyAttribute = resolved() == true ? + new ReferenceAttribute(source(), name(), dataType(), qualifier, nullable(), id(), synthetic()) : + new UnresolvedAttribute(source(), name(), qualifier); } return lazyAttribute; } - @Override - public ScriptTemplate asScript() { - throw new SqlIllegalArgumentException("Encountered a bug; an alias should never be scripted"); - } - - private Attribute createAttribute() { - if (resolved()) { - Expression c = child(); - - Attribute attr = Expressions.attribute(c); - if (attr != null) { - return attr.clone(source(), name(), child.dataType(), qualifier, child.nullable(), id(), synthetic()); - } - else { - // TODO: WE need to fix this fake Field - return new FieldAttribute(source(), null, name(), - new EsField(name(), child.dataType(), Collections.emptyMap(), true), - qualifier, child.nullable(), id(), synthetic()); - } - } - - return new UnresolvedAttribute(source(), name(), qualifier); - } - @Override public String toString() { return child + " AS " + name() + "#" + id(); diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Attribute.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Attribute.java index 9f6b54badaf..bda8287115e 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Attribute.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Attribute.java @@ -5,9 +5,6 @@ */ package org.elasticsearch.xpack.sql.expression; -import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; -import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; -import org.elasticsearch.xpack.sql.tree.NodeInfo; import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.type.DataType; @@ -17,24 +14,16 @@ import java.util.Objects; import static java.util.Collections.emptyList; /** - * {@link Expression}s that can be materialized and represent the result columns sent to the client. - * Typically are converted into constants, functions or Elasticsearch order-bys, - * aggregations, or queries. They can also be extracted from the result of a search. - * + * {@link Expression}s that can be materialized and describe properties of the derived table. + * In other words, an attribute represent a column in the results of a query. + * * In the statement {@code SELECT ABS(foo), A, B+C FROM ...} the three named * expressions {@code ABS(foo), A, B+C} get converted to attributes and the user can * only see Attributes. * - * In the statement {@code SELECT foo FROM TABLE WHERE foo > 10 + 1} both {@code foo} and - * {@code 10 + 1} are named expressions, the first due to the SELECT, the second due to being a function. - * However since {@code 10 + 1} is used for filtering it doesn't appear appear in the result set - * (derived table) and as such it is never translated to an attribute. - * "foo" on the other hand is since it's a column in the result set. - * - * Another example {@code SELECT foo FROM ... WHERE bar > 10 +1} {@code foo} gets - * converted into an Attribute, bar does not. That's because {@code bar} is used for - * filtering alone but it's not part of the projection meaning the user doesn't - * need it in the derived table. + * In the statement {@code SELECT foo FROM TABLE WHERE foo > 10 + 1} only {@code foo} inside the SELECT + * is a named expression (an {@code Alias} will be created automatically for it). + * The rest are not as they are not part of the projection and thus are not part of the derived table. */ public abstract class Attribute extends NamedExpression { @@ -45,15 +34,15 @@ public abstract class Attribute extends NamedExpression { // can the attr be null - typically used in JOINs private final Nullability nullability; - public Attribute(Source source, String name, String qualifier, ExpressionId id) { + public Attribute(Source source, String name, String qualifier, NameId id) { this(source, name, qualifier, Nullability.TRUE, id); } - public Attribute(Source source, String name, String qualifier, Nullability nullability, ExpressionId id) { + public Attribute(Source source, String name, String qualifier, Nullability nullability, NameId id) { this(source, name, qualifier, nullability, id, false); } - public Attribute(Source source, String name, String qualifier, Nullability nullability, ExpressionId id, boolean synthetic) { + public Attribute(Source source, String name, String qualifier, Nullability nullability, NameId id, boolean synthetic) { super(source, name, emptyList(), id, synthetic); this.qualifier = qualifier; this.nullability = nullability; @@ -64,11 +53,6 @@ public abstract class Attribute extends NamedExpression { throw new UnsupportedOperationException("this type of node doesn't have any children to replace"); } - @Override - public ScriptTemplate asScript() { - throw new SqlIllegalArgumentException("Encountered a bug - an attribute should never be scripted"); - } - public String qualifier() { return qualifier; } @@ -105,16 +89,16 @@ public abstract class Attribute extends NamedExpression { synthetic()); } + public Attribute withId(NameId id) { + return clone(source(), name(), dataType(), qualifier(), nullable(), id, synthetic()); + } + public Attribute withDataType(DataType type) { return Objects.equals(dataType(), type) ? this : clone(source(), name(), type, qualifier(), nullable(), id(), synthetic()); } - public Attribute withId(ExpressionId id) { - return clone(source(), name(), dataType(), qualifier(), nullable(), id, synthetic()); - } - protected abstract Attribute clone(Source source, String name, DataType type, String qualifier, Nullability nullability, - ExpressionId id, boolean synthetic); + NameId id, boolean synthetic); @Override public Attribute toAttribute() { @@ -126,11 +110,6 @@ public abstract class Attribute extends NamedExpression { return id().hashCode(); } - @Override - protected NodeInfo info() { - return null; - } - @Override public boolean semanticEquals(Expression other) { return other instanceof Attribute ? id().equals(((Attribute) other).id()) : false; @@ -154,7 +133,12 @@ public abstract class Attribute extends NamedExpression { @Override public String toString() { - return name() + "{" + label() + "}" + "#" + id(); + return qualifiedName() + "{" + label() + "}" + "#" + id(); + } + + @Override + public String nodeString() { + return toString(); } protected abstract String label(); diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/AttributeMap.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/AttributeMap.java index bb8d373f98b..c4c26729c6b 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/AttributeMap.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/AttributeMap.java @@ -32,14 +32,14 @@ public class AttributeMap implements Map { @Override public int hashCode() { - return attr.hashCode(); + return attr.semanticHash(); } @Override public boolean equals(Object obj) { if (obj instanceof AttributeWrapper) { AttributeWrapper aw = (AttributeWrapper) obj; - return attr.equals(aw.attr); + return attr.semanticEquals(aw.attr); } return false; diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Exists.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Exists.java index 2363b52316c..d481d8e115f 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Exists.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Exists.java @@ -16,7 +16,7 @@ public class Exists extends SubQueryExpression { this(source, query, null); } - public Exists(Source source, LogicalPlan query, ExpressionId id) { + public Exists(Source source, LogicalPlan query, NameId id) { super(source, query, id); } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expression.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expression.java index 2dde7e5f97d..e2e3f99ca87 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expression.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expression.java @@ -128,6 +128,11 @@ public abstract class Expression extends Node implements Resolvable @Override public String toString() { - return nodeName() + "[" + propertiesToString(false) + "]"; + return sourceText(); } -} + + @Override + public String propertiesToString(boolean skipIfChild) { + return super.propertiesToString(false); + } +} \ No newline at end of file diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expressions.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expressions.java index 3e5450f01ac..92703f4768f 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expressions.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expressions.java @@ -6,17 +6,21 @@ package org.elasticsearch.xpack.sql.expression; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; +import org.elasticsearch.xpack.sql.expression.function.Function; +import org.elasticsearch.xpack.sql.expression.gen.pipeline.AttributeInput; +import org.elasticsearch.xpack.sql.expression.gen.pipeline.ConstantInput; import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; import org.elasticsearch.xpack.sql.type.DataType; import org.elasticsearch.xpack.sql.type.DataTypes; import java.util.ArrayList; import java.util.Collection; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.function.Predicate; -import java.util.stream.Collectors; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; @@ -103,33 +107,8 @@ public final class Expressions { return set; } - public static AttributeSet filterReferences(Set exps, AttributeSet excluded) { - AttributeSet ret = new AttributeSet(); - while (exps.size() > 0) { - - Set filteredExps = new LinkedHashSet<>(); - for (Expression exp : exps) { - Expression attr = Expressions.attribute(exp); - if (attr == null || (excluded.contains(attr) == false)) { - filteredExps.add(exp); - } - } - - ret.addAll(new AttributeSet( - filteredExps.stream().filter(c->c.children().isEmpty()) - .flatMap(exp->exp.references().stream()) - .collect(Collectors.toSet()) - )); - - exps = filteredExps.stream() - .flatMap((Expression exp)->exp.children().stream()) - .collect(Collectors.toSet()); - } - return ret; - } - public static String name(Expression e) { - return e instanceof NamedExpression ? ((NamedExpression) e).name() : e.nodeName(); + return e instanceof NamedExpression ? ((NamedExpression) e).name() : e.sourceText(); } public static boolean isNull(Expression e) { @@ -149,9 +128,6 @@ public final class Expressions { if (e instanceof NamedExpression) { return ((NamedExpression) e).toAttribute(); } - if (e != null && e.foldable()) { - return Literal.of(e).toAttribute(); - } return null; } @@ -163,6 +139,25 @@ public final class Expressions { return true; } + public static AttributeMap aliases(List named) { + Map aliasMap = new LinkedHashMap<>(); + for (NamedExpression ne : named) { + if (ne instanceof Alias) { + aliasMap.put(ne.toAttribute(), ((Alias) ne).child()); + } + } + return new AttributeMap<>(aliasMap); + } + + public static boolean hasReferenceAttribute(Collection output) { + for (Attribute attribute : output) { + if (attribute instanceof ReferenceAttribute) { + return true; + } + } + return false; + } + public static List onlyPrimitiveFieldAttributes(Collection attributes) { List filtered = new ArrayList<>(); // add only primitives @@ -188,8 +183,14 @@ public final class Expressions { } public static Pipe pipe(Expression e) { + if (e.foldable()) { + return new ConstantInput(e.source(), e, e.fold()); + } if (e instanceof NamedExpression) { - return ((NamedExpression) e).asPipe(); + return new AttributeInput(e.source(), e, ((NamedExpression) e).toAttribute()); + } + if (e instanceof Function) { + return ((Function) e).asPipe(); } throw new SqlIllegalArgumentException("Cannot create pipe for {}", e); } @@ -201,4 +202,8 @@ public final class Expressions { } return pipes; } -} + + public static String id(Expression e) { + return Integer.toHexString(e.hashCode()); + } +} \ No newline at end of file diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/FieldAttribute.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/FieldAttribute.java index 625a679898a..f802c9a940d 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/FieldAttribute.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/FieldAttribute.java @@ -36,14 +36,14 @@ public class FieldAttribute extends TypedAttribute { public FieldAttribute(Source source, FieldAttribute parent, String name, EsField field) { this(source, parent, name, field, null, Nullability.TRUE, null, false); } - + public FieldAttribute(Source source, FieldAttribute parent, String name, EsField field, String qualifier, Nullability nullability, - ExpressionId id, boolean synthetic) { + NameId id, boolean synthetic) { this(source, parent, name, field.getDataType(), field, qualifier, nullability, id, synthetic); } public FieldAttribute(Source source, FieldAttribute parent, String name, DataType type, EsField field, String qualifier, - Nullability nullability, ExpressionId id, boolean synthetic) { + Nullability nullability, NameId id, boolean synthetic) { super(source, name, type, qualifier, nullability, id, synthetic); this.path = parent != null ? parent.name() : StringUtils.EMPTY; this.parent = parent; @@ -103,8 +103,8 @@ public class FieldAttribute extends TypedAttribute { } @Override - protected Attribute clone(Source source, String name, DataType type, String qualifier, - Nullability nullability, ExpressionId id, boolean synthetic) { + protected Attribute clone(Source source, String name, DataType type, String qualifier, Nullability nullability, NameId id, + boolean synthetic) { FieldAttribute qualifiedParent = parent != null ? (FieldAttribute) parent.withQualifier(qualifier) : null; return new FieldAttribute(source, qualifiedParent, name, field, qualifier, nullability, id, synthetic); } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Literal.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Literal.java index b22483bda36..315b1bb308e 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Literal.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Literal.java @@ -6,23 +6,18 @@ package org.elasticsearch.xpack.sql.expression; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; -import org.elasticsearch.xpack.sql.expression.gen.script.Params; -import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; import org.elasticsearch.xpack.sql.tree.NodeInfo; import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.type.DataType; import org.elasticsearch.xpack.sql.type.DataTypeConversion; import org.elasticsearch.xpack.sql.type.DataTypes; -import java.util.List; import java.util.Objects; -import static java.util.Collections.emptyList; - /** * SQL Literal or constant. */ -public class Literal extends NamedExpression { +public class Literal extends LeafExpression { public static final Literal TRUE = Literal.of(Source.EMPTY, Boolean.TRUE); public static final Literal FALSE = Literal.of(Source.EMPTY, Boolean.FALSE); @@ -32,11 +27,7 @@ public class Literal extends NamedExpression { private final DataType dataType; public Literal(Source source, Object value, DataType dataType) { - this(source, null, value, dataType); - } - - public Literal(Source source, String name, Object value, DataType dataType) { - super(source, name == null ? source.text() : name, emptyList(), null); + super(source); this.dataType = dataType; this.value = DataTypeConversion.convert(value, dataType); } @@ -75,32 +66,6 @@ public class Literal extends NamedExpression { return value; } - @Override - public Attribute toAttribute() { - return new LiteralAttribute(source(), name(), dataType, null, nullable(), id(), false, this); - } - - @Override - public ScriptTemplate asScript() { - return new ScriptTemplate(String.valueOf(value), Params.EMPTY, dataType); - } - - @Override - public Expression replaceChildren(List newChildren) { - throw new UnsupportedOperationException("this type of node doesn't have any children to replace"); - } - - @Override - public AttributeSet references() { - return AttributeSet.EMPTY; - } - - @Override - protected Expression canonicalize() { - String s = String.valueOf(value); - return name().equals(s) ? this : Literal.of(source(), value); - } - @Override public int hashCode() { return Objects.hash(value, dataType); @@ -116,14 +81,17 @@ public class Literal extends NamedExpression { } Literal other = (Literal) obj; - return Objects.equals(value, other.value) - && Objects.equals(dataType, other.dataType); + return Objects.equals(value, other.value) && Objects.equals(dataType, other.dataType); } @Override public String toString() { - String s = String.valueOf(value); - return name().equals(s) ? s : name() + "=" + value; + return String.valueOf(value); + } + + @Override + public String nodeString() { + return toString() + "[" + dataType + "]"; } /** @@ -141,31 +109,18 @@ public class Literal extends NamedExpression { * Throws an exception if the expression is not foldable. */ public static Literal of(Expression foldable) { - return of((String) null, foldable); - } - - public static Literal of(String name, Expression foldable) { if (!foldable.foldable()) { throw new SqlIllegalArgumentException("Foldable expression required for Literal creation; received unfoldable " + foldable); } if (foldable instanceof Literal) { - Literal l = (Literal) foldable; - if (name == null || l.name().equals(name)) { - return l; - } + return (Literal) foldable; } - Object fold = foldable.fold(); - - if (name == null) { - name = foldable instanceof NamedExpression ? ((NamedExpression) foldable).name() : String.valueOf(fold); - } - return new Literal(foldable.source(), name, fold, foldable.dataType()); + return new Literal(foldable.source(), foldable.fold(), foldable.dataType()); } public static Literal of(Expression source, Object value) { - String name = source instanceof NamedExpression ? ((NamedExpression) source).name() : String.valueOf(value); - return new Literal(source.source(), name, value, source.dataType()); + return new Literal(source.source(), value, source.dataType()); } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/LiteralAttribute.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/LiteralAttribute.java deleted file mode 100644 index 506f3f8a073..00000000000 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/LiteralAttribute.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.sql.expression; - -import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; -import org.elasticsearch.xpack.sql.tree.NodeInfo; -import org.elasticsearch.xpack.sql.tree.Source; -import org.elasticsearch.xpack.sql.type.DataType; - -public class LiteralAttribute extends TypedAttribute { - - private final Literal literal; - - public LiteralAttribute(Source source, String name, DataType dataType, String qualifier, Nullability nullability, ExpressionId id, - boolean synthetic, Literal literal) { - super(source, name, dataType, qualifier, nullability, id, synthetic); - this.literal = literal; - } - - @Override - protected NodeInfo info() { - return NodeInfo.create(this, LiteralAttribute::new, - name(), dataType(), qualifier(), nullable(), id(), synthetic(), literal); - } - - @Override - protected LiteralAttribute clone(Source source, String name, DataType dataType, String qualifier, Nullability nullability, - ExpressionId id, boolean synthetic) { - return new LiteralAttribute(source, name, dataType, qualifier, nullability, id, synthetic, literal); - } - - @Override - protected String label() { - return "c"; - } - - @Override - public Pipe asPipe() { - return literal.asPipe(); - } - - @Override - public Object fold() { - return literal.fold(); - } -} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/ExpressionId.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/NameId.java similarity index 77% rename from x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/ExpressionId.java rename to x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/NameId.java index cbc622a615c..bc74a506d77 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/ExpressionId.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/NameId.java @@ -9,23 +9,23 @@ import java.util.Objects; import java.util.concurrent.atomic.AtomicLong; /** - * Unique identifier for an expression. + * Unique identifier for a named expression. *

* We use an {@link AtomicLong} to guarantee that they are unique - * and that they produce reproduceable values when run in subsequent - * tests. They don't produce reproduceable values in production, but + * and that create reproducible values when run in subsequent + * tests. They don't produce reproducible values in production, but * you rarely debug with them in production and commonly do so in * tests. */ -public class ExpressionId { +public class NameId { private static final AtomicLong COUNTER = new AtomicLong(); private final long id; - public ExpressionId() { + public NameId() { this.id = COUNTER.incrementAndGet(); } - public ExpressionId(long id) { + public NameId(long id) { this.id = id; } @@ -42,7 +42,7 @@ public class ExpressionId { if (obj == null || obj.getClass() != getClass()) { return false; } - ExpressionId other = (ExpressionId) obj; + NameId other = (NameId) obj; return id == other.id; } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/NamedExpression.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/NamedExpression.java index e586621a7dd..633e2303930 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/NamedExpression.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/NamedExpression.java @@ -5,10 +5,6 @@ */ package org.elasticsearch.xpack.sql.expression; -import org.elasticsearch.xpack.sql.expression.gen.pipeline.AttributeInput; -import org.elasticsearch.xpack.sql.expression.gen.pipeline.ConstantInput; -import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; -import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; import org.elasticsearch.xpack.sql.tree.Source; import java.util.List; @@ -21,19 +17,18 @@ import java.util.Objects; public abstract class NamedExpression extends Expression { private final String name; - private final ExpressionId id; + private final NameId id; private final boolean synthetic; - private Pipe lazyPipe = null; - public NamedExpression(Source source, String name, List children, ExpressionId id) { + public NamedExpression(Source source, String name, List children, NameId id) { this(source, name, children, id, false); } - public NamedExpression(Source source, String name, List children, ExpressionId id, boolean synthetic) { + public NamedExpression(Source source, String name, List children, NameId id, boolean synthetic) { super(source, children); this.name = name; - this.id = id == null ? new ExpressionId() : id; + this.id = id == null ? new NameId() : id; this.synthetic = synthetic; } @@ -41,7 +36,7 @@ public abstract class NamedExpression extends Expression { return name; } - public ExpressionId id() { + public NameId id() { return id; } @@ -51,20 +46,6 @@ public abstract class NamedExpression extends Expression { public abstract Attribute toAttribute(); - public Pipe asPipe() { - if (lazyPipe == null) { - lazyPipe = foldable() ? new ConstantInput(source(), this, fold()) : makePipe(); - } - - return lazyPipe; - } - - protected Pipe makePipe() { - return new AttributeInput(source(), this, toAttribute()); - } - - public abstract ScriptTemplate asScript(); - @Override public int hashCode() { return Objects.hash(super.hashCode(), name, synthetic); @@ -95,4 +76,9 @@ public abstract class NamedExpression extends Expression { public String toString() { return super.toString() + "#" + id(); } -} + + @Override + public String nodeString() { + return name(); + } +} \ No newline at end of file diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Order.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Order.java index 267a8827d8c..3642ac94d8e 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Order.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Order.java @@ -101,4 +101,4 @@ public class Order extends Expression { && Objects.equals(nulls, other.nulls) && Objects.equals(child, other.child); } -} +} \ No newline at end of file diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/ReferenceAttribute.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/ReferenceAttribute.java new file mode 100644 index 00000000000..03330bc1148 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/ReferenceAttribute.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.sql.expression; + +import org.elasticsearch.xpack.sql.tree.NodeInfo; +import org.elasticsearch.xpack.sql.tree.Source; +import org.elasticsearch.xpack.sql.type.DataType; + +/** + * Attribute based on a reference to an expression. + */ +public class ReferenceAttribute extends TypedAttribute { + + public ReferenceAttribute(Source source, String name, DataType dataType) { + this(source, name, dataType, null, Nullability.FALSE, null, false); + } + + public ReferenceAttribute(Source source, String name, DataType dataType, String qualifier, Nullability nullability, + NameId id, boolean synthetic) { + super(source, name, dataType, qualifier, nullability, id, synthetic); + } + + @Override + protected Attribute clone(Source source, String name, DataType dataType, String qualifier, Nullability nullability, NameId id, + boolean synthetic) { + return new ReferenceAttribute(source, name, dataType, qualifier, nullability, id, synthetic); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, ReferenceAttribute::new, name(), dataType(), qualifier(), nullable(), id(), synthetic()); + } + + @Override + protected String label() { + return "r"; + } +} \ No newline at end of file diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/ScalarSubquery.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/ScalarSubquery.java index 84693cdc79d..cba61814e8f 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/ScalarSubquery.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/ScalarSubquery.java @@ -16,7 +16,7 @@ public class ScalarSubquery extends SubQueryExpression { this(source, query, null); } - public ScalarSubquery(Source source, LogicalPlan query, ExpressionId id) { + public ScalarSubquery(Source source, LogicalPlan query, NameId id) { super(source, query, id); } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/SubQueryExpression.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/SubQueryExpression.java index 17ec60b6e69..250e5de7218 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/SubQueryExpression.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/SubQueryExpression.java @@ -15,16 +15,16 @@ import java.util.Objects; public abstract class SubQueryExpression extends Expression { private final LogicalPlan query; - private final ExpressionId id; + private final NameId id; public SubQueryExpression(Source source, LogicalPlan query) { this(source, query, null); } - public SubQueryExpression(Source source, LogicalPlan query, ExpressionId id) { + public SubQueryExpression(Source source, LogicalPlan query, NameId id) { super(source, Collections.emptyList()); this.query = query; - this.id = id == null ? new ExpressionId() : id; + this.id = id == null ? new NameId() : id; } @Override @@ -36,7 +36,7 @@ public abstract class SubQueryExpression extends Expression { return query; } - public ExpressionId id() { + public NameId id() { return id; } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/TypedAttribute.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/TypedAttribute.java index 414ff330bda..98f91d4dca1 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/TypedAttribute.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/TypedAttribute.java @@ -15,7 +15,7 @@ public abstract class TypedAttribute extends Attribute { private final DataType dataType; protected TypedAttribute(Source source, String name, DataType dataType, String qualifier, Nullability nullability, - ExpressionId id, boolean synthetic) { + NameId id, boolean synthetic) { super(source, name, qualifier, nullability, id, synthetic); this.dataType = dataType; } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedAlias.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedAlias.java index 178c4d89695..67bbee18b39 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedAlias.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedAlias.java @@ -6,14 +6,14 @@ package org.elasticsearch.xpack.sql.expression; import org.elasticsearch.xpack.sql.capabilities.UnresolvedException; -import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.tree.NodeInfo; +import org.elasticsearch.xpack.sql.tree.Source; + +import java.util.List; import java.util.Objects; import static java.util.Collections.singletonList; -import java.util.List; - public class UnresolvedAlias extends UnresolvedNamedExpression { private final Expression child; @@ -72,4 +72,9 @@ public class UnresolvedAlias extends UnresolvedNamedExpression { public String toString() { return child + " AS ?"; } -} + + @Override + public String nodeString() { + return toString(); + } +} \ No newline at end of file diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedAttribute.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedAttribute.java index add7f702e04..34b8eca1c35 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedAttribute.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedAttribute.java @@ -13,11 +13,8 @@ import org.elasticsearch.xpack.sql.type.DataType; import org.elasticsearch.xpack.sql.util.CollectionUtils; import java.util.List; -import java.util.Locale; import java.util.Objects; -import static java.lang.String.format; - // unfortunately we can't use UnresolvedNamedExpression public class UnresolvedAttribute extends Attribute implements Unresolvable { @@ -37,7 +34,7 @@ public class UnresolvedAttribute extends Attribute implements Unresolvable { this(source, name, qualifier, null, unresolvedMessage, null); } - public UnresolvedAttribute(Source source, String name, String qualifier, ExpressionId id, String unresolvedMessage, + public UnresolvedAttribute(Source source, String name, String qualifier, NameId id, String unresolvedMessage, Object resolutionMetadata) { super(source, name, qualifier, id); this.customMessage = unresolvedMessage != null; @@ -66,7 +63,7 @@ public class UnresolvedAttribute extends Attribute implements Unresolvable { @Override protected Attribute clone(Source source, String name, DataType dataType, String qualifier, Nullability nullability, - ExpressionId id, boolean synthetic) { + NameId id, boolean synthetic) { return this; } @@ -79,11 +76,6 @@ public class UnresolvedAttribute extends Attribute implements Unresolvable { throw new UnresolvedException("dataType", this); } - @Override - public String nodeString() { - return format(Locale.ROOT, "unknown column '%s'", name()); - } - @Override public String toString() { return UNRESOLVED_PREFIX + qualifiedName(); @@ -94,6 +86,11 @@ public class UnresolvedAttribute extends Attribute implements Unresolvable { return UNRESOLVED_PREFIX; } + @Override + public String nodeString() { + return toString(); + } + @Override public String unresolvedMessage() { return unresolvedMsg; diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedNamedExpression.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedNamedExpression.java index 4d68d32f374..5e27180541d 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedNamedExpression.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedNamedExpression.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.sql.expression; import org.elasticsearch.xpack.sql.capabilities.Unresolvable; import org.elasticsearch.xpack.sql.capabilities.UnresolvedException; -import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.type.DataType; @@ -16,7 +15,7 @@ import java.util.List; abstract class UnresolvedNamedExpression extends NamedExpression implements Unresolvable { UnresolvedNamedExpression(Source source, List children) { - super(source, "", children, new ExpressionId()); + super(source, "", children, new NameId()); } @Override @@ -30,7 +29,7 @@ abstract class UnresolvedNamedExpression extends NamedExpression implements Unre } @Override - public ExpressionId id() { + public NameId id() { throw new UnresolvedException("id", this); } @@ -43,9 +42,4 @@ abstract class UnresolvedNamedExpression extends NamedExpression implements Unre public Attribute toAttribute() { throw new UnresolvedException("attribute", this); } - - @Override - public ScriptTemplate asScript() { - throw new UnresolvedException("script", this); - } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedStar.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedStar.java index 4b5a6dfa537..0f38a12d796 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedStar.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/UnresolvedStar.java @@ -6,14 +6,14 @@ package org.elasticsearch.xpack.sql.expression; import org.elasticsearch.xpack.sql.capabilities.UnresolvedException; -import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.tree.NodeInfo; +import org.elasticsearch.xpack.sql.tree.Source; + +import java.util.List; import java.util.Objects; import static java.util.Collections.emptyList; -import java.util.List; - public class UnresolvedStar extends UnresolvedNamedExpression { // typically used for nested fields or inner/dotted fields @@ -74,6 +74,11 @@ public class UnresolvedStar extends UnresolvedNamedExpression { return "Cannot determine columns for [" + message() + "]"; } + @Override + public String nodeString() { + return toString(); + } + @Override public String toString() { return UNRESOLVED_PREFIX + message(); diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/Function.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/Function.java index 7724e81525b..47e160df578 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/Function.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/Function.java @@ -6,42 +6,39 @@ package org.elasticsearch.xpack.sql.expression.function; import org.elasticsearch.xpack.sql.expression.Expression; -import org.elasticsearch.xpack.sql.expression.ExpressionId; import org.elasticsearch.xpack.sql.expression.Expressions; -import org.elasticsearch.xpack.sql.expression.NamedExpression; import org.elasticsearch.xpack.sql.expression.Nullability; +import org.elasticsearch.xpack.sql.expression.gen.pipeline.ConstantInput; +import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; +import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; import org.elasticsearch.xpack.sql.tree.Source; -import org.elasticsearch.xpack.sql.util.StringUtils; import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.StringJoiner; /** * Any SQL expression with parentheses, like {@code MAX()}, or {@code ABS()}. A * function is always a {@code NamedExpression}. */ -public abstract class Function extends NamedExpression { +public abstract class Function extends Expression { - private final String functionName, name; + private final String functionName = getClass().getSimpleName().toUpperCase(Locale.ROOT); - protected Function(Source source, List children) { - this(source, children, null, false); - } + private Pipe lazyPipe = null; // TODO: Functions supporting distinct should add a dedicated constructor Location, List, boolean - protected Function(Source source, List children, ExpressionId id, boolean synthetic) { - // cannot detect name yet so override the name - super(source, null, children, id, synthetic); - functionName = StringUtils.camelCaseToUnderscore(getClass().getSimpleName()); - name = source.text(); + protected Function(Source source, List children) { + super(source, children); } public final List arguments() { return children(); } - @Override - public String name() { - return name; + public String functionName() { + return functionName; } @Override @@ -49,16 +46,44 @@ public abstract class Function extends NamedExpression { return Expressions.nullable(children()); } - public String functionName() { - return functionName; + @Override + public int hashCode() { + return Objects.hash(getClass(), children()); } - // TODO: ExpressionId might be converted into an Int which could make the String an int as well - public String functionId() { - return id().toString(); + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + Function other = (Function) obj; + return Objects.equals(children(), other.children()); } - public boolean functionEquals(Function f) { - return f != null && getClass() == f.getClass() && arguments().equals(f.arguments()); + public Pipe asPipe() { + if (lazyPipe == null) { + lazyPipe = foldable() ? new ConstantInput(source(), this, fold()) : makePipe(); + } + return lazyPipe; } + + protected Pipe makePipe() { + throw new UnsupportedOperationException(); + } + + @Override + public String nodeString() { + StringJoiner sj = new StringJoiner(",", functionName() + "(", ")"); + for (Expression ex : arguments()) { + sj.add(ex.nodeString()); + } + return sj.toString(); + } + + public abstract ScriptTemplate asScript(); } \ No newline at end of file diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/FunctionAttribute.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/FunctionAttribute.java deleted file mode 100644 index 962fb010c48..00000000000 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/FunctionAttribute.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.sql.expression.function; - -import org.elasticsearch.xpack.sql.expression.ExpressionId; -import org.elasticsearch.xpack.sql.expression.Nullability; -import org.elasticsearch.xpack.sql.expression.TypedAttribute; -import org.elasticsearch.xpack.sql.tree.Source; -import org.elasticsearch.xpack.sql.type.DataType; - -import java.util.Objects; - -public abstract class FunctionAttribute extends TypedAttribute { - - private final String functionId; - - protected FunctionAttribute(Source source, String name, DataType dataType, String qualifier, Nullability nullability, - ExpressionId id, boolean synthetic, String functionId) { - super(source, name, dataType, qualifier, nullability, id, synthetic); - this.functionId = functionId; - } - - public String functionId() { - return functionId; - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode()); - } - - @Override - public boolean equals(Object obj) { - return super.equals(obj); - } -} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/Functions.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/Functions.java index 46ca0ea91b4..47ca821f4b5 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/Functions.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/Functions.java @@ -8,10 +8,6 @@ package org.elasticsearch.xpack.sql.expression.function; import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunction; -import org.elasticsearch.xpack.sql.plan.QueryPlan; - -import java.util.LinkedHashMap; -import java.util.Map; public abstract class Functions { @@ -22,15 +18,4 @@ public abstract class Functions { public static boolean isGrouping(Expression e) { return e instanceof GroupingFunction; } - - public static Map collectFunctions(QueryPlan plan) { - Map resolvedFunctions = new LinkedHashMap<>(); - plan.forEachExpressionsDown(e -> { - if (e.resolved() && e instanceof Function) { - Function f = (Function) e; - resolvedFunctions.put(f.functionId(), f); - } - }); - return resolvedFunctions; - } } \ No newline at end of file diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/Score.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/Score.java index 23363ff6cbd..d5cee644980 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/Score.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/Score.java @@ -6,12 +6,11 @@ package org.elasticsearch.xpack.sql.expression.function; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; -import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; -import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.tree.NodeInfo; +import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.type.DataType; import java.util.List; @@ -43,11 +42,6 @@ public class Score extends Function { return DataType.FLOAT; } - @Override - public Attribute toAttribute() { - return new ScoreAttribute(source()); - } - @Override public boolean equals(Object obj) { if (obj == null || obj.getClass() != getClass()) { diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/ScoreAttribute.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/ScoreAttribute.java deleted file mode 100644 index 7d93db3d862..00000000000 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/ScoreAttribute.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.sql.expression.function; - -import org.elasticsearch.xpack.sql.expression.Attribute; -import org.elasticsearch.xpack.sql.expression.ExpressionId; -import org.elasticsearch.xpack.sql.expression.Nullability; -import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; -import org.elasticsearch.xpack.sql.expression.gen.pipeline.ScorePipe; -import org.elasticsearch.xpack.sql.tree.NodeInfo; -import org.elasticsearch.xpack.sql.tree.Source; -import org.elasticsearch.xpack.sql.type.DataType; - -import static org.elasticsearch.xpack.sql.expression.Nullability.FALSE; - -/** - * {@link Attribute} that represents Elasticsearch's {@code _score}. - */ -public class ScoreAttribute extends FunctionAttribute { - /** - * Constructor for normal use. - */ - public ScoreAttribute(Source source) { - this(source, "SCORE()", DataType.FLOAT, null, FALSE, null, false); - } - - /** - * Constructor for {@link #clone()} - */ - private ScoreAttribute(Source source, String name, DataType dataType, String qualifier, Nullability nullability, ExpressionId id, - boolean synthetic) { - super(source, name, dataType, qualifier, nullability, id, synthetic, "SCORE"); - } - - @Override - protected NodeInfo info() { - return NodeInfo.create(this); - } - - @Override - protected Attribute clone(Source source, String name, DataType dataType, String qualifier, Nullability nullability, - ExpressionId id, boolean synthetic) { - return new ScoreAttribute(source, name, dataType, qualifier, nullability, id, synthetic); - } - - @Override - protected Pipe makePipe() { - return new ScorePipe(source(), this); - } - - @Override - protected String label() { - return "SCORE"; - } -} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/UnresolvedFunction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/UnresolvedFunction.java index 85afc25c5c6..920d030ddfd 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/UnresolvedFunction.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/UnresolvedFunction.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.sql.expression.function; import org.elasticsearch.xpack.sql.capabilities.Unresolvable; import org.elasticsearch.xpack.sql.capabilities.UnresolvedException; -import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.Literal; import org.elasticsearch.xpack.sql.expression.Nullability; @@ -113,16 +112,10 @@ public class UnresolvedFunction extends Function implements Unresolvable { return false; } - @Override public String name() { return name; } - @Override - public String functionName() { - return name; - } - ResolutionType resolutionType() { return resolutionType; } @@ -149,11 +142,6 @@ public class UnresolvedFunction extends Function implements Unresolvable { throw new UnresolvedException("nullable", this); } - @Override - public Attribute toAttribute() { - throw new UnresolvedException("attribute", this); - } - @Override public ScriptTemplate asScript() { throw new UnresolvedException("script", this); @@ -169,6 +157,11 @@ public class UnresolvedFunction extends Function implements Unresolvable { return UNRESOLVED_PREFIX + name + children(); } + @Override + public String nodeString() { + return toString(); + } + @Override public boolean equals(Object obj) { if (obj == null || obj.getClass() != getClass()) { diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/AggregateFunction.java index 59b4f345a4a..91ac02dc837 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/AggregateFunction.java @@ -30,8 +30,6 @@ public abstract class AggregateFunction extends Function { private final Expression field; private final List parameters; - private AggregateFunctionAttribute lazyAttribute; - protected AggregateFunction(Source source, Expression field) { this(source, field, emptyList()); } @@ -51,18 +49,14 @@ public abstract class AggregateFunction extends Function { } @Override - public AggregateFunctionAttribute toAttribute() { - if (lazyAttribute == null) { - // this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl) - lazyAttribute = new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId()); - } - return lazyAttribute; + protected TypeResolution resolveType() { + return TypeResolutions.isExact(field, sourceText(), Expressions.ParamOrdinal.DEFAULT); } @Override protected Pipe makePipe() { // unresolved AggNameInput (should always get replaced by the folder) - return new AggNameInput(source(), this, name()); + return new AggNameInput(source(), this, sourceText()); } @Override @@ -70,23 +64,20 @@ public abstract class AggregateFunction extends Function { throw new SqlIllegalArgumentException("Aggregate functions cannot be scripted"); } - @Override - public boolean equals(Object obj) { - if (false == super.equals(obj)) { - return false; - } - AggregateFunction other = (AggregateFunction) obj; - return Objects.equals(other.field(), field()) - && Objects.equals(other.parameters(), parameters()); - } - - @Override - protected TypeResolution resolveType() { - return TypeResolutions.isExact(field, sourceText(), Expressions.ParamOrdinal.DEFAULT); - } - @Override public int hashCode() { - return Objects.hash(field(), parameters()); + // NB: the hashcode is currently used for key generation so + // to avoid clashes between aggs with the same arguments, add the class name as variation + return Objects.hash(getClass(), children()); + } + + @Override + public boolean equals(Object obj) { + if (super.equals(obj) == true) { + AggregateFunction other = (AggregateFunction) obj; + return Objects.equals(other.field(), field()) + && Objects.equals(other.parameters(), parameters()); + } + return false; } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/AggregateFunctionAttribute.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/AggregateFunctionAttribute.java deleted file mode 100644 index 463a277a8fa..00000000000 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/AggregateFunctionAttribute.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.sql.expression.function.aggregate; - -import org.elasticsearch.xpack.sql.expression.Attribute; -import org.elasticsearch.xpack.sql.expression.Expression; -import org.elasticsearch.xpack.sql.expression.ExpressionId; -import org.elasticsearch.xpack.sql.expression.Nullability; -import org.elasticsearch.xpack.sql.expression.function.FunctionAttribute; -import org.elasticsearch.xpack.sql.tree.NodeInfo; -import org.elasticsearch.xpack.sql.tree.Source; -import org.elasticsearch.xpack.sql.type.DataType; - -import java.util.Objects; - -public class AggregateFunctionAttribute extends FunctionAttribute { - - // used when dealing with a inner agg (avg -> stats) to keep track of - // packed id - // used since the functionId points to the compoundAgg - private final ExpressionId innerId; - private final String propertyPath; - - AggregateFunctionAttribute(Source source, String name, DataType dataType, ExpressionId id, String functionId) { - this(source, name, dataType, null, Nullability.FALSE, id, false, functionId, null, null); - } - - AggregateFunctionAttribute(Source source, String name, DataType dataType, ExpressionId id, String functionId, ExpressionId innerId, - String propertyPath) { - this(source, name, dataType, null, Nullability.FALSE, id, false, functionId, innerId, propertyPath); - } - - public AggregateFunctionAttribute(Source source, String name, DataType dataType, String qualifier, Nullability nullability, - ExpressionId id, boolean synthetic, String functionId, ExpressionId innerId, String propertyPath) { - super(source, name, dataType, qualifier, nullability, id, synthetic, functionId); - this.innerId = innerId; - this.propertyPath = propertyPath; - } - - @Override - protected NodeInfo info() { - return NodeInfo.create(this, AggregateFunctionAttribute::new, name(), dataType(), qualifier(), nullable(), id(), synthetic(), - functionId(), innerId, propertyPath); - } - - public ExpressionId innerId() { - return innerId != null ? innerId : id(); - } - - public String propertyPath() { - return propertyPath; - } - - @Override - protected Expression canonicalize() { - return new AggregateFunctionAttribute(source(), "", dataType(), null, Nullability.TRUE, id(), false, "", null, null); - } - - @Override - protected Attribute clone(Source source, String name, DataType dataType, String qualifier, Nullability nullability, ExpressionId id, - boolean synthetic) { - // this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl) - // that is the functionId is actually derived from the expression id to easily track it across contexts - return new AggregateFunctionAttribute(source, name, dataType, qualifier, nullability, id, synthetic, functionId(), innerId, - propertyPath); - } - - public AggregateFunctionAttribute withFunctionId(String functionId, String propertyPath) { - return new AggregateFunctionAttribute(source(), name(), dataType(), qualifier(), nullable(), id(), synthetic(), functionId, innerId, - propertyPath); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), propertyPath); - } - - @Override - public boolean equals(Object obj) { - if (super.equals(obj)) { - AggregateFunctionAttribute other = (AggregateFunctionAttribute) obj; - return Objects.equals(propertyPath, other.propertyPath); - } - return false; - } - - @Override - protected String label() { - return "a->" + innerId(); - } -} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/Count.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/Count.java index 1da2eeb0277..951144f5b2e 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/Count.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/Count.java @@ -6,8 +6,6 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate; import org.elasticsearch.xpack.sql.expression.Expression; -import org.elasticsearch.xpack.sql.expression.Literal; -import org.elasticsearch.xpack.sql.expression.NamedExpression; import org.elasticsearch.xpack.sql.tree.NodeInfo; import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.type.DataType; @@ -51,46 +49,17 @@ public class Count extends AggregateFunction { return DataType.LONG; } - @Override - public String functionId() { - String functionId = id().toString(); - // if count works against a given expression, use its id (to identify the group) - // in case of COUNT DISTINCT don't use the expression id to avoid possible duplicate IDs when COUNT and COUNT DISTINCT is used - // in the same query - if (!distinct() && field() instanceof NamedExpression) { - functionId = ((NamedExpression) field()).id().toString(); - } - return functionId; - } - - @Override - public AggregateFunctionAttribute toAttribute() { - // COUNT(*) gets its value from the parent aggregation on which _count is called - if (field() instanceof Literal) { - return new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId(), id(), "_count"); - } - // COUNT(column) gets its value from a sibling aggregation (an exists filter agg) by calling its id and then _count on it - if (!distinct()) { - return new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId(), id(), functionId() + "._count"); - } - return super.toAttribute(); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || obj.getClass() != getClass()) { - return false; - } - Count other = (Count) obj; - return Objects.equals(other.distinct(), distinct()) - && Objects.equals(field(), other.field()); - } - @Override public int hashCode() { return Objects.hash(super.hashCode(), distinct()); } + + @Override + public boolean equals(Object obj) { + if (super.equals(obj) == true) { + Count other = (Count) obj; + return Objects.equals(other.distinct(), distinct()); + } + return false; + } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/InnerAggregate.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/InnerAggregate.java index 6e35fa5a7ac..c9d18b83c15 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/InnerAggregate.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/aggregate/InnerAggregate.java @@ -6,12 +6,12 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate; import org.elasticsearch.xpack.sql.expression.Expression; -import org.elasticsearch.xpack.sql.expression.function.Function; import org.elasticsearch.xpack.sql.tree.NodeInfo; import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.type.DataType; import java.util.List; +import java.util.Objects; public class InnerAggregate extends AggregateFunction { @@ -69,38 +69,28 @@ public class InnerAggregate extends AggregateFunction { } @Override - public String functionId() { - return outer.id().toString(); + public String functionName() { + return inner.functionName(); } @Override - public AggregateFunctionAttribute toAttribute() { - // this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl) - return new AggregateFunctionAttribute(source(), name(), dataType(), outer.id(), functionId(), - inner.id(), aggMetricValue(functionId(), innerName)); - } - - private static String aggMetricValue(String aggPath, String valueName) { - // handle aggPath inconsistency (for percentiles and percentileRanks) percentile[99.9] (valid) vs percentile.99.9 (invalid) - return aggPath + "[" + valueName + "]"; + public int hashCode() { + return Objects.hash(inner, outer, innerKey); } @Override - public boolean functionEquals(Function f) { - if (super.equals(f)) { - InnerAggregate other = (InnerAggregate) f; - return inner.equals(other.inner) && outer.equals(other.outer); + public boolean equals(Object obj) { + if (super.equals(obj) == true) { + InnerAggregate other = (InnerAggregate) obj; + return Objects.equals(inner, other.inner) + && Objects.equals(outer, other.outer) + && Objects.equals(innerKey, other.innerKey); } return false; } - @Override - public String name() { - return inner.name(); - } - @Override public String toString() { - return nodeName() + "[" + outer + ">" + inner.nodeName() + "#" + inner.id() + "]"; + return nodeName() + "[" + outer + ">" + inner.nodeName() + "]"; } } \ No newline at end of file diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/grouping/GroupingFunction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/grouping/GroupingFunction.java index b8a6bb40540..327c4ef382d 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/grouping/GroupingFunction.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/grouping/GroupingFunction.java @@ -28,8 +28,6 @@ public abstract class GroupingFunction extends Function { private final Expression field; private final List parameters; - private GroupingFunctionAttribute lazyAttribute; - protected GroupingFunction(Source source, Expression field) { this(source, field, emptyList()); } @@ -48,19 +46,10 @@ public abstract class GroupingFunction extends Function { return parameters; } - @Override - public GroupingFunctionAttribute toAttribute() { - if (lazyAttribute == null) { - // this is highly correlated with QueryFolder$FoldAggregate#addAggFunction (regarding the function name within the querydsl) - lazyAttribute = new GroupingFunctionAttribute(source(), name(), dataType(), id(), functionId()); - } - return lazyAttribute; - } - @Override protected Pipe makePipe() { // unresolved AggNameInput (should always get replaced by the folder) - return new AggNameInput(source(), this, name()); + return new AggNameInput(source(), this, sourceText()); } @Override diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/grouping/GroupingFunctionAttribute.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/grouping/GroupingFunctionAttribute.java deleted file mode 100644 index 2fed4cf3060..00000000000 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/grouping/GroupingFunctionAttribute.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.sql.expression.function.grouping; - -import org.elasticsearch.xpack.sql.expression.Attribute; -import org.elasticsearch.xpack.sql.expression.Expression; -import org.elasticsearch.xpack.sql.expression.ExpressionId; -import org.elasticsearch.xpack.sql.expression.Nullability; -import org.elasticsearch.xpack.sql.expression.function.FunctionAttribute; -import org.elasticsearch.xpack.sql.tree.NodeInfo; -import org.elasticsearch.xpack.sql.tree.Source; -import org.elasticsearch.xpack.sql.type.DataType; - -public class GroupingFunctionAttribute extends FunctionAttribute { - - GroupingFunctionAttribute(Source source, String name, DataType dataType, ExpressionId id, String functionId) { - this(source, name, dataType, null, Nullability.FALSE, id, false, functionId); - } - - public GroupingFunctionAttribute(Source source, String name, DataType dataType, String qualifier, - Nullability nullability, ExpressionId id, boolean synthetic, String functionId) { - super(source, name, dataType, qualifier, nullability, id, synthetic, functionId); - } - - @Override - protected NodeInfo info() { - return NodeInfo.create(this, GroupingFunctionAttribute::new, - name(), dataType(), qualifier(), nullable(), id(), synthetic(), functionId()); - } - - @Override - protected Expression canonicalize() { - return new GroupingFunctionAttribute(source(), "", dataType(), null, Nullability.TRUE, id(), false, ""); - } - - @Override - protected Attribute clone(Source source, String name, DataType dataType, String qualifier, Nullability nullability, - ExpressionId id, boolean synthetic) { - // this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl) - // that is the functionId is actually derived from the expression id to easily track it across contexts - return new GroupingFunctionAttribute(source, name, dataType, qualifier, nullability, id, synthetic, functionId()); - } - - public GroupingFunctionAttribute withFunctionId(String functionId, String propertyPath) { - return new GroupingFunctionAttribute(source(), name(), dataType(), qualifier(), nullable(), - id(), synthetic(), functionId); - } - - @Override - protected String label() { - return "g->" + functionId(); - } -} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/ScalarFunction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/ScalarFunction.java index d836030a3ae..b764b4b0a6a 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/ScalarFunction.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/ScalarFunction.java @@ -22,8 +22,6 @@ import static java.util.Collections.emptyList; */ public abstract class ScalarFunction extends Function implements ScriptWeaver { - private ScalarFunctionAttribute lazyAttribute = null; - protected ScalarFunction(Source source) { super(source, emptyList()); } @@ -32,15 +30,6 @@ public abstract class ScalarFunction extends Function implements ScriptWeaver { super(source, fields); } - @Override - public final ScalarFunctionAttribute toAttribute() { - if (lazyAttribute == null) { - lazyAttribute = new ScalarFunctionAttribute(source(), name(), dataType(), id(), functionId(), asScript(), orderBy(), - asPipe()); - } - return lazyAttribute; - } - // used if the function is monotonic and thus does not have to be computed for ordering purposes // null means the script needs to be used; expression means the field/expression to be used instead public Expression orderBy() { diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/ScalarFunctionAttribute.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/ScalarFunctionAttribute.java deleted file mode 100644 index 67324ba466c..00000000000 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/ScalarFunctionAttribute.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.sql.expression.function.scalar; - -import org.elasticsearch.xpack.sql.expression.Attribute; -import org.elasticsearch.xpack.sql.expression.Expression; -import org.elasticsearch.xpack.sql.expression.ExpressionId; -import org.elasticsearch.xpack.sql.expression.Nullability; -import org.elasticsearch.xpack.sql.expression.function.FunctionAttribute; -import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; -import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; -import org.elasticsearch.xpack.sql.tree.NodeInfo; -import org.elasticsearch.xpack.sql.tree.Source; -import org.elasticsearch.xpack.sql.type.DataType; - -import java.util.Objects; - -public class ScalarFunctionAttribute extends FunctionAttribute { - - private final ScriptTemplate script; - private final Expression orderBy; - private final Pipe pipe; - - ScalarFunctionAttribute(Source source, String name, DataType dataType, ExpressionId id, - String functionId, ScriptTemplate script, Expression orderBy, Pipe processorDef) { - this(source, name, dataType, null, Nullability.TRUE, id, false, functionId, script, orderBy, processorDef); - } - - public ScalarFunctionAttribute(Source source, String name, DataType dataType, String qualifier, - Nullability nullability, ExpressionId id, boolean synthetic, String functionId, ScriptTemplate script, - Expression orderBy, Pipe pipe) { - super(source, name, dataType, qualifier, nullability, id, synthetic, functionId); - - this.script = script; - this.orderBy = orderBy; - this.pipe = pipe; - } - - @Override - protected NodeInfo info() { - return NodeInfo.create(this, ScalarFunctionAttribute::new, - name(), dataType(), qualifier(), nullable(), id(), synthetic(), - functionId(), script, orderBy, pipe); - } - - public ScriptTemplate script() { - return script; - } - - public Expression orderBy() { - return orderBy; - } - - @Override - public Pipe asPipe() { - return pipe; - } - - @Override - protected Expression canonicalize() { - return new ScalarFunctionAttribute(source(), "", dataType(), null, Nullability.TRUE, id(), false, - functionId(), script, orderBy, pipe); - } - - @Override - protected Attribute clone(Source source, String name, DataType dataType, String qualifier, Nullability nullability, - ExpressionId id, boolean synthetic) { - return new ScalarFunctionAttribute(source, name, dataType, qualifier, nullability, - id, synthetic, functionId(), script, orderBy, pipe); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), script(), pipe, orderBy); - } - - @Override - public boolean equals(Object obj) { - if (super.equals(obj)) { - ScalarFunctionAttribute other = (ScalarFunctionAttribute) obj; - return Objects.equals(script, other.script()) - && Objects.equals(pipe, other.asPipe()) - && Objects.equals(orderBy, other.orderBy()); - } - return false; - } - - @Override - protected String label() { - return "s->" + functionId(); - } -} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/UnaryScalarFunction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/UnaryScalarFunction.java index 9a5f85e9431..d10d18b83a1 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/UnaryScalarFunction.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/UnaryScalarFunction.java @@ -57,8 +57,13 @@ public abstract class UnaryScalarFunction extends ScalarFunction { return field.foldable(); } + @Override + public Object fold() { + return makeProcessor().process(field().fold()); + } + @Override public ScriptTemplate asScript() { return asScript(field); } -} +} \ No newline at end of file diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/datetime/BaseDateTimeFunction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/datetime/BaseDateTimeFunction.java index bda86183fff..0cce7521a29 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/datetime/BaseDateTimeFunction.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/datetime/BaseDateTimeFunction.java @@ -52,6 +52,11 @@ abstract class BaseDateTimeFunction extends UnaryScalarFunction { return makeProcessor().process(field().fold()); } + @Override + public int hashCode() { + return Objects.hash(getClass(), field(), zoneId()); + } + @Override public boolean equals(Object obj) { if (obj == null || obj.getClass() != getClass()) { @@ -61,9 +66,4 @@ abstract class BaseDateTimeFunction extends UnaryScalarFunction { return Objects.equals(other.field(), field()) && Objects.equals(other.zoneId(), zoneId()); } - - @Override - public int hashCode() { - return Objects.hash(field(), zoneId()); - } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StWkttosql.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StWkttosql.java index 3ebae55dec4..04006d4a28b 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StWkttosql.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StWkttosql.java @@ -36,7 +36,7 @@ public class StWkttosql extends UnaryScalarFunction { if (field().dataType().isString()) { return TypeResolution.TYPE_RESOLVED; } - return isString(field(), functionName(), Expressions.ParamOrdinal.DEFAULT); + return isString(field(), sourceText(), Expressions.ParamOrdinal.DEFAULT); } @Override diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/math/E.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/math/E.java index 265b96984b5..b1b731fe91b 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/math/E.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/math/E.java @@ -20,7 +20,7 @@ public class E extends MathFunction { private static final ScriptTemplate TEMPLATE = new ScriptTemplate("Math.E", Params.EMPTY, DataType.DOUBLE); public E(Source source) { - super(source, new Literal(source, "E", Math.E, DataType.DOUBLE)); + super(source, new Literal(source, Math.E, DataType.DOUBLE)); } @Override diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/math/Pi.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/math/Pi.java index 7fb966c3201..79492bac3c1 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/math/Pi.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/math/Pi.java @@ -20,7 +20,7 @@ public class Pi extends MathFunction { private static final ScriptTemplate TEMPLATE = new ScriptTemplate("Math.PI", Params.EMPTY, DataType.DOUBLE); public Pi(Source source) { - super(source, new Literal(source, "PI", Math.PI, DataType.DOUBLE)); + super(source, new Literal(source, Math.PI, DataType.DOUBLE)); } @Override diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/string/Concat.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/string/Concat.java index 4e461d919a9..15602bc53c8 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/string/Concat.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/string/Concat.java @@ -38,12 +38,12 @@ public class Concat extends BinaryScalarFunction { return new TypeResolution("Unresolved children"); } - TypeResolution resolution = isStringAndExact(left(), functionName(), ParamOrdinal.FIRST); + TypeResolution resolution = isStringAndExact(left(), sourceText(), ParamOrdinal.FIRST); if (resolution.unresolved()) { return resolution; } - return isStringAndExact(right(), functionName(), ParamOrdinal.SECOND); + return isStringAndExact(right(), sourceText(), ParamOrdinal.SECOND); } @Override diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/Agg.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/Agg.java index 55bba713062..ad4ff617cce 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/Agg.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/Agg.java @@ -5,20 +5,48 @@ */ package org.elasticsearch.xpack.sql.expression.gen.script; -import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute; +import org.elasticsearch.xpack.sql.expression.Expressions; +import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.sql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.sql.expression.function.aggregate.InnerAggregate; -class Agg extends Param { +class Agg extends Param { - Agg(AggregateFunctionAttribute aggRef) { + private static final String COUNT_PATH = "_count"; + + Agg(AggregateFunction aggRef) { super(aggRef); } String aggName() { - return value().functionId(); + return Expressions.id(value()); } public String aggProperty() { - return value().propertyPath(); + AggregateFunction agg = value(); + + if (agg instanceof InnerAggregate) { + InnerAggregate inner = (InnerAggregate) agg; + return Expressions.id(inner.outer()) + "." + inner.innerName(); + } + // Count needs special handling since in most cases it is not a dedicated aggregation + else if (agg instanceof Count) { + Count c = (Count) agg; + // for literals get the last count + if (c.field().foldable() == true) { + return COUNT_PATH; + } + // when dealing with fields, check whether there's a single-metric (distinct -> cardinality) + // or a bucket (non-distinct - filter agg) + else { + if (c.distinct() == true) { + return Expressions.id(c); + } else { + return Expressions.id(c) + "." + COUNT_PATH; + } + } + } + return null; } @Override diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/Grouping.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/Grouping.java index e11f82a842e..f34e1c8798f 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/Grouping.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/Grouping.java @@ -5,16 +5,16 @@ */ package org.elasticsearch.xpack.sql.expression.gen.script; -import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunctionAttribute; +import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunction; -class Grouping extends Param { +class Grouping extends Param { - Grouping(GroupingFunctionAttribute groupRef) { + Grouping(GroupingFunction groupRef) { super(groupRef); } String groupName() { - return value().functionId(); + return Integer.toHexString(value().hashCode()); } @Override diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/ParamsBuilder.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/ParamsBuilder.java index 25e92103ccc..2e13682b70e 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/ParamsBuilder.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/ParamsBuilder.java @@ -5,8 +5,8 @@ */ package org.elasticsearch.xpack.sql.expression.gen.script; -import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute; -import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunctionAttribute; +import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunction; import java.util.ArrayList; import java.util.List; @@ -24,12 +24,12 @@ public class ParamsBuilder { return this; } - public ParamsBuilder agg(AggregateFunctionAttribute agg) { + public ParamsBuilder agg(AggregateFunction agg) { params.add(new Agg(agg)); return this; } - public ParamsBuilder grouping(GroupingFunctionAttribute grouping) { + public ParamsBuilder grouping(GroupingFunction grouping) { params.add(new Grouping(grouping)); return this; } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/ScriptWeaver.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/ScriptWeaver.java index 223e22b2a33..e468a2801ce 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/ScriptWeaver.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/ScriptWeaver.java @@ -7,14 +7,12 @@ package org.elasticsearch.xpack.sql.expression.gen.script; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; -import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Expression; -import org.elasticsearch.xpack.sql.expression.Expressions; import org.elasticsearch.xpack.sql.expression.FieldAttribute; import org.elasticsearch.xpack.sql.expression.Literal; -import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute; -import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunctionAttribute; -import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute; +import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunction; +import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction; import org.elasticsearch.xpack.sql.expression.function.scalar.geo.GeoShape; import org.elasticsearch.xpack.sql.expression.literal.IntervalDayTime; import org.elasticsearch.xpack.sql.expression.literal.IntervalYearMonth; @@ -36,20 +34,20 @@ public interface ScriptWeaver { return scriptWithFoldable(exp); } - Attribute attr = Expressions.attribute(exp); - if (attr != null) { - if (attr instanceof ScalarFunctionAttribute) { - return scriptWithScalar((ScalarFunctionAttribute) attr); - } - if (attr instanceof AggregateFunctionAttribute) { - return scriptWithAggregate((AggregateFunctionAttribute) attr); - } - if (attr instanceof GroupingFunctionAttribute) { - return scriptWithGrouping((GroupingFunctionAttribute) attr); - } - if (attr instanceof FieldAttribute) { - return scriptWithField((FieldAttribute) attr); - } + if (exp instanceof ScalarFunction) { + return scriptWithScalar((ScalarFunction) exp); + } + + if (exp instanceof AggregateFunction) { + return scriptWithAggregate((AggregateFunction) exp); + } + + if (exp instanceof GroupingFunction) { + return scriptWithGrouping((GroupingFunction) exp); + } + + if (exp instanceof FieldAttribute) { + return scriptWithField((FieldAttribute) exp); } throw new SqlIllegalArgumentException("Cannot evaluate script for expression {}", exp); } @@ -108,14 +106,14 @@ public interface ScriptWeaver { dataType()); } - default ScriptTemplate scriptWithScalar(ScalarFunctionAttribute scalar) { - ScriptTemplate nested = scalar.script(); + default ScriptTemplate scriptWithScalar(ScalarFunction scalar) { + ScriptTemplate nested = scalar.asScript(); return new ScriptTemplate(processScript(nested.template()), paramsBuilder().script(nested.params()).build(), dataType()); } - default ScriptTemplate scriptWithAggregate(AggregateFunctionAttribute aggregate) { + default ScriptTemplate scriptWithAggregate(AggregateFunction aggregate) { String template = "{}"; if (aggregate.dataType().isDateBased()) { template = "{sql}.asDateTime({})"; @@ -125,7 +123,7 @@ public interface ScriptWeaver { dataType()); } - default ScriptTemplate scriptWithGrouping(GroupingFunctionAttribute grouping) { + default ScriptTemplate scriptWithGrouping(GroupingFunction grouping) { String template = "{}"; if (grouping.dataType().isDateBased()) { template = "{sql}.asDateTime({})"; diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/BinaryPredicate.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/BinaryPredicate.java index eb7265dc29b..8705f9c58e5 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/BinaryPredicate.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/BinaryPredicate.java @@ -68,4 +68,9 @@ public abstract class BinaryPredicate { Batch pivot = new Batch("Pivot Rewrite", Limiter.ONCE, new RewritePivot()); + Batch refs = new Batch("Replace References", Limiter.ONCE, + new ReplaceReferenceAttributeWithSource()); + Batch operators = new Batch("Operator Optimization", - new PruneDuplicatesInGroupBy(), // combining new CombineProjections(), // folding @@ -145,20 +141,18 @@ public class Optimizer extends RuleExecutor { new PropagateEquals(), new CombineBinaryComparisons(), // prune/elimination + new PruneLiteralsInGroupBy(), + new PruneDuplicatesInGroupBy(), new PruneFilters(), - new PruneOrderBy(), + new PruneOrderByForImplicitGrouping(), + new PruneLiteralsInOrderBy(), new PruneOrderByNestedFields(), new PruneCast(), // order by alignment of the aggs new SortAggregateOnOrderBy() - // requires changes in the folding - // since the exact same function, with the same ID can appear in multiple places - // see https://github.com/elastic/x-pack-elasticsearch/issues/3527 - //new PruneDuplicateFunctions() ); Batch aggregate = new Batch("Aggregation Rewrite", - //new ReplaceDuplicateAggsWithReferences(), new ReplaceMinMaxWithTopHits(), new ReplaceAggsWithMatrixStats(), new ReplaceAggsWithExtendedStats(), @@ -172,12 +166,11 @@ public class Optimizer extends RuleExecutor { new SkipQueryOnLimitZero(), new SkipQueryIfFoldingProjection() ); - //new BalanceBooleanTrees()); Batch label = new Batch("Set as Optimized", Limiter.ONCE, CleanAliases.INSTANCE, new SetAsOptimized()); - return Arrays.asList(pivot, operators, aggregate, local, label); + return Arrays.asList(pivot, refs, operators, aggregate, local, label); } static class RewritePivot extends OptimizerRule { @@ -189,17 +182,9 @@ public class Optimizer extends RuleExecutor { for (NamedExpression namedExpression : plan.values()) { // everything should have resolved to an alias if (namedExpression instanceof Alias) { - rawValues.add(((Alias) namedExpression).child()); + rawValues.add(Literal.of(((Alias) namedExpression).child())); } - // TODO: this should be removed when refactoring NamedExpression - else if (namedExpression instanceof Literal) { - rawValues.add(namedExpression); - } - // TODO: NamedExpression refactoring should remove this - else if (namedExpression.foldable()) { - rawValues.add(Literal.of(namedExpression.name(), namedExpression)); - } - // TODO: same as above + // fallback - should not happen else { UnresolvedAttribute attr = new UnresolvedAttribute(namedExpression.source(), namedExpression.name(), null, "Unexpected alias"); @@ -208,7 +193,67 @@ public class Optimizer extends RuleExecutor { } Filter filter = new Filter(plan.source(), plan.child(), new In(plan.source(), plan.column(), rawValues)); // 2. preserve the PIVOT - return new Pivot(plan.source(), filter, plan.column(), plan.values(), plan.aggregates()); + return new Pivot(plan.source(), filter, plan.column(), plan.values(), plan.aggregates(), plan.groupings()); + } + } + + // + // Replace any reference attribute with its source, if it does not affect the result. + // This avoid ulterior look-ups between attributes and its source across nodes, which is + // problematic when doing script translation. + // + static class ReplaceReferenceAttributeWithSource extends OptimizerBasicRule { + + @Override + public LogicalPlan apply(LogicalPlan plan) { + final Map collectRefs = new LinkedHashMap<>(); + + // collect aliases + plan.forEachUp(p -> p.forEachExpressionsUp(e -> { + if (e instanceof Alias) { + Alias a = (Alias) e; + collectRefs.put(a.toAttribute(), a.child()); + } + })); + + plan = plan.transformUp(p -> { + // non attribute defining plans get their references removed + if ((p instanceof Pivot || p instanceof Aggregate || p instanceof Project) == false || p.children().isEmpty()) { + p = p.transformExpressionsOnly(e -> { + if (e instanceof ReferenceAttribute) { + e = collectRefs.getOrDefault(e, e); + } + return e; + }); + } + return p; + }); + + return plan; + } + } + + static class PruneLiteralsInGroupBy extends OptimizerRule { + + @Override + protected LogicalPlan rule(Aggregate agg) { + List groupings = agg.groupings(); + List prunedGroupings = new ArrayList<>(); + + for (Expression g : groupings) { + if (g.foldable()) { + prunedGroupings.add(g); + } + } + + // everything was eliminated, the grouping + if (prunedGroupings.size() > 0) { + List newGroupings = new ArrayList<>(groupings); + newGroupings.removeAll(prunedGroupings); + return new Aggregate(agg.source(), agg.child(), newGroupings, agg.aggregates()); + } + + return agg; } } @@ -228,569 +273,12 @@ public class Optimizer extends RuleExecutor { } } - static class ReplaceDuplicateAggsWithReferences extends OptimizerRule { - - @Override - protected LogicalPlan rule(Aggregate agg) { - List aggs = agg.aggregates(); - - Map unique = new HashMap<>(); - Map reverse = new HashMap<>(); - - // find duplicates by looking at the function and canonical form - for (NamedExpression ne : aggs) { - if (ne instanceof Alias) { - Alias a = (Alias) ne; - unique.putIfAbsent(a.child(), a); - reverse.putIfAbsent(ne, a.child()); - } - else { - unique.putIfAbsent(ne.canonical(), ne); - reverse.putIfAbsent(ne, ne.canonical()); - } - } - - if (unique.size() != aggs.size()) { - List newAggs = new ArrayList<>(aggs.size()); - for (NamedExpression ne : aggs) { - newAggs.add(unique.get(reverse.get(ne))); - } - return new Aggregate(agg.source(), agg.child(), agg.groupings(), newAggs); - } - - return agg; - } - } - - static class ReplaceAggsWithMatrixStats extends Rule { - - @Override - public LogicalPlan apply(LogicalPlan p) { - Map seen = new LinkedHashMap<>(); - Map promotedFunctionIds = new LinkedHashMap<>(); - - p = p.transformExpressionsUp(e -> rule(e, seen, promotedFunctionIds)); - - // nothing found - if (seen.isEmpty()) { - return p; - } - - return ReplaceAggsWithStats.updateAggAttributes(p, promotedFunctionIds); - } - - @Override - protected LogicalPlan rule(LogicalPlan e) { - return e; - } - - protected Expression rule(Expression e, Map seen, Map promotedIds) { - if (e instanceof MatrixStatsEnclosed) { - AggregateFunction f = (AggregateFunction) e; - - Expression argument = f.field(); - MatrixStats matrixStats = seen.get(argument); - - if (matrixStats == null) { - matrixStats = new MatrixStats(f.source(), argument); - seen.put(argument, matrixStats); - } - - InnerAggregate ia = new InnerAggregate(f.source(), f, matrixStats, argument); - promotedIds.putIfAbsent(f.functionId(), ia.toAttribute()); - return ia; - } - - return e; - } - } - - static class ReplaceAggsWithExtendedStats extends Rule { - - @Override - public LogicalPlan apply(LogicalPlan p) { - Map promotedFunctionIds = new LinkedHashMap<>(); - Map seen = new LinkedHashMap<>(); - p = p.transformExpressionsUp(e -> rule(e, seen, promotedFunctionIds)); - - // nothing found - if (seen.isEmpty()) { - return p; - } - - // update old agg attributes - return ReplaceAggsWithStats.updateAggAttributes(p, promotedFunctionIds); - } - - @Override - protected LogicalPlan rule(LogicalPlan e) { - return e; - } - - protected Expression rule(Expression e, Map seen, - Map promotedIds) { - if (e instanceof ExtendedStatsEnclosed) { - AggregateFunction f = (AggregateFunction) e; - - Expression argument = f.field(); - ExtendedStats extendedStats = seen.get(argument); - - if (extendedStats == null) { - extendedStats = new ExtendedStats(f.source(), argument); - seen.put(argument, extendedStats); - } - - InnerAggregate ia = new InnerAggregate(f, extendedStats); - promotedIds.putIfAbsent(f.functionId(), ia.toAttribute()); - return ia; - } - - return e; - } - } - - static class ReplaceAggsWithStats extends Rule { - - private static class Match { - final Stats stats; - private final Set> functionTypes = new LinkedHashSet<>(); - private Map, InnerAggregate> innerAggs = null; - - Match(Stats stats) { - this.stats = stats; - } - - @Override - public String toString() { - return stats.toString(); - } - - public void add(Class aggType) { - functionTypes.add(aggType); - } - - // if the stat has at least two different functions for it, promote it as stat - // also keep the promoted function around for reuse - public AggregateFunction maybePromote(AggregateFunction agg) { - if (functionTypes.size() > 1) { - if (innerAggs == null) { - innerAggs = new LinkedHashMap<>(); - } - return innerAggs.computeIfAbsent(agg.getClass(), k -> new InnerAggregate(agg, stats)); - } - return agg; - } - } - - @Override - public LogicalPlan apply(LogicalPlan p) { - Map potentialPromotions = new LinkedHashMap<>(); - - p.forEachExpressionsUp(e -> collect(e, potentialPromotions)); - - // no promotions found - skip - if (potentialPromotions.isEmpty()) { - return p; - } - - // start promotion - - // old functionId to new function attribute - Map promotedFunctionIds = new LinkedHashMap<>(); - - // 1. promote aggs to InnerAggs - p = p.transformExpressionsUp(e -> promote(e, potentialPromotions, promotedFunctionIds)); - - // 2. update the old agg attrs to the promoted agg functions - return updateAggAttributes(p, promotedFunctionIds); - } - - @Override - protected LogicalPlan rule(LogicalPlan e) { - return e; - } - - private Expression collect(Expression e, Map seen) { - if (Stats.isTypeCompatible(e)) { - AggregateFunction f = (AggregateFunction) e; - - Expression argument = f.field(); - Match match = seen.get(argument); - - if (match == null) { - match = new Match(new Stats(new Source(f.sourceLocation(), "STATS(" + Expressions.name(argument) + ")"), argument)); - seen.put(argument, match); - } - match.add(f.getClass()); - } - - return e; - } - - private static Expression promote(Expression e, Map seen, Map attrs) { - if (Stats.isTypeCompatible(e)) { - AggregateFunction f = (AggregateFunction) e; - - Expression argument = f.field(); - Match match = seen.get(argument); - - if (match != null) { - AggregateFunction inner = match.maybePromote(f); - if (inner != f) { - attrs.putIfAbsent(f.functionId(), inner.toAttribute()); - } - return inner; - } - } - return e; - } - - static LogicalPlan updateAggAttributes(LogicalPlan p, Map promotedFunctionIds) { - // 1. update old agg function attributes - p = p.transformExpressionsUp(e -> updateAggFunctionAttrs(e, promotedFunctionIds)); - - // 2. update all scalar function consumers of the promoted aggs - // since they contain the old ids in scrips and processorDefinitions that need regenerating - - // 2a. collect ScalarFunctions that unwrapped refer to any of the updated aggregates - // 2b. replace any of the old ScalarFunction attributes - - final Set newAggIds = new LinkedHashSet<>(promotedFunctionIds.size()); - - for (AggregateFunctionAttribute afa : promotedFunctionIds.values()) { - newAggIds.add(afa.functionId()); - } - - final Map updatedScalarAttrs = new LinkedHashMap<>(); - final Map updatedScalarAliases = new LinkedHashMap<>(); - - p = p.transformExpressionsUp(e -> { - - // replace scalar attributes of the old replaced functions - if (e instanceof ScalarFunctionAttribute) { - ScalarFunctionAttribute sfa = (ScalarFunctionAttribute) e; - // check aliases - sfa = updatedScalarAttrs.getOrDefault(sfa.functionId(), sfa); - // check scalars - sfa = updatedScalarAliases.getOrDefault(sfa.id(), sfa); - return sfa; - } - - // unwrap aliases as they 'hide' functions under their own attributes - if (e instanceof Alias) { - Attribute att = Expressions.attribute(e); - if (att instanceof ScalarFunctionAttribute) { - ScalarFunctionAttribute sfa = (ScalarFunctionAttribute) att; - // the underlying function has been updated - // thus record the alias as well - if (updatedScalarAttrs.containsKey(sfa.functionId())) { - updatedScalarAliases.put(sfa.id(), sfa); - } - } - } - - else if (e instanceof ScalarFunction && false == Expressions.anyMatch(e.children(), c -> c instanceof FullTextPredicate)) { - ScalarFunction sf = (ScalarFunction) e; - - // if it's a unseen function check if the function children/arguments refers to any of the promoted aggs - if (newAggIds.isEmpty() == false && !updatedScalarAttrs.containsKey(sf.functionId()) && e.anyMatch(c -> { - Attribute a = Expressions.attribute(c); - if (a instanceof FunctionAttribute) { - return newAggIds.contains(((FunctionAttribute) a).functionId()); - } - return false; - })) { - // if so, record its attribute - updatedScalarAttrs.put(sf.functionId(), sf.toAttribute()); - } - } - - return e; - }); - - return p; - } - - - private static Expression updateAggFunctionAttrs(Expression e, Map promotedIds) { - if (e instanceof AggregateFunctionAttribute) { - AggregateFunctionAttribute ae = (AggregateFunctionAttribute) e; - AggregateFunctionAttribute promoted = promotedIds.get(ae.functionId()); - if (promoted != null) { - return ae.withFunctionId(promoted.functionId(), promoted.propertyPath()); - } - } - return e; - } - } - - static class PromoteStatsToExtendedStats extends Rule { - - @Override - public LogicalPlan apply(LogicalPlan p) { - Map seen = new LinkedHashMap<>(); - - // count the extended stats - p.forEachExpressionsUp(e -> count(e, seen)); - // then if there's a match, replace the stat inside the InnerAgg - return p.transformExpressionsUp(e -> promote(e, seen)); - } - - @Override - protected LogicalPlan rule(LogicalPlan e) { - return e; - } - - private void count(Expression e, Map seen) { - if (e instanceof InnerAggregate) { - InnerAggregate ia = (InnerAggregate) e; - if (ia.outer() instanceof ExtendedStats) { - ExtendedStats extStats = (ExtendedStats) ia.outer(); - seen.putIfAbsent(extStats.field(), extStats); - } - } - } - - protected Expression promote(Expression e, Map seen) { - if (e instanceof InnerAggregate) { - InnerAggregate ia = (InnerAggregate) e; - if (ia.outer() instanceof Stats) { - Stats stats = (Stats) ia.outer(); - ExtendedStats ext = seen.get(stats.field()); - if (ext != null && stats.field().equals(ext.field())) { - return new InnerAggregate(ia.inner(), ext); - } - } - } - - return e; - } - } - - static class ReplaceAggsWithPercentiles extends Rule { - - @Override - public LogicalPlan apply(LogicalPlan p) { - // percentile per field/expression - Map> percentsPerField = new LinkedHashMap<>(); - - // count gather the percents for each field - p.forEachExpressionsUp(e -> count(e, percentsPerField)); - - Map percentilesPerField = new LinkedHashMap<>(); - // create a Percentile agg for each field (and its associated percents) - percentsPerField.forEach((k, v) -> { - percentilesPerField.put(k, new Percentiles(v.iterator().next().source(), k, new ArrayList<>(v))); - }); - - // now replace the agg with pointer to the main ones - Map promotedFunctionIds = new LinkedHashMap<>(); - p = p.transformExpressionsUp(e -> rule(e, percentilesPerField, promotedFunctionIds)); - // finally update all the function references as well - return p.transformExpressionsDown(e -> ReplaceAggsWithStats.updateAggFunctionAttrs(e, promotedFunctionIds)); - } - - private void count(Expression e, Map> percentsPerField) { - if (e instanceof Percentile) { - Percentile p = (Percentile) e; - Expression field = p.field(); - Set percentiles = percentsPerField.get(field); - - if (percentiles == null) { - percentiles = new LinkedHashSet<>(); - percentsPerField.put(field, percentiles); - } - - percentiles.add(p.percent()); - } - } - - protected Expression rule(Expression e, Map percentilesPerField, - Map promotedIds) { - if (e instanceof Percentile) { - Percentile p = (Percentile) e; - Percentiles percentiles = percentilesPerField.get(p.field()); - - InnerAggregate ia = new InnerAggregate(p, percentiles); - promotedIds.putIfAbsent(p.functionId(), ia.toAttribute()); - return ia; - } - - return e; - } - - @Override - protected LogicalPlan rule(LogicalPlan e) { - return e; - } - } - - static class ReplaceAggsWithPercentileRanks extends Rule { - - @Override - public LogicalPlan apply(LogicalPlan p) { - // percentile per field/expression - Map> valuesPerField = new LinkedHashMap<>(); - - // count gather the percents for each field - p.forEachExpressionsUp(e -> count(e, valuesPerField)); - - Map ranksPerField = new LinkedHashMap<>(); - // create a PercentileRanks agg for each field (and its associated values) - valuesPerField.forEach((k, v) -> { - ranksPerField.put(k, new PercentileRanks(v.iterator().next().source(), k, new ArrayList<>(v))); - }); - - // now replace the agg with pointer to the main ones - Map promotedFunctionIds = new LinkedHashMap<>(); - p = p.transformExpressionsUp(e -> rule(e, ranksPerField, promotedFunctionIds)); - // finally update all the function references as well - return p.transformExpressionsDown(e -> ReplaceAggsWithStats.updateAggFunctionAttrs(e, promotedFunctionIds)); - } - - private void count(Expression e, Map> ranksPerField) { - if (e instanceof PercentileRank) { - PercentileRank p = (PercentileRank) e; - Expression field = p.field(); - Set percentiles = ranksPerField.get(field); - - if (percentiles == null) { - percentiles = new LinkedHashSet<>(); - ranksPerField.put(field, percentiles); - } - - percentiles.add(p.value()); - } - } - - protected Expression rule(Expression e, Map ranksPerField, - Map promotedIds) { - if (e instanceof PercentileRank) { - PercentileRank p = (PercentileRank) e; - PercentileRanks ranks = ranksPerField.get(p.field()); - - InnerAggregate ia = new InnerAggregate(p, ranks); - promotedIds.putIfAbsent(p.functionId(), ia.toAttribute()); - return ia; - } - - return e; - } - - @Override - protected LogicalPlan rule(LogicalPlan e) { - return e; - } - } - - static class ReplaceMinMaxWithTopHits extends OptimizerRule { - - @Override - protected LogicalPlan rule(LogicalPlan plan) { - Map seen = new HashMap<>(); - return plan.transformExpressionsDown(e -> { - if (e instanceof Min) { - Min min = (Min) e; - if (min.field().dataType().isString()) { - TopHits topHits = seen.get(min.id()); - if (topHits != null) { - return topHits; - } - topHits = new First(min.source(), min.field(), null); - seen.put(min.id(), topHits); - return topHits; - } - } - if (e instanceof Max) { - Max max = (Max) e; - if (max.field().dataType().isString()) { - TopHits topHits = seen.get(max.id()); - if (topHits != null) { - return topHits; - } - topHits = new Last(max.source(), max.field(), null); - seen.put(max.id(), topHits); - return topHits; - } - } - return e; - }); - } - } - - static class PruneFilters extends OptimizerRule { - - @Override - protected LogicalPlan rule(Filter filter) { - Expression condition = filter.condition().transformUp(PruneFilters::foldBinaryLogic); - - if (condition instanceof Literal) { - if (TRUE.equals(condition)) { - return filter.child(); - } - if (FALSE.equals(condition) || Expressions.isNull(condition)) { - return new LocalRelation(filter.source(), new EmptyExecutable(filter.output())); - } - } - - if (!condition.equals(filter.condition())) { - return new Filter(filter.source(), filter.child(), condition); - } - return filter; - } - - private static Expression foldBinaryLogic(Expression expression) { - if (expression instanceof Or) { - Or or = (Or) expression; - boolean nullLeft = Expressions.isNull(or.left()); - boolean nullRight = Expressions.isNull(or.right()); - if (nullLeft && nullRight) { - return Literal.NULL; - } - if (nullLeft) { - return or.right(); - } - if (nullRight) { - return or.left(); - } - } - if (expression instanceof And) { - And and = (And) expression; - if (Expressions.isNull(and.left()) || Expressions.isNull(and.right())) { - return Literal.NULL; - } - } - return expression; - } - } - - static class ReplaceAliasesInHaving extends OptimizerRule { - - @Override - protected LogicalPlan rule(Filter filter) { - if (filter.child() instanceof Aggregate) { - Expression cond = filter.condition(); - // resolve attributes to their actual - Expression newCondition = cond.transformDown(a -> { - - return a; - }, AggregateFunctionAttribute.class); - - if (newCondition != cond) { - return new Filter(filter.source(), filter.child(), newCondition); - } - } - return filter; - } - } - static class PruneOrderByNestedFields extends OptimizerRule { - private void findNested(Expression exp, Map functions, Consumer onFind) { + private void findNested(Expression exp, AttributeMap functions, Consumer onFind) { exp.forEachUp(e -> { - if (e instanceof FunctionAttribute) { - FunctionAttribute sfa = (FunctionAttribute) e; - Function f = functions.get(sfa.functionId()); + if (e instanceof ReferenceAttribute) { + Function f = functions.get(e); if (f != null) { findNested(f, functions, onFind); } @@ -810,8 +298,22 @@ public class Optimizer extends RuleExecutor { if (project.child() instanceof OrderBy) { OrderBy ob = (OrderBy) project.child(); - // resolve function aliases (that are hiding the target) - Map functions = Functions.collectFunctions(project); + // resolve function references (that maybe hiding the target) + final Map collectRefs = new LinkedHashMap<>(); + + // collect Attribute sources + // only Aliases are interesting since these are the only ones that hide expressions + // FieldAttribute for example are self replicating. + project.forEachUp(p -> p.forEachExpressionsUp(e -> { + if (e instanceof Alias) { + Alias a = (Alias) e; + if (a.child() instanceof Function) { + collectRefs.put(a.toAttribute(), (Function) a.child()); + } + } + })); + + AttributeMap functions = new AttributeMap<>(collectRefs); // track the direct parents Map nestedOrders = new LinkedHashMap<>(); @@ -870,7 +372,33 @@ public class Optimizer extends RuleExecutor { } } - static class PruneOrderBy extends OptimizerRule { + static class PruneLiteralsInOrderBy extends OptimizerRule { + + @Override + protected LogicalPlan rule(OrderBy ob) { + List prunedOrders = new ArrayList<>(); + + for (Order o : ob.order()) { + if (o.child().foldable()) { + prunedOrders.add(o); + } + } + + // everything was eliminated, the order isn't needed anymore + if (prunedOrders.size() == ob.order().size()) { + return ob.child(); + } + if (prunedOrders.size() > 0) { + List newOrders = new ArrayList<>(ob.order()); + newOrders.removeAll(prunedOrders); + return new OrderBy(ob.source(), ob.child(), newOrders); + } + + return ob; + } + } + + static class PruneOrderByForImplicitGrouping extends OptimizerRule { @Override protected LogicalPlan rule(OrderBy ob) { @@ -905,12 +433,10 @@ public class Optimizer extends RuleExecutor { protected LogicalPlan rule(OrderBy ob) { List order = ob.order(); - // remove constants and put the items in reverse order so the iteration happens back to front + // put the items in reverse order so the iteration happens back to front List nonConstant = new LinkedList<>(); - for (Order o : order) { - if (o.child().foldable() == false) { - nonConstant.add(0, o); - } + for (int i = order.size() - 1; i >= 0; i--) { + nonConstant.add(order.get(i)); } Holder foundAggregate = new Holder<>(Boolean.FALSE); @@ -940,13 +466,13 @@ public class Optimizer extends RuleExecutor { if ((equalsAsAttribute(child, group) && (equalsAsAttribute(alias, fieldToOrder) || equalsAsAttribute(child, fieldToOrder))) || (equalsAsAttribute(alias, group) - && (equalsAsAttribute(alias, fieldToOrder) || equalsAsAttribute(child, fieldToOrder)))) { + && (equalsAsAttribute(alias, fieldToOrder) || equalsAsAttribute(child, fieldToOrder)))) { isMatching.set(Boolean.TRUE); } } }); } - + if (isMatching.get() == true) { // move grouping in front groupings.remove(group); @@ -985,78 +511,21 @@ public class Optimizer extends RuleExecutor { @Override protected LogicalPlan rule(LogicalPlan plan) { - final Map replacedCast = new LinkedHashMap<>(); - // eliminate redundant casts LogicalPlan transformed = plan.transformExpressionsUp(e -> { if (e instanceof Cast) { Cast c = (Cast) e; - if (c.from() == c.to()) { - Expression argument = c.field(); - Alias as = new Alias(c.source(), c.sourceText(), argument); - replacedCast.put(c.toAttribute(), as.toAttribute()); - - return as; + return c.field(); } } return e; }); - // replace attributes from previous removed Casts - if (!replacedCast.isEmpty()) { - return transformed.transformUp(p -> { - List newProjections = new ArrayList<>(); - - boolean changed = false; - for (NamedExpression ne : p.projections()) { - Attribute found = replacedCast.get(ne.toAttribute()); - if (found != null) { - changed = true; - newProjections.add(found); - } - else { - newProjections.add(ne.toAttribute()); - } - } - - return changed ? new Project(p.source(), p.child(), newProjections) : p; - - }, Project.class); - } return transformed; } } - static class PruneDuplicateFunctions extends Rule { - - @Override - public LogicalPlan apply(LogicalPlan p) { - List seen = new ArrayList<>(); - return p.transformExpressionsUp(e -> rule(e, seen)); - } - - @Override - protected LogicalPlan rule(LogicalPlan e) { - return e; - } - - protected Expression rule(Expression exp, List seen) { - Expression e = exp; - if (e instanceof Function) { - Function f = (Function) e; - for (Function seenFunction : seen) { - if (seenFunction != f && f.functionEquals(seenFunction)) { - return seenFunction; - } - } - seen.add(f); - } - - return exp; - } - } - static class CombineProjections extends OptimizerRule { CombineProjections() { @@ -1203,12 +672,12 @@ public class Optimizer extends RuleExecutor { protected Expression rule(Expression e) { if (e instanceof IsNotNull) { if (((IsNotNull) e).field().nullable() == Nullability.FALSE) { - return new Literal(e.source(), Expressions.name(e), Boolean.TRUE, DataType.BOOLEAN); + return new Literal(e.source(), Boolean.TRUE, DataType.BOOLEAN); } } else if (e instanceof IsNull) { if (((IsNull) e).field().nullable() == Nullability.FALSE) { - return new Literal(e.source(), Expressions.name(e), Boolean.FALSE, DataType.BOOLEAN); + return new Literal(e.source(), Boolean.FALSE, DataType.BOOLEAN); } } else if (e instanceof In) { @@ -1220,8 +689,8 @@ public class Optimizer extends RuleExecutor { } else if (e instanceof Alias == false && e.nullable() == Nullability.TRUE && Expressions.anyMatch(e.children(), Expressions::isNull)) { - return Literal.of(e, null); - } + return Literal.of(e, null); + } return e; } @@ -1235,7 +704,7 @@ public class Optimizer extends RuleExecutor { @Override protected Expression rule(Expression e) { - return e.foldable() ? Literal.of(e) : e; + return e.foldable() && (e instanceof Literal == false) ? Literal.of(e) : e; } } @@ -1408,6 +877,7 @@ public class Optimizer extends RuleExecutor { return bc; } + @SuppressWarnings("rawtypes") private Expression simplifyNot(Not n) { Expression c = n.field(); @@ -1666,14 +1136,14 @@ public class Optimizer extends RuleExecutor { // />= else if ((other instanceof GreaterThan || other instanceof GreaterThanOrEqual) && (main instanceof LessThan || main instanceof LessThanOrEqual)) { - bcs.remove(j); - bcs.remove(i); + bcs.remove(j); + bcs.remove(i); ranges.add(new Range(and.source(), main.left(), other.right(), other instanceof GreaterThanOrEqual, main.right(), main instanceof LessThanOrEqual)); - changed = true; + changed = true; } } } @@ -1745,16 +1215,16 @@ public class Optimizer extends RuleExecutor { lowerEq = comp == 0 && main.includeLower() == other.includeLower(); // AND if (conjunctive) { - // (2 < a < 3) AND (1 < a < 3) -> (1 < a < 3) + // (2 < a < 3) AND (1 < a < 3) -> (1 < a < 3) lower = comp > 0 || - // (2 < a < 3) AND (2 < a <= 3) -> (2 < a < 3) + // (2 < a < 3) AND (2 < a <= 3) -> (2 < a < 3) (comp == 0 && !main.includeLower() && other.includeLower()); } // OR else { - // (1 < a < 3) OR (2 < a < 3) -> (1 < a < 3) + // (1 < a < 3) OR (2 < a < 3) -> (1 < a < 3) lower = comp < 0 || - // (2 <= a < 3) OR (2 < a < 3) -> (2 <= a < 3) + // (2 <= a < 3) OR (2 < a < 3) -> (2 <= a < 3) (comp == 0 && main.includeLower() && !other.includeLower()) || lowerEq; } } @@ -1771,16 +1241,16 @@ public class Optimizer extends RuleExecutor { // AND if (conjunctive) { - // (1 < a < 2) AND (1 < a < 3) -> (1 < a < 2) + // (1 < a < 2) AND (1 < a < 3) -> (1 < a < 2) upper = comp < 0 || - // (1 < a < 2) AND (1 < a <= 2) -> (1 < a < 2) + // (1 < a < 2) AND (1 < a <= 2) -> (1 < a < 2) (comp == 0 && !main.includeUpper() && other.includeUpper()); } // OR else { - // (1 < a < 3) OR (1 < a < 2) -> (1 < a < 3) + // (1 < a < 3) OR (1 < a < 2) -> (1 < a < 3) upper = comp > 0 || - // (1 < a <= 3) OR (1 < a < 3) -> (2 < a < 3) + // (1 < a <= 3) OR (1 < a < 3) -> (2 < a < 3) (comp == 0 && main.includeUpper() && !other.includeUpper()) || upperEq; } } @@ -1839,7 +1309,7 @@ public class Optimizer extends RuleExecutor { if (comp != null) { // 2 < a AND (2 <= a < 3) -> 2 < a < 3 boolean lowerEq = comp == 0 && other.includeLower() && main instanceof GreaterThan; - // 2 < a AND (1 < a < 3) -> 2 < a < 3 + // 2 < a AND (1 < a < 3) -> 2 < a < 3 boolean lower = comp > 0 || lowerEq; if (lower) { @@ -1904,18 +1374,18 @@ public class Optimizer extends RuleExecutor { Integer compare = BinaryComparison.compare(value, other.right().fold()); if (compare != null) { - // AND + // AND if ((conjunctive && - // a > 3 AND a > 2 -> a > 3 - (compare > 0 || - // a > 2 AND a >= 2 -> a > 2 - (compare == 0 && main instanceof GreaterThan && other instanceof GreaterThanOrEqual))) - || - // OR - (!conjunctive && - // a > 2 OR a > 3 -> a > 2 - (compare < 0 || - // a >= 2 OR a > 2 -> a >= 2 + // a > 3 AND a > 2 -> a > 3 + (compare > 0 || + // a > 2 AND a >= 2 -> a > 2 + (compare == 0 && main instanceof GreaterThan && other instanceof GreaterThanOrEqual))) + || + // OR + (!conjunctive && + // a > 2 OR a > 3 -> a > 2 + (compare < 0 || + // a >= 2 OR a > 2 -> a >= 2 (compare == 0 && main instanceof GreaterThanOrEqual && other instanceof GreaterThan)))) { bcs.remove(i); bcs.add(i, main); @@ -1931,40 +1401,365 @@ public class Optimizer extends RuleExecutor { else if ((other instanceof LessThan || other instanceof LessThanOrEqual) && (main instanceof LessThan || main instanceof LessThanOrEqual)) { - if (main.left().semanticEquals(other.left())) { - Integer compare = BinaryComparison.compare(value, other.right().fold()); + if (main.left().semanticEquals(other.left())) { + Integer compare = BinaryComparison.compare(value, other.right().fold()); - if (compare != null) { - // AND - if ((conjunctive && - // a < 2 AND a < 3 -> a < 2 - (compare < 0 || - // a < 2 AND a <= 2 -> a < 2 + if (compare != null) { + // AND + if ((conjunctive && + // a < 2 AND a < 3 -> a < 2 + (compare < 0 || + // a < 2 AND a <= 2 -> a < 2 (compare == 0 && main instanceof LessThan && other instanceof LessThanOrEqual))) || - // OR - (!conjunctive && - // a < 2 OR a < 3 -> a < 3 - (compare > 0 || - // a <= 2 OR a < 2 -> a <= 2 - (compare == 0 && main instanceof LessThanOrEqual && other instanceof LessThan)))) { - bcs.remove(i); - bcs.add(i, main); + // OR + (!conjunctive && + // a < 2 OR a < 3 -> a < 3 + (compare > 0 || + // a <= 2 OR a < 2 -> a <= 2 + (compare == 0 && main instanceof LessThanOrEqual && other instanceof LessThan)))) { + bcs.remove(i); + bcs.add(i, main); + } + // found a match + return true; + } + + return false; } - // found a match - return true; } - - return false; - } - } } return false; } } + + static class ReplaceAggsWithMatrixStats extends OptimizerBasicRule { + + @Override + public LogicalPlan apply(LogicalPlan p) { + // minimal reuse of the same matrix stat object + final Map seen = new LinkedHashMap<>(); + + return p.transformExpressionsUp(e -> { + if (e instanceof MatrixStatsEnclosed) { + AggregateFunction f = (AggregateFunction) e; + + Expression argument = f.field(); + MatrixStats matrixStats = seen.get(argument); + + if (matrixStats == null) { + Source source = new Source(f.sourceLocation(), "MATRIX(" + argument.sourceText() + ")"); + matrixStats = new MatrixStats(source, argument); + seen.put(argument, matrixStats); + } + + InnerAggregate ia = new InnerAggregate(f.source(), f, matrixStats, argument); + return ia; + } + + return e; + }); + } + } + + static class ReplaceAggsWithExtendedStats extends OptimizerBasicRule { + + @Override + public LogicalPlan apply(LogicalPlan p) { + // minimal reuse of the same matrix stat object + final Map seen = new LinkedHashMap<>(); + + return p.transformExpressionsUp(e -> { + if (e instanceof ExtendedStatsEnclosed) { + AggregateFunction f = (AggregateFunction) e; + + Expression argument = f.field(); + ExtendedStats extendedStats = seen.get(argument); + + if (extendedStats == null) { + Source source = new Source(f.sourceLocation(), "EXT_STATS(" + argument.sourceText() + ")"); + extendedStats = new ExtendedStats(source, argument); + seen.put(argument, extendedStats); + } + + InnerAggregate ia = new InnerAggregate(f, extendedStats); + return ia; + } + + return e; + }); + } + } + + static class ReplaceAggsWithStats extends OptimizerBasicRule { + + private static class Match { + final Stats stats; + private final Set> functionTypes = new LinkedHashSet<>(); + private Map, InnerAggregate> innerAggs = null; + + Match(Stats stats) { + this.stats = stats; + } + + @Override + public String toString() { + return stats.toString(); + } + + public void add(Class aggType) { + functionTypes.add(aggType); + } + + // if the stat has at least two different functions for it, promote it as stat + // also keep the promoted function around for reuse + public AggregateFunction maybePromote(AggregateFunction agg) { + if (functionTypes.size() > 1) { + if (innerAggs == null) { + innerAggs = new LinkedHashMap<>(); + } + return innerAggs.computeIfAbsent(agg.getClass(), k -> new InnerAggregate(agg, stats)); + } + return agg; + } + } + + @Override + public LogicalPlan apply(LogicalPlan p) { + // 1. first check whether there are at least 2 aggs for the same fields so that there can be a promotion + final Map potentialPromotions = new LinkedHashMap<>(); + + p.forEachExpressionsUp(e -> { + if (Stats.isTypeCompatible(e)) { + AggregateFunction f = (AggregateFunction) e; + + Expression argument = f.field(); + Match match = potentialPromotions.get(argument); + + if (match == null) { + Source source = new Source(f.sourceLocation(), "STATS(" + argument.sourceText() + ")"); + match = new Match(new Stats(source, argument)); + potentialPromotions.put(argument, match); + } + match.add(f.getClass()); + } + }); + + // no promotions found - skip + if (potentialPromotions.isEmpty()) { + return p; + } + + // start promotion + + // 2. promote aggs to InnerAggs + return p.transformExpressionsUp(e -> { + if (Stats.isTypeCompatible(e)) { + AggregateFunction f = (AggregateFunction) e; + + Expression argument = f.field(); + Match match = potentialPromotions.get(argument); + + if (match != null) { + return match.maybePromote(f); + } + } + return e; + }); + } + } + + static class PromoteStatsToExtendedStats extends OptimizerBasicRule { + + @Override + public LogicalPlan apply(LogicalPlan p) { + final Map seen = new LinkedHashMap<>(); + + // count the extended stats + p.forEachExpressionsUp(e -> { + if (e instanceof InnerAggregate) { + InnerAggregate ia = (InnerAggregate) e; + if (ia.outer() instanceof ExtendedStats) { + ExtendedStats extStats = (ExtendedStats) ia.outer(); + seen.putIfAbsent(extStats.field(), extStats); + } + } + }); + + // then if there's a match, replace the stat inside the InnerAgg + return p.transformExpressionsUp(e -> { + if (e instanceof InnerAggregate) { + InnerAggregate ia = (InnerAggregate) e; + if (ia.outer() instanceof Stats) { + Stats stats = (Stats) ia.outer(); + ExtendedStats ext = seen.get(stats.field()); + if (ext != null && stats.field().equals(ext.field())) { + return new InnerAggregate(ia.inner(), ext); + } + } + } + + return e; + }); + } + } + + static class ReplaceAggsWithPercentiles extends OptimizerBasicRule { + + @Override + public LogicalPlan apply(LogicalPlan p) { + // percentile per field/expression + Map> percentsPerField = new LinkedHashMap<>(); + + // count gather the percents for each field + p.forEachExpressionsUp(e -> { + if (e instanceof Percentile) { + Percentile per = (Percentile) e; + Expression field = per.field(); + Set percentiles = percentsPerField.get(field); + + if (percentiles == null) { + percentiles = new LinkedHashSet<>(); + percentsPerField.put(field, percentiles); + } + + percentiles.add(per.percent()); + } + }); + + Map percentilesPerField = new LinkedHashMap<>(); + // create a Percentile agg for each field (and its associated percents) + percentsPerField.forEach((k, v) -> { + percentilesPerField.put(k, new Percentiles(v.iterator().next().source(), k, new ArrayList<>(v))); + }); + + return p.transformExpressionsUp(e -> { + if (e instanceof Percentile) { + Percentile per = (Percentile) e; + Percentiles percentiles = percentilesPerField.get(per.field()); + return new InnerAggregate(per, percentiles); + } + + return e; + }); + } + } + + static class ReplaceAggsWithPercentileRanks extends OptimizerBasicRule { + + @Override + public LogicalPlan apply(LogicalPlan p) { + // percentile per field/expression + final Map> percentPerField = new LinkedHashMap<>(); + + // count gather the percents for each field + p.forEachExpressionsUp(e -> { + if (e instanceof PercentileRank) { + PercentileRank per = (PercentileRank) e; + Expression field = per.field(); + Set percentiles = percentPerField.get(field); + + if (percentiles == null) { + percentiles = new LinkedHashSet<>(); + percentPerField.put(field, percentiles); + } + + percentiles.add(per.value()); + } + }); + + Map ranksPerField = new LinkedHashMap<>(); + // create a PercentileRanks agg for each field (and its associated values) + percentPerField.forEach((k, v) -> { + ranksPerField.put(k, new PercentileRanks(v.iterator().next().source(), k, new ArrayList<>(v))); + }); + + return p.transformExpressionsUp(e -> { + if (e instanceof PercentileRank) { + PercentileRank per = (PercentileRank) e; + PercentileRanks ranks = ranksPerField.get(per.field()); + return new InnerAggregate(per, ranks); + } + + return e; + }); + } + } + + static class ReplaceMinMaxWithTopHits extends OptimizerRule { + + @Override + protected LogicalPlan rule(LogicalPlan plan) { + Map mins = new HashMap<>(); + Map maxs = new HashMap<>(); + return plan.transformExpressionsDown(e -> { + if (e instanceof Min) { + Min min = (Min) e; + if (min.field().dataType().isString()) { + return mins.computeIfAbsent(min.field(), k -> new First(min.source(), k, null)); + } + } + if (e instanceof Max) { + Max max = (Max) e; + if (max.field().dataType().isString()) { + return maxs.computeIfAbsent(max.field(), k -> new Last(max.source(), k, null)); + } + } + return e; + }); + } + } + + static class PruneFilters extends OptimizerRule { + + @Override + protected LogicalPlan rule(Filter filter) { + Expression condition = filter.condition().transformUp(PruneFilters::foldBinaryLogic); + + if (condition instanceof Literal) { + if (TRUE.equals(condition)) { + return filter.child(); + } + if (FALSE.equals(condition) || Expressions.isNull(condition)) { + return new LocalRelation(filter.source(), new EmptyExecutable(filter.output())); + } + } + + if (!condition.equals(filter.condition())) { + return new Filter(filter.source(), filter.child(), condition); + } + return filter; + } + + private static Expression foldBinaryLogic(Expression expression) { + if (expression instanceof Or) { + Or or = (Or) expression; + boolean nullLeft = Expressions.isNull(or.left()); + boolean nullRight = Expressions.isNull(or.right()); + if (nullLeft && nullRight) { + return Literal.NULL; + } + if (nullLeft) { + return or.right(); + } + if (nullRight) { + return or.left(); + } + } + if (expression instanceof And) { + And and = (And) expression; + if (Expressions.isNull(and.left()) || Expressions.isNull(and.right())) { + return Literal.NULL; + } + } + return expression; + } + } + + static class SkipQueryOnLimitZero extends OptimizerRule { @Override protected LogicalPlan rule(Limit limit) { @@ -2104,6 +1899,17 @@ public class Optimizer extends RuleExecutor { protected abstract Expression rule(Expression e); } + abstract static class OptimizerBasicRule extends Rule { + + @Override + public abstract LogicalPlan apply(LogicalPlan plan); + + @Override + protected LogicalPlan rule(LogicalPlan plan) { + return plan; + } + } + enum TransformDirection { UP, DOWN } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/ExpressionBuilder.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/ExpressionBuilder.java index 5a1e09f602a..524d4e8b75a 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/ExpressionBuilder.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/ExpressionBuilder.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.sql.expression.Literal; import org.elasticsearch.xpack.sql.expression.Order; import org.elasticsearch.xpack.sql.expression.Order.NullsPosition; import org.elasticsearch.xpack.sql.expression.ScalarSubquery; +import org.elasticsearch.xpack.sql.expression.UnresolvedAlias; import org.elasticsearch.xpack.sql.expression.UnresolvedAttribute; import org.elasticsearch.xpack.sql.expression.UnresolvedStar; import org.elasticsearch.xpack.sql.expression.function.Function; @@ -157,10 +158,8 @@ abstract class ExpressionBuilder extends IdentifierBuilder { public Expression visitSelectExpression(SelectExpressionContext ctx) { Expression exp = expression(ctx.expression()); String alias = visitIdentifier(ctx.identifier()); - if (alias != null) { - exp = new Alias(source(ctx), alias, exp); - } - return exp; + Source source = source(ctx); + return alias != null ? new Alias(source, alias, exp) : new UnresolvedAlias(source, exp); } @Override diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plan/logical/Pivot.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plan/logical/Pivot.java index c8067b2f2f3..35447ecb405 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plan/logical/Pivot.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plan/logical/Pivot.java @@ -6,21 +6,25 @@ package org.elasticsearch.xpack.sql.plan.logical; +import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import org.elasticsearch.xpack.sql.capabilities.Resolvables; +import org.elasticsearch.xpack.sql.expression.Alias; import org.elasticsearch.xpack.sql.expression.Attribute; +import org.elasticsearch.xpack.sql.expression.AttributeMap; import org.elasticsearch.xpack.sql.expression.AttributeSet; import org.elasticsearch.xpack.sql.expression.Expression; -import org.elasticsearch.xpack.sql.expression.ExpressionId; import org.elasticsearch.xpack.sql.expression.Expressions; +import org.elasticsearch.xpack.sql.expression.Literal; import org.elasticsearch.xpack.sql.expression.NamedExpression; import org.elasticsearch.xpack.sql.expression.function.Function; import org.elasticsearch.xpack.sql.tree.NodeInfo; import org.elasticsearch.xpack.sql.tree.Source; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; import static java.util.Collections.singletonList; @@ -29,49 +33,48 @@ public class Pivot extends UnaryPlan { private final Expression column; private final List values; private final List aggregates; + private final List grouping; // derived properties private AttributeSet groupingSet; private AttributeSet valueOutput; private List output; + private AttributeMap aliases; public Pivot(Source source, LogicalPlan child, Expression column, List values, List aggregates) { + this(source, child, column, values, aggregates, null); + } + + public Pivot(Source source, LogicalPlan child, Expression column, List values, List aggregates, + List grouping) { super(source, child); this.column = column; this.values = values; this.aggregates = aggregates; - } - - private static Expression withQualifierNull(Expression e) { - if (e instanceof Attribute) { - Attribute fa = (Attribute) e; - return fa.withQualifier(null); + + // resolve the grouping set ASAP so it doesn't get re-resolved after analysis (since the aliasing information has been removed) + if (grouping == null && expressionsResolved()) { + AttributeSet columnSet = Expressions.references(singletonList(column)); + // grouping can happen only on "primitive" fields, thus exclude multi-fields or nested docs + // the verifier enforces this rule so it does not catch folks by surprise + grouping = new ArrayList<>(new AttributeSet(Expressions.onlyPrimitiveFieldAttributes(child().output())) + // make sure to have the column as the last entry (helps with translation) so substract it + .subtract(columnSet) + .subtract(Expressions.references(aggregates)) + .combine(columnSet)); } - return e; + + this.grouping = grouping; + this.groupingSet = grouping != null ? new AttributeSet(grouping) : null; } @Override protected NodeInfo info() { - return NodeInfo.create(this, Pivot::new, child(), column, values, aggregates); + return NodeInfo.create(this, Pivot::new, child(), column, values, aggregates, grouping); } @Override protected Pivot replaceChild(LogicalPlan newChild) { - Expression newColumn = column; - List newAggregates = aggregates; - - if (newChild instanceof EsRelation) { - // when changing from a SubQueryAlias to EsRelation - // the qualifier of the column and aggregates needs - // to be changed to null like the attributes of EsRelation - // otherwise they don't equal and aren't removed - // when calculating the groupingSet - newColumn = column.transformUp(Pivot::withQualifierNull); - newAggregates = aggregates.stream().map((NamedExpression aggregate) -> - (NamedExpression) aggregate.transformUp(Pivot::withQualifierNull) - ).collect(Collectors.toList()); - } - - return new Pivot(source(), newChild, newColumn, values, newAggregates); + return new Pivot(source(), newChild, column, values, aggregates, grouping); } public Expression column() { @@ -85,39 +88,40 @@ public class Pivot extends UnaryPlan { public List aggregates() { return aggregates; } - + + public List groupings() { + return grouping; + } + public AttributeSet groupingSet() { if (groupingSet == null) { - AttributeSet columnSet = Expressions.references(singletonList(column)); - // grouping can happen only on "primitive" fields, thus exclude multi-fields or nested docs - // the verifier enforces this rule so it does not catch folks by surprise - groupingSet = new AttributeSet(Expressions.onlyPrimitiveFieldAttributes(child().output())) - // make sure to have the column as the last entry (helps with translation) - .subtract(columnSet) - .subtract(Expressions.references(aggregates)) - .combine(columnSet); + throw new SqlIllegalArgumentException("Cannot determine grouping in unresolved PIVOT"); } return groupingSet; } - public AttributeSet valuesOutput() { - // TODO: the generated id is a hack since it can clash with other potentially generated ids + private AttributeSet valuesOutput() { if (valueOutput == null) { List out = new ArrayList<>(aggregates.size() * values.size()); if (aggregates.size() == 1) { NamedExpression agg = aggregates.get(0); for (NamedExpression value : values) { - ExpressionId id = value.id(); - out.add(value.toAttribute().withDataType(agg.dataType()).withId(id)); + out.add(value.toAttribute().withDataType(agg.dataType())); } } // for multiple args, concat the function and the value else { for (NamedExpression agg : aggregates) { - String name = agg instanceof Function ? ((Function) agg).functionName() : agg.name(); + String name = agg.name(); + if (agg instanceof Alias) { + Alias a = (Alias) agg; + if (a.child() instanceof Function) { + name = ((Function) a.child()).functionName(); + } + } + //FIXME: the value attributes are reused and thus will clash - new ids need to be created for (NamedExpression value : values) { - ExpressionId id = value.id(); - out.add(value.toAttribute().withName(value.name() + "_" + name).withDataType(agg.dataType()).withId(id)); + out.add(value.toAttribute().withName(value.name() + "_" + name).withDataType(agg.dataType())); } } } @@ -125,6 +129,29 @@ public class Pivot extends UnaryPlan { } return valueOutput; } + + public AttributeMap valuesToLiterals() { + AttributeSet outValues = valuesOutput(); + Map valuesMap = new LinkedHashMap<>(); + + int index = 0; + // for each attribute, associate its value + // take into account while iterating that attributes are a multiplication of actual values + for (Attribute attribute : outValues) { + NamedExpression namedExpression = values.get(index % values.size()); + index++; + // everything should have resolved to an alias + if (namedExpression instanceof Alias) { + valuesMap.put(attribute, Literal.of(((Alias) namedExpression).child())); + } + // fallback - verifier should prevent this + else { + throw new SqlIllegalArgumentException("Unexpected alias", namedExpression); + } + } + + return new AttributeMap<>(valuesMap); + } @Override public List output() { @@ -137,6 +164,14 @@ public class Pivot extends UnaryPlan { return output; } + // Since pivot creates its own columns (and thus aliases) + // remember the backing expressions inside a dedicated aliases map + public AttributeMap aliases() { + // make sure to initialize all expressions + output(); + return aliases; + } + @Override public boolean expressionsResolved() { return column.resolved() && Resolvables.resolved(values) && Resolvables.resolved(aggregates); @@ -146,21 +181,21 @@ public class Pivot extends UnaryPlan { public int hashCode() { return Objects.hash(column, values, aggregates, child()); } - + @Override public boolean equals(Object obj) { if (this == obj) { return true; } - + if (obj == null || getClass() != obj.getClass()) { return false; } - + Pivot other = (Pivot) obj; return Objects.equals(column, other.column) && Objects.equals(values, other.values) && Objects.equals(aggregates, other.aggregates) && Objects.equals(child(), other.child()); } -} +} \ No newline at end of file diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plan/logical/SubQueryAlias.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plan/logical/SubQueryAlias.java index 980cd0a849a..dd8fa5bec43 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plan/logical/SubQueryAlias.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plan/logical/SubQueryAlias.java @@ -17,6 +17,7 @@ import static java.util.stream.Collectors.toList; public class SubQueryAlias extends UnaryPlan { private final String alias; + private List output; public SubQueryAlias(Source source, LogicalPlan child, String alias) { super(source, child); @@ -39,11 +40,13 @@ public class SubQueryAlias extends UnaryPlan { @Override public List output() { - return (alias == null ? child().output() : + if (output == null) { + output = alias == null ? child().output() : child().output().stream() .map(e -> e.withQualifier(alias)) - .collect(toList()) - ); + .collect(toList()); + } + return output; } @Override diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryFolder.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryFolder.java index 8dc9b5b595a..72e4ca380fd 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryFolder.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryFolder.java @@ -12,29 +12,34 @@ import org.elasticsearch.xpack.sql.execution.search.FieldExtraction; import org.elasticsearch.xpack.sql.expression.Alias; import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.AttributeMap; -import org.elasticsearch.xpack.sql.expression.AttributeSet; import org.elasticsearch.xpack.sql.expression.Expression; -import org.elasticsearch.xpack.sql.expression.ExpressionId; import org.elasticsearch.xpack.sql.expression.Expressions; +import org.elasticsearch.xpack.sql.expression.FieldAttribute; import org.elasticsearch.xpack.sql.expression.Foldables; +import org.elasticsearch.xpack.sql.expression.Literal; import org.elasticsearch.xpack.sql.expression.NamedExpression; import org.elasticsearch.xpack.sql.expression.Order; +import org.elasticsearch.xpack.sql.expression.ReferenceAttribute; import org.elasticsearch.xpack.sql.expression.function.Function; import org.elasticsearch.xpack.sql.expression.function.Functions; -import org.elasticsearch.xpack.sql.expression.function.ScoreAttribute; +import org.elasticsearch.xpack.sql.expression.function.Score; import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.sql.expression.function.aggregate.CompoundNumericAggregate; import org.elasticsearch.xpack.sql.expression.function.aggregate.Count; import org.elasticsearch.xpack.sql.expression.function.aggregate.InnerAggregate; import org.elasticsearch.xpack.sql.expression.function.aggregate.TopHits; import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunction; +import org.elasticsearch.xpack.sql.expression.function.grouping.Histogram; import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction; -import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeHistogramFunction; +import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.Year; import org.elasticsearch.xpack.sql.expression.gen.pipeline.AggPathInput; import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; import org.elasticsearch.xpack.sql.expression.gen.pipeline.UnaryPipe; import org.elasticsearch.xpack.sql.expression.gen.processor.Processor; +import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; +import org.elasticsearch.xpack.sql.expression.literal.IntervalYearMonth; +import org.elasticsearch.xpack.sql.expression.literal.Intervals; import org.elasticsearch.xpack.sql.plan.logical.Pivot; import org.elasticsearch.xpack.sql.plan.physical.AggregateExec; import org.elasticsearch.xpack.sql.plan.physical.EsQueryExec; @@ -45,12 +50,15 @@ import org.elasticsearch.xpack.sql.plan.physical.OrderExec; import org.elasticsearch.xpack.sql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.sql.plan.physical.PivotExec; import org.elasticsearch.xpack.sql.plan.physical.ProjectExec; -import org.elasticsearch.xpack.sql.planner.QueryTranslator.GroupingContext; import org.elasticsearch.xpack.sql.planner.QueryTranslator.QueryTranslation; import org.elasticsearch.xpack.sql.querydsl.agg.AggFilter; import org.elasticsearch.xpack.sql.querydsl.agg.Aggs; +import org.elasticsearch.xpack.sql.querydsl.agg.GroupByDateHistogram; import org.elasticsearch.xpack.sql.querydsl.agg.GroupByKey; +import org.elasticsearch.xpack.sql.querydsl.agg.GroupByNumericHistogram; +import org.elasticsearch.xpack.sql.querydsl.agg.GroupByValue; import org.elasticsearch.xpack.sql.querydsl.agg.LeafAgg; +import org.elasticsearch.xpack.sql.querydsl.container.AggregateSort; import org.elasticsearch.xpack.sql.querydsl.container.AttributeSort; import org.elasticsearch.xpack.sql.querydsl.container.ComputedRef; import org.elasticsearch.xpack.sql.querydsl.container.GlobalCountRef; @@ -69,18 +77,21 @@ import org.elasticsearch.xpack.sql.rule.Rule; import org.elasticsearch.xpack.sql.rule.RuleExecutor; import org.elasticsearch.xpack.sql.session.EmptyExecutable; import org.elasticsearch.xpack.sql.util.Check; +import org.elasticsearch.xpack.sql.util.DateUtils; +import java.time.Period; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.sql.planner.QueryTranslator.and; import static org.elasticsearch.xpack.sql.planner.QueryTranslator.toAgg; import static org.elasticsearch.xpack.sql.planner.QueryTranslator.toQuery; +import static org.elasticsearch.xpack.sql.type.DataType.DATE; import static org.elasticsearch.xpack.sql.util.CollectionUtils.combine; /** @@ -123,38 +134,26 @@ class QueryFolder extends RuleExecutor { EsQueryExec exec = (EsQueryExec) project.child(); QueryContainer queryC = exec.queryContainer(); - Map aliases = new LinkedHashMap<>(queryC.aliases()); + Map aliases = new LinkedHashMap<>(queryC.aliases()); Map processors = new LinkedHashMap<>(queryC.scalarFunctions()); for (NamedExpression pj : project.projections()) { if (pj instanceof Alias) { - Attribute aliasAttr = pj.toAttribute(); + Attribute attr = pj.toAttribute(); Expression e = ((Alias) pj).child(); - if (e instanceof NamedExpression) { - Attribute attr = ((NamedExpression) e).toAttribute(); - aliases.put(aliasAttr.id(), attr); - // add placeholder for each scalar function - if (e instanceof ScalarFunction) { - processors.put(attr, Expressions.pipe(e)); - } - } else { - processors.put(aliasAttr, Expressions.pipe(e)); - } - } - else { - // for named expressions nothing is recorded as these are resolved last - // otherwise 'intermediate' projects might pollute the - // output - if (pj instanceof ScalarFunction) { - ScalarFunction f = (ScalarFunction) pj; - processors.put(f.toAttribute(), Expressions.pipe(f)); + // track all aliases (to determine their reference later on) + aliases.put(attr, e); + + // track scalar pipelines + if (e instanceof ScalarFunction) { + processors.put(attr, ((ScalarFunction) e).asPipe()); } } } QueryContainer clone = new QueryContainer(queryC.query(), queryC.aggs(), queryC.fields(), - new HashMap<>(aliases), + new AttributeMap<>(aliases), queryC.pseudoFunctions(), new AttributeMap<>(processors), queryC.sort(), @@ -214,7 +213,176 @@ class QueryFolder extends RuleExecutor { } } - private static class FoldAggregate extends FoldingRule { + // TODO: remove exceptions from the Folder + static class FoldAggregate extends FoldingRule { + + static class GroupingContext { + final Map groupMap; + final GroupByKey tail; + + GroupingContext(Map groupMap) { + this.groupMap = groupMap; + + GroupByKey lastAgg = null; + for (Entry entry : groupMap.entrySet()) { + lastAgg = entry.getValue(); + } + + tail = lastAgg; + } + + GroupByKey groupFor(Expression exp) { + Integer hash = null; + if (Functions.isAggregate(exp)) { + AggregateFunction f = (AggregateFunction) exp; + // if there's at least one agg in the tree + if (groupMap.isEmpty() == false) { + GroupByKey matchingGroup = null; + // group found - finding the dedicated agg + // TODO: when dealing with expressions inside Aggregation, make sure to extract the field + hash = Integer.valueOf(f.field().hashCode()); + matchingGroup = groupMap.get(hash); + // return matching group or the tail (last group) + return matchingGroup != null ? matchingGroup : tail; + } else { + return null; + } + } + + hash = Integer.valueOf(exp.hashCode()); + return groupMap.get(hash); + } + + @Override + public String toString() { + return groupMap.toString(); + } + } + + /** + * Creates the list of GroupBy keys + */ + static GroupingContext groupBy(List groupings) { + if (groupings.isEmpty() == true) { + return null; + } + + Map aggMap = new LinkedHashMap<>(); + + for (Expression exp : groupings) { + GroupByKey key = null; + + Integer hash = Integer.valueOf(exp.hashCode()); + String aggId = Expressions.id(exp); + + // change analyzed to non non-analyzed attributes + if (exp instanceof FieldAttribute) { + FieldAttribute field = (FieldAttribute) exp; + field = field.exactAttribute(); + key = new GroupByValue(aggId, field.name()); + } + + // handle functions + else if (exp instanceof Function) { + // dates are handled differently because of date histograms + if (exp instanceof DateTimeHistogramFunction) { + DateTimeHistogramFunction dthf = (DateTimeHistogramFunction) exp; + + Expression field = dthf.field(); + if (field instanceof FieldAttribute) { + if (dthf.calendarInterval() != null) { + key = new GroupByDateHistogram(aggId, QueryTranslator.nameOf(exp), dthf.calendarInterval(), dthf.zoneId()); + } else { + key = new GroupByDateHistogram(aggId, QueryTranslator.nameOf(exp), dthf.fixedInterval(), dthf.zoneId()); + } + } + // use scripting for functions + else if (field instanceof Function) { + ScriptTemplate script = ((Function) field).asScript(); + if (dthf.calendarInterval() != null) { + key = new GroupByDateHistogram(aggId, script, dthf.calendarInterval(), dthf.zoneId()); + } else { + key = new GroupByDateHistogram(aggId, script, dthf.fixedInterval(), dthf.zoneId()); + } + } + } + // all other scalar functions become a script + else if (exp instanceof ScalarFunction) { + ScalarFunction sf = (ScalarFunction) exp; + key = new GroupByValue(aggId, sf.asScript()); + } + // histogram + else if (exp instanceof GroupingFunction) { + if (exp instanceof Histogram) { + Histogram h = (Histogram) exp; + Expression field = h.field(); + + // date histogram + if (h.dataType().isDateBased()) { + Object value = h.interval().value(); + // interval of exactly 1 year + if (value instanceof IntervalYearMonth + && ((IntervalYearMonth) value).interval().equals(Period.ofYears(1))) { + String calendarInterval = Year.YEAR_INTERVAL; + + // When the histogram is `INTERVAL '1' YEAR`, the interval used in the ES date_histogram will be + // a calendar_interval with value "1y". All other intervals will be fixed_intervals expressed in ms. + if (field instanceof FieldAttribute) { + key = new GroupByDateHistogram(aggId, QueryTranslator.nameOf(field), calendarInterval, h.zoneId()); + } else if (field instanceof Function) { + key = new GroupByDateHistogram(aggId, ((Function) field).asScript(), calendarInterval, h.zoneId()); + } + } + // typical interval + else { + long intervalAsMillis = Intervals.inMillis(h.interval()); + + // When the histogram in SQL is applied on DATE type instead of DATETIME, the interval + // specified is truncated to the multiple of a day. If the interval specified is less + // than 1 day, then the interval used will be `INTERVAL '1' DAY`. + if (h.dataType() == DATE) { + intervalAsMillis = DateUtils.minDayInterval(intervalAsMillis); + } + + if (field instanceof FieldAttribute) { + key = new GroupByDateHistogram(aggId, QueryTranslator.nameOf(field), intervalAsMillis, h.zoneId()); + } else if (field instanceof Function) { + key = new GroupByDateHistogram(aggId, ((Function) field).asScript(), intervalAsMillis, h.zoneId()); + } + } + } + // numeric histogram + else { + if (field instanceof FieldAttribute) { + key = new GroupByNumericHistogram(aggId, QueryTranslator.nameOf(field), + Foldables.doubleValueOf(h.interval())); + } else if (field instanceof Function) { + key = new GroupByNumericHistogram(aggId, ((Function) field).asScript(), + Foldables.doubleValueOf(h.interval())); + } + } + if (key == null) { + throw new SqlIllegalArgumentException("Unsupported histogram field {}", field); + } + } else { + throw new SqlIllegalArgumentException("Unsupproted grouping function {}", exp); + } + } + // bumped into into an invalid function (which should be caught by the verifier) + else { + throw new SqlIllegalArgumentException("Cannot GROUP BY function {}", exp); + } + } + // catch corner-case + else { + throw new SqlIllegalArgumentException("Cannot GROUP BY {}", exp); + } + + aggMap.put(hash, key); + } + return new GroupingContext(aggMap); + } + @Override protected PhysicalPlan rule(AggregateExec a) { if (a.child() instanceof EsQueryExec) { @@ -225,46 +393,71 @@ class QueryFolder extends RuleExecutor { } static EsQueryExec fold(AggregateExec a, EsQueryExec exec) { - // build the group aggregation - // and also collect info about it (since the group columns might be used inside the select) - - GroupingContext groupingContext = QueryTranslator.groupBy(a.groupings()); - + QueryContainer queryC = exec.queryContainer(); + + // track aliases defined in the SELECT and used inside GROUP BY + // SELECT x AS a ... GROUP BY a + Map aliasMap = new LinkedHashMap<>(); + for (NamedExpression ne : a.aggregates()) { + if (ne instanceof Alias) { + aliasMap.put(ne.toAttribute(), ((Alias) ne).child()); + } + } + + if (aliasMap.isEmpty() == false) { + Map newAliases = new LinkedHashMap<>(queryC.aliases()); + newAliases.putAll(aliasMap); + queryC = queryC.withAliases(new AttributeMap<>(newAliases)); + } + + // build the group aggregation + // NB: any reference in grouping is already "optimized" by its source so there's no need to look for aliases + GroupingContext groupingContext = groupBy(a.groupings()); + if (groupingContext != null) { queryC = queryC.addGroups(groupingContext.groupMap.values()); } - Map aliases = new LinkedHashMap<>(); // tracker for compound aggs seen in a group Map compoundAggMap = new LinkedHashMap<>(); // followed by actual aggregates for (NamedExpression ne : a.aggregates()) { - // unwrap alias - it can be - // - an attribute (since we support aliases inside group-by) - // SELECT emp_no ... GROUP BY emp_no + // unwrap alias (since we support aliases declared inside SELECTs to be used by the GROUP BY) + // An alias can point to : + // - field + // SELECT emp_no AS e ... GROUP BY e + // - a function // SELECT YEAR(hire_date) ... GROUP BY YEAR(hire_date) - // - an agg function (typically) + // - an agg function over the grouped field // SELECT COUNT(*), AVG(salary) ... GROUP BY salary; - // - a scalar function, which can be applied on an attribute or aggregate and can require one or multiple inputs + // - a scalar function, which can be applied on a column or aggregate and can require one or multiple inputs // SELECT SIN(emp_no) ... GROUP BY emp_no // SELECT CAST(YEAR(hire_date)) ... GROUP BY YEAR(hire_date) // SELECT CAST(AVG(salary)) ... GROUP BY salary // SELECT AVG(salary) + SIN(MIN(salary)) ... GROUP BY salary - if (ne instanceof Alias || ne instanceof Function) { - Alias as = ne instanceof Alias ? (Alias) ne : null; - Expression child = as != null ? as.child() : ne; + Expression target = ne; - // record aliases in case they are later referred in the tree - if (as != null && as.child() instanceof NamedExpression) { - aliases.put(as.toAttribute().id(), ((NamedExpression) as.child()).toAttribute()); - } + // unwrap aliases since it's the children we are interested in + if (ne instanceof Alias) { + target = ((Alias) ne).child(); + } + + String id = Expressions.id(target); + + // literal + if (target.foldable()) { + queryC = queryC.addColumn(ne.toAttribute()); + } + + // look at functions + else if (target instanceof Function) { // // look first for scalar functions which might wrap the actual grouped target @@ -273,12 +466,14 @@ class QueryFolder extends RuleExecutor { // ABS(YEAR(field)) GROUP BY YEAR(field) or // ABS(AVG(salary)) ... GROUP BY salary // ) - if (child instanceof ScalarFunction) { - ScalarFunction f = (ScalarFunction) child; + + if (target instanceof ScalarFunction) { + ScalarFunction f = (ScalarFunction) target; Pipe proc = f.asPipe(); final AtomicReference qC = new AtomicReference<>(queryC); + // traverse the pipe to find the mandatory grouping expression proc = proc.transformUp(p -> { // bail out if the def is resolved if (p.resolved()) { @@ -295,6 +490,7 @@ class QueryFolder extends RuleExecutor { } else { // a scalar function can be used only if has already been mentioned for grouping // (otherwise it is the opposite of grouping) + // normally this case should be caught by the Verifier if (exp instanceof ScalarFunction) { throw new FoldingException(exp, "Scalar function " + exp.toString() + " can be used only if included already in grouping"); @@ -332,77 +528,71 @@ class QueryFolder extends RuleExecutor { return p; }); - if (!proc.resolved()) { - throw new FoldingException(child, "Cannot find grouping for '{}'", Expressions.name(child)); + if (proc.resolved() == false) { + throw new FoldingException(target, "Cannot find grouping for '{}'", Expressions.name(target)); } // add the computed column - queryC = qC.get().addColumn(new ComputedRef(proc), f.toAttribute()); - - // TODO: is this needed? - // redirect the alias to the scalar group id (changing the id altogether doesn't work it is - // already used in the aggpath) - //aliases.put(as.toAttribute(), sf.toAttribute()); + queryC = qC.get().addColumn(new ComputedRef(proc), id); } + // apply the same logic above (for function inputs) to non-scalar functions with small variations: // instead of adding things as input, add them as full blown column else { GroupByKey matchingGroup = null; if (groupingContext != null) { // is there a group (aggregation) for this expression ? - matchingGroup = groupingContext.groupFor(child); + matchingGroup = groupingContext.groupFor(target); } // attributes can only refer to declared groups - if (child instanceof Attribute) { - Check.notNull(matchingGroup, "Cannot find group [{}]", Expressions.name(child)); - queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, child.dataType().isDateBased()), - ((Attribute) child)); + if (target instanceof Attribute) { + Check.notNull(matchingGroup, "Cannot find group [{}]", Expressions.name(target)); + queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, target.dataType().isDateBased()), id); } // handle histogram - else if (child instanceof GroupingFunction) { - queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, child.dataType().isDateBased()), - ((GroupingFunction) child).toAttribute()); + else if (target instanceof GroupingFunction) { + queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, target.dataType().isDateBased()), id); + } + // handle literal + else if (target.foldable()) { + queryC = queryC.addColumn(ne.toAttribute()); } - else if (child.foldable()) { - queryC = queryC.addColumn(ne.toAttribute()); - } // fallback to regular agg functions else { // the only thing left is agg function - Check.isTrue(Functions.isAggregate(child), "Expected aggregate function inside alias; got [{}]", - child.nodeString()); - AggregateFunction af = (AggregateFunction) child; + Check.isTrue(Functions.isAggregate(target), "Expected aggregate function inside alias; got [{}]", + target.nodeString()); + AggregateFunction af = (AggregateFunction) target; Tuple withAgg = addAggFunction(matchingGroup, af, compoundAggMap, queryC); // make sure to add the inner id (to handle compound aggs) - queryC = withAgg.v1().addColumn(withAgg.v2().context(), af.toAttribute()); + queryC = withAgg.v1().addColumn(withAgg.v2().context(), id); } } - // not an Alias or Function means it's an Attribute so apply the same logic as above - } else { + + } + // not a Function or literal, means its has to be a field or field expression + else { GroupByKey matchingGroup = null; if (groupingContext != null) { - matchingGroup = groupingContext.groupFor(ne); + matchingGroup = groupingContext.groupFor(target); Check.notNull(matchingGroup, "Cannot find group [{}]", Expressions.name(ne)); - queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, ne.dataType().isDateBased()), ne.toAttribute()); + queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, ne.dataType().isDateBased()), id); } - else if (ne.foldable()) { - queryC = queryC.addColumn(ne.toAttribute()); - } + // fallback + else { + throw new SqlIllegalArgumentException("Cannot fold aggregate {}", ne); } } - - if (!aliases.isEmpty()) { - Map newAliases = new LinkedHashMap<>(queryC.aliases()); - newAliases.putAll(aliases); - queryC = queryC.withAliases(new HashMap<>(newAliases)); } + return new EsQueryExec(exec.source(), exec.index(), a.output(), queryC); } private static Tuple addAggFunction(GroupByKey groupingAgg, AggregateFunction f, Map compoundAggMap, QueryContainer queryC) { - String functionId = f.functionId(); + + String functionId = Expressions.id(f); // handle count as a special case agg if (f instanceof Count) { Count c = (Count) f; @@ -422,7 +612,7 @@ class QueryFolder extends RuleExecutor { pseudoFunctions.put(functionId, groupingAgg); return new Tuple<>(queryC.withPseudoFunctions(pseudoFunctions), new AggPathInput(f, ref)); // COUNT() - } else if (!c.distinct()) { + } else if (c.distinct() == false) { LeafAgg leafAgg = toAgg(functionId, f); AggPathInput a = new AggPathInput(f, new MetricAggRef(leafAgg.id(), "doc_count", "_count", false)); queryC = queryC.with(queryC.aggs().addAgg(leafAgg)); @@ -440,7 +630,7 @@ class QueryFolder extends RuleExecutor { // the compound agg hasn't been seen before so initialize it if (cAggPath == null) { - LeafAgg leafAgg = toAgg(outer.functionId(), outer); + LeafAgg leafAgg = toAgg(Expressions.id(outer), outer); cAggPath = leafAgg.id(); compoundAggMap.put(outer, cAggPath); // add the agg (without any reference) @@ -480,37 +670,60 @@ class QueryFolder extends RuleExecutor { Missing missing = Missing.from(order.nullsPosition()); // check whether sorting is on an group (and thus nested agg) or field - Attribute attr = ((NamedExpression) order.child()).toAttribute(); - // check whether there's an alias (occurs with scalar functions which are not named) - attr = qContainer.aliases().getOrDefault(attr.id(), attr); - GroupByKey group = qContainer.findGroupForAgg(attr); + Expression orderExpression = order.child(); + + // if it's a reference, get the target expression + if (orderExpression instanceof ReferenceAttribute) { + orderExpression = qContainer.aliases().get(orderExpression); + } + String lookup = Expressions.id(orderExpression); + GroupByKey group = qContainer.findGroupForAgg(lookup); // TODO: might need to validate whether the target field or group actually exist if (group != null && group != Aggs.IMPLICIT_GROUP_KEY) { - qContainer = qContainer.updateGroup(group.with(direction)); + // check whether the lookup matches a group + if (group.id().equals(lookup)) { + qContainer = qContainer.updateGroup(group.with(direction)); + } + // else it's a leafAgg + else { + qContainer = qContainer.updateGroup(group.with(direction)); + } } else { // scalar functions typically require script ordering - if (attr instanceof ScalarFunctionAttribute) { - ScalarFunctionAttribute sfa = (ScalarFunctionAttribute) attr; + if (orderExpression instanceof ScalarFunction) { + ScalarFunction sf = (ScalarFunction) orderExpression; // is there an expression to order by? - if (sfa.orderBy() != null) { - if (sfa.orderBy() instanceof NamedExpression) { - Attribute at = ((NamedExpression) sfa.orderBy()).toAttribute(); - at = qContainer.aliases().getOrDefault(at.id(), at); - qContainer = qContainer.addSort(new AttributeSort(at, direction, missing)); - } else if (!sfa.orderBy().foldable()) { + if (sf.orderBy() != null) { + Expression orderBy = sf.orderBy(); + if (orderBy instanceof NamedExpression) { + orderBy = qContainer.aliases().getOrDefault(orderBy, orderBy); + qContainer = qContainer + .addSort(new AttributeSort(((NamedExpression) orderBy).toAttribute(), direction, missing)); + } else if (orderBy.foldable() == false) { // ignore constant - throw new PlanningException("does not know how to order by expression {}", sfa.orderBy()); + throw new PlanningException("does not know how to order by expression {}", orderBy); } } else { // nope, use scripted sorting - qContainer = qContainer.addSort(new ScriptSort(sfa.script(), direction, missing)); + qContainer = qContainer.addSort(new ScriptSort(sf.asScript(), direction, missing)); } - } else if (attr instanceof ScoreAttribute) { + } + // score + else if (orderExpression instanceof Score) { qContainer = qContainer.addSort(new ScoreSort(direction, missing)); + } + // field + else if (orderExpression instanceof FieldAttribute) { + qContainer = qContainer.addSort(new AttributeSort((FieldAttribute) orderExpression, direction, missing)); + } + // agg function + else if (orderExpression instanceof AggregateFunction) { + qContainer = qContainer.addSort(new AggregateSort((AggregateFunction) orderExpression, direction, missing)); } else { - qContainer = qContainer.addSort(new AttributeSort(attr, direction, missing)); + // unknown + throw new SqlIllegalArgumentException("unsupported sorting expression {}", orderExpression); } } } @@ -573,21 +786,22 @@ class QueryFolder extends RuleExecutor { // due to the Pivot structure - the column is the last entry in the grouping set QueryContainer query = fold.queryContainer(); - List> fields = new ArrayList<>(query.fields()); + List> fields = new ArrayList<>(query.fields()); int startingIndex = fields.size() - p.aggregates().size() - 1; // pivot grouping - Tuple groupTuple = fields.remove(startingIndex); - AttributeSet valuesOutput = plan.pivot().valuesOutput(); + Tuple groupTuple = fields.remove(startingIndex); + AttributeMap values = p.valuesToLiterals(); for (int i = startingIndex; i < fields.size(); i++) { - Tuple tuple = fields.remove(i); - for (Attribute attribute : valuesOutput) { - fields.add(new Tuple<>(new PivotColumnRef(groupTuple.v1(), tuple.v1(), attribute.fold()), attribute.id())); + Tuple tuple = fields.remove(i); + for (Map.Entry entry : values.entrySet()) { + fields.add(new Tuple<>( + new PivotColumnRef(groupTuple.v1(), tuple.v1(), entry.getValue().value()), Expressions.id(entry.getKey()))); } - i += valuesOutput.size(); + i += values.size(); } - return fold.with(new QueryContainer(query.query(), query.aggs(), + return fold.with(new QueryContainer(query.query(), query.aggs(), fields, query.aliases(), query.pseudoFunctions(), @@ -596,7 +810,7 @@ class QueryFolder extends RuleExecutor { query.limit(), query.shouldTrackHits(), query.shouldIncludeFrozen(), - valuesOutput.size())); + values.size())); } return plan; } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java index 4fbcc76ff82..149999a8802 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java @@ -10,15 +10,12 @@ import org.elasticsearch.geometry.Geometry; import org.elasticsearch.geometry.Point; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; -import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.Expressions; import org.elasticsearch.xpack.sql.expression.FieldAttribute; -import org.elasticsearch.xpack.sql.expression.Foldables; import org.elasticsearch.xpack.sql.expression.Literal; import org.elasticsearch.xpack.sql.expression.NamedExpression; import org.elasticsearch.xpack.sql.expression.function.Function; -import org.elasticsearch.xpack.sql.expression.function.Functions; import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.sql.expression.function.aggregate.Avg; import org.elasticsearch.xpack.sql.expression.function.aggregate.CompoundNumericAggregate; @@ -35,17 +32,11 @@ import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentiles; import org.elasticsearch.xpack.sql.expression.function.aggregate.Stats; import org.elasticsearch.xpack.sql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.sql.expression.function.aggregate.TopHits; -import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunction; -import org.elasticsearch.xpack.sql.expression.function.grouping.Histogram; import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeFunction; -import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeHistogramFunction; -import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.Year; import org.elasticsearch.xpack.sql.expression.function.scalar.geo.GeoShape; import org.elasticsearch.xpack.sql.expression.function.scalar.geo.StDistance; import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; -import org.elasticsearch.xpack.sql.expression.literal.IntervalYearMonth; -import org.elasticsearch.xpack.sql.expression.literal.Intervals; import org.elasticsearch.xpack.sql.expression.predicate.Range; import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MatchQueryPredicate; import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MultiMatchQueryPredicate; @@ -74,10 +65,6 @@ import org.elasticsearch.xpack.sql.querydsl.agg.AvgAgg; import org.elasticsearch.xpack.sql.querydsl.agg.CardinalityAgg; import org.elasticsearch.xpack.sql.querydsl.agg.ExtendedStatsAgg; import org.elasticsearch.xpack.sql.querydsl.agg.FilterExistsAgg; -import org.elasticsearch.xpack.sql.querydsl.agg.GroupByDateHistogram; -import org.elasticsearch.xpack.sql.querydsl.agg.GroupByKey; -import org.elasticsearch.xpack.sql.querydsl.agg.GroupByNumericHistogram; -import org.elasticsearch.xpack.sql.querydsl.agg.GroupByValue; import org.elasticsearch.xpack.sql.querydsl.agg.LeafAgg; import org.elasticsearch.xpack.sql.querydsl.agg.MatrixStatsAgg; import org.elasticsearch.xpack.sql.querydsl.agg.MaxAgg; @@ -106,25 +93,20 @@ import org.elasticsearch.xpack.sql.querydsl.query.TermsQuery; import org.elasticsearch.xpack.sql.querydsl.query.WildcardQuery; import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.util.Check; -import org.elasticsearch.xpack.sql.util.DateUtils; import org.elasticsearch.xpack.sql.util.Holder; import org.elasticsearch.xpack.sql.util.ReflectionUtils; import java.time.OffsetTime; -import java.time.Period; import java.time.ZonedDateTime; import java.time.temporal.TemporalAccessor; import java.util.Arrays; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; -import java.util.Map.Entry; import java.util.function.Supplier; import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.sql.expression.Expressions.id; import static org.elasticsearch.xpack.sql.expression.Foldables.doubleValuesOf; import static org.elasticsearch.xpack.sql.expression.Foldables.valueOf; -import static org.elasticsearch.xpack.sql.type.DataType.DATE; final class QueryTranslator { @@ -208,174 +190,6 @@ final class QueryTranslator { throw new SqlIllegalArgumentException("Don't know how to translate {} {}", f.nodeName(), f); } - static class GroupingContext { - final Map groupMap; - final GroupByKey tail; - - GroupingContext(Map groupMap) { - this.groupMap = groupMap; - - GroupByKey lastAgg = null; - for (Entry entry : groupMap.entrySet()) { - lastAgg = entry.getValue(); - } - - tail = lastAgg; - } - - GroupByKey groupFor(Expression exp) { - if (Functions.isAggregate(exp)) { - AggregateFunction f = (AggregateFunction) exp; - // if there's at least one agg in the tree - if (!groupMap.isEmpty()) { - GroupByKey matchingGroup = null; - // group found - finding the dedicated agg - if (f.field() instanceof NamedExpression) { - matchingGroup = groupMap.get(f.field()); - } - // return matching group or the tail (last group) - return matchingGroup != null ? matchingGroup : tail; - } - else { - return null; - } - } - if (exp instanceof NamedExpression) { - return groupMap.get(exp); - } - throw new SqlIllegalArgumentException("Don't know how to find group for expression {}", exp); - } - - @Override - public String toString() { - return groupMap.toString(); - } - } - - /** - * Creates the list of GroupBy keys - */ - static GroupingContext groupBy(List groupings) { - if (groupings.isEmpty()) { - return null; - } - - Map aggMap = new LinkedHashMap<>(); - - for (Expression exp : groupings) { - GroupByKey key = null; - NamedExpression id; - String aggId; - - if (exp instanceof NamedExpression) { - NamedExpression ne = (NamedExpression) exp; - - id = ne; - aggId = ne.id().toString(); - - // change analyzed to non non-analyzed attributes - if (exp instanceof FieldAttribute) { - ne = ((FieldAttribute) exp).exactAttribute(); - } - - // handle functions differently - if (exp instanceof Function) { - // dates are handled differently because of date histograms - if (exp instanceof DateTimeHistogramFunction) { - DateTimeHistogramFunction dthf = (DateTimeHistogramFunction) exp; - Expression field = dthf.field(); - if (field instanceof FieldAttribute) { - if (dthf.calendarInterval() != null) { - key = new GroupByDateHistogram(aggId, nameOf(field), dthf.calendarInterval(), dthf.zoneId()); - } else { - key = new GroupByDateHistogram(aggId, nameOf(field), dthf.fixedInterval(), dthf.zoneId()); - } - } else if (field instanceof Function) { - ScriptTemplate script = ((Function) field).asScript(); - if (dthf.calendarInterval() != null) { - key = new GroupByDateHistogram(aggId, script, dthf.calendarInterval(), dthf.zoneId()); - } else { - key = new GroupByDateHistogram(aggId, script, dthf.fixedInterval(), dthf.zoneId()); - } - } - } - // all other scalar functions become a script - else if (exp instanceof ScalarFunction) { - ScalarFunction sf = (ScalarFunction) exp; - key = new GroupByValue(aggId, sf.asScript()); - } - // histogram - else if (exp instanceof GroupingFunction) { - if (exp instanceof Histogram) { - Histogram h = (Histogram) exp; - Expression field = h.field(); - - // date histogram - if (h.dataType().isDateBased()) { - Object value = h.interval().value(); - if (value instanceof IntervalYearMonth - && ((IntervalYearMonth) value).interval().equals(Period.of(1, 0, 0))) { - String calendarInterval = Year.YEAR_INTERVAL; - - // When the histogram is `INTERVAL '1' YEAR`, the interval used in the ES date_histogram will be - // a calendar_interval with value "1y". All other intervals will be fixed_intervals expressed in ms. - if (field instanceof FieldAttribute) { - key = new GroupByDateHistogram(aggId, nameOf(field), calendarInterval, h.zoneId()); - } else if (field instanceof Function) { - key = new GroupByDateHistogram(aggId, ((Function) field).asScript(), calendarInterval, h.zoneId()); - } - } else { - long intervalAsMillis = Intervals.inMillis(h.interval()); - - // When the histogram in SQL is applied on DATE type instead of DATETIME, the interval - // specified is truncated to the multiple of a day. If the interval specified is less - // than 1 day, then the interval used will be `INTERVAL '1' DAY`. - if (h.dataType() == DATE) { - intervalAsMillis = DateUtils.minDayInterval(intervalAsMillis); - } - - if (field instanceof FieldAttribute) { - key = new GroupByDateHistogram(aggId, nameOf(field), intervalAsMillis, h.zoneId()); - } else if (field instanceof Function) { - key = new GroupByDateHistogram(aggId, ((Function) field).asScript(), intervalAsMillis, h.zoneId()); - } - } - } - // numeric histogram - else { - if (field instanceof FieldAttribute) { - key = new GroupByNumericHistogram(aggId, nameOf(field), Foldables.doubleValueOf(h.interval())); - } else if (field instanceof Function) { - key = new GroupByNumericHistogram(aggId, ((Function) field).asScript(), - Foldables.doubleValueOf(h.interval())); - } - } - if (key == null) { - throw new SqlIllegalArgumentException("Unsupported histogram field {}", field); - } - } - else { - throw new SqlIllegalArgumentException("Unsupproted grouping function {}", exp); - } - } - // bumped into into an invalid function (which should be caught by the verifier) - else { - throw new SqlIllegalArgumentException("Cannot GROUP BY function {}", exp); - } - } - else { - key = new GroupByValue(aggId, ne.name()); - } - } - else { - throw new SqlIllegalArgumentException("Don't know how to group on {}", exp.nodeString()); - } - - aggMap.put(id, key); - } - return new GroupingContext(aggMap); - } - static QueryTranslation and(Source source, QueryTranslation left, QueryTranslation right) { Check.isTrue(left != null || right != null, "Both expressions are null"); if (left == null) { @@ -464,17 +278,9 @@ final class QueryTranslator { if (e instanceof NamedExpression) { return ((NamedExpression) e).name(); } - if (e instanceof Literal) { - return String.valueOf(e.fold()); + else { + return e.sourceText(); } - throw new SqlIllegalArgumentException("Cannot determine name for {}", e); - } - - static String idOf(Expression e) { - if (e instanceof NamedExpression) { - return ((NamedExpression) e).id().toString(); - } - throw new SqlIllegalArgumentException("Cannot determine id for {}", e); } static String dateFormat(Expression e) { @@ -525,7 +331,7 @@ final class QueryTranslator { if (e.field() instanceof FieldAttribute) { targetFieldName = nameOf(((FieldAttribute) e.field()).exactAttribute()); } else { - throw new SqlIllegalArgumentException("Scalar function [{}] not allowed (yet) as argument for " + e.functionName(), + throw new SqlIllegalArgumentException("Scalar function [{}] not allowed (yet) as argument for " + e.sourceText(), Expressions.name(e.field())); } @@ -590,7 +396,7 @@ final class QueryTranslator { AggFilter aggFilter = null; if (onAggs) { - aggFilter = new AggFilter(not.id().toString(), not.asScript()); + aggFilter = new AggFilter(id(not), not.asScript()); } else { Expression e = not.field(); Query wrappedQuery = toQuery(not.field(), false).query; @@ -616,7 +422,7 @@ final class QueryTranslator { AggFilter aggFilter = null; if (onAggs) { - aggFilter = new AggFilter(isNotNull.id().toString(), isNotNull.asScript()); + aggFilter = new AggFilter(id(isNotNull), isNotNull.asScript()); } else { Query q = null; if (isNotNull.field() instanceof FieldAttribute) { @@ -640,7 +446,7 @@ final class QueryTranslator { AggFilter aggFilter = null; if (onAggs) { - aggFilter = new AggFilter(isNull.id().toString(), isNull.asScript()); + aggFilter = new AggFilter(id(isNull), isNull.asScript()); } else { Query q = null; if (isNull.field() instanceof FieldAttribute) { @@ -667,30 +473,18 @@ final class QueryTranslator { bc.right().sourceLocation().getLineNumber(), bc.right().sourceLocation().getColumnNumber(), Expressions.name(bc.right()), bc.symbol()); - if (bc.left() instanceof NamedExpression) { - NamedExpression ne = (NamedExpression) bc.left(); + Query query = null; + AggFilter aggFilter = null; - Query query = null; - AggFilter aggFilter = null; - - Attribute at = ne.toAttribute(); - // - // Agg context means HAVING -> PipelineAggs - // - if (onAggs) { - aggFilter = new AggFilter(at.id().toString(), bc.asScript()); - } - else { - query = handleQuery(bc, ne, () -> translateQuery(bc)); - } - return new QueryTranslation(query, aggFilter); - } // - // if the code gets here it's a bug + // Agg context means HAVING -> PipelineAggs // - else { - throw new SqlIllegalArgumentException("No idea how to translate " + bc.left()); + if (onAggs) { + aggFilter = new AggFilter(id(bc.left()), bc.asScript()); + } else { + query = handleQuery(bc, bc.left(), () -> translateQuery(bc)); } + return new QueryTranslation(query, aggFilter); } private static Query translateQuery(BinaryComparison bc) { @@ -778,39 +572,28 @@ final class QueryTranslator { @Override protected QueryTranslation asQuery(In in, boolean onAggs) { - if (in.value() instanceof NamedExpression) { - NamedExpression ne = (NamedExpression) in.value(); + Query query = null; + AggFilter aggFilter = null; - Query query = null; - AggFilter aggFilter = null; - - Attribute at = ne.toAttribute(); - // - // Agg context means HAVING -> PipelineAggs - // - if (onAggs) { - aggFilter = new AggFilter(at.id().toString(), in.asScript()); - } - else { - Query q = null; - if (in.value() instanceof FieldAttribute) { - FieldAttribute fa = (FieldAttribute) in.value(); - // equality should always be against an exact match (which is important for strings) - q = new TermsQuery(in.source(), fa.exactAttribute().name(), in.list()); - } else { - q = new ScriptQuery(in.source(), in.asScript()); - } - Query qu = q; - query = handleQuery(in, ne, () -> qu); - } - return new QueryTranslation(query, aggFilter); + // + // Agg context means HAVING -> PipelineAggs + // + if (onAggs) { + aggFilter = new AggFilter(id(in.value()), in.asScript()); } - // - // if the code gets here it's a bug - // else { - throw new SqlIllegalArgumentException("No idea how to translate " + in.value()); + Query q = null; + if (in.value() instanceof FieldAttribute) { + FieldAttribute fa = (FieldAttribute) in.value(); + // equality should always be against an exact match (which is important for strings) + q = new TermsQuery(in.source(), fa.exactAttribute().name(), in.list()); + } else { + q = new ScriptQuery(in.source(), in.asScript()); + } + Query qu = q; + query = handleQuery(in, in.value(), () -> qu); } + return new QueryTranslation(query, aggFilter); } } @@ -820,53 +603,48 @@ final class QueryTranslator { protected QueryTranslation asQuery(Range r, boolean onAggs) { Expression e = r.value(); - if (e instanceof NamedExpression) { - Query query = null; - AggFilter aggFilter = null; + Query query = null; + AggFilter aggFilter = null; - // - // Agg context means HAVING -> PipelineAggs - // - Attribute at = ((NamedExpression) e).toAttribute(); - - if (onAggs) { - aggFilter = new AggFilter(at.id().toString(), r.asScript()); - } else { - Holder lower = new Holder<>(valueOf(r.lower())); - Holder upper = new Holder<>(valueOf(r.upper())); - Holder format = new Holder<>(dateFormat(r.value())); - - // for a date constant comparison, we need to use a format for the date, to make sure that the format is the same - // no matter the timezone provided by the user - if (format.get() == null) { - DateFormatter formatter = null; - if (lower.get() instanceof ZonedDateTime || upper.get() instanceof ZonedDateTime) { - formatter = DateFormatter.forPattern(DATE_FORMAT); - } else if (lower.get() instanceof OffsetTime || upper.get() instanceof OffsetTime) { - formatter = DateFormatter.forPattern(TIME_FORMAT); - } - if (formatter != null) { - // RangeQueryBuilder accepts an Object as its parameter, but it will call .toString() on the ZonedDateTime - // instance which can have a slightly different format depending on the ZoneId used to create the ZonedDateTime - // Since RangeQueryBuilder can handle date as String as well, we'll format it as String and provide the format. - if (lower.get() instanceof ZonedDateTime || lower.get() instanceof OffsetTime) { - lower.set(formatter.format((TemporalAccessor) lower.get())); - } - if (upper.get() instanceof ZonedDateTime || upper.get() instanceof OffsetTime) { - upper.set(formatter.format((TemporalAccessor) upper.get())); - } - format.set(formatter.pattern()); - } - } - - query = handleQuery(r, r.value(), - () -> new RangeQuery(r.source(), nameOf(r.value()), lower.get(), r.includeLower(), - upper.get(), r.includeUpper(), format.get())); - } - return new QueryTranslation(query, aggFilter); + // + // Agg context means HAVING -> PipelineAggs + // + if (onAggs) { + aggFilter = new AggFilter(id(e), r.asScript()); } else { - throw new SqlIllegalArgumentException("No idea how to translate " + e); + + Holder lower = new Holder<>(valueOf(r.lower())); + Holder upper = new Holder<>(valueOf(r.upper())); + Holder format = new Holder<>(dateFormat(r.value())); + + // for a date constant comparison, we need to use a format for the date, to make sure that the format is the same + // no matter the timezone provided by the user + if (format.get() == null) { + DateFormatter formatter = null; + if (lower.get() instanceof ZonedDateTime || upper.get() instanceof ZonedDateTime) { + formatter = DateFormatter.forPattern(DATE_FORMAT); + } else if (lower.get() instanceof OffsetTime || upper.get() instanceof OffsetTime) { + formatter = DateFormatter.forPattern(TIME_FORMAT); + } + if (formatter != null) { + // RangeQueryBuilder accepts an Object as its parameter, but it will call .toString() on the ZonedDateTime + // instance which can have a slightly different format depending on the ZoneId used to create the ZonedDateTime + // Since RangeQueryBuilder can handle date as String as well, we'll format it as String and provide the format. + if (lower.get() instanceof ZonedDateTime || lower.get() instanceof OffsetTime) { + lower.set(formatter.format((TemporalAccessor) lower.get())); + } + if (upper.get() instanceof ZonedDateTime || upper.get() instanceof OffsetTime) { + upper.set(formatter.format((TemporalAccessor) upper.get())); + } + format.set(formatter.pattern()); + } + } + + query = handleQuery(r, r.value(), + () -> new RangeQuery(r.source(), nameOf(r.value()), lower.get(), r.includeLower(), upper.get(), r.includeUpper(), + format.get())); } + return new QueryTranslation(query, aggFilter); } } @@ -880,7 +658,7 @@ final class QueryTranslator { AggFilter aggFilter = null; if (onAggs) { - aggFilter = new AggFilter(f.id().toString(), script); + aggFilter = new AggFilter(id(f), script); } else { query = handleQuery(f, f, () -> new ScriptQuery(f.source(), script)); } @@ -1086,4 +864,4 @@ final class QueryTranslator { return query; } } -} +} \ No newline at end of file diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/Aggs.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/Aggs.java index 632eb729936..94f854c29f0 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/Aggs.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/Aggs.java @@ -10,8 +10,6 @@ import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregati import org.elasticsearch.search.aggregations.bucket.composite.CompositeValuesSourceBuilder; import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregationBuilder; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; -import org.elasticsearch.xpack.sql.expression.Attribute; -import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute; import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; import org.elasticsearch.xpack.sql.querydsl.container.Sort.Direction; import org.elasticsearch.xpack.sql.util.StringUtils; @@ -123,23 +121,16 @@ public class Aggs { return new Aggs(groups, simpleAggs, combine(pipelineAggs, pipelineAgg)); } - public GroupByKey findGroupForAgg(Attribute attr) { - String id = attr.id().toString(); + public GroupByKey findGroupForAgg(String groupOrAggId) { for (GroupByKey group : this.groups) { - if (id.equals(group.id())) { + if (groupOrAggId.equals(group.id())) { return group; } - if (attr instanceof ScalarFunctionAttribute) { - ScalarFunctionAttribute sfa = (ScalarFunctionAttribute) attr; - if (group.script() != null && group.script().equals(sfa.script())) { - return group; - } - } } // maybe it's the default group agg ? for (Agg agg : simpleAggs) { - if (id.equals(agg.id())) { + if (groupOrAggId.equals(agg.id())) { return IMPLICIT_GROUP_KEY; } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/container/AggregateSort.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/container/AggregateSort.java new file mode 100644 index 00000000000..966f5c50796 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/container/AggregateSort.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.sql.querydsl.container; + +import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction; + +import java.util.Objects; + +public class AggregateSort extends Sort { + + private final AggregateFunction agg; + + public AggregateSort(AggregateFunction agg, Direction direction, Missing missing) { + super(direction, missing); + this.agg = agg; + } + + public AggregateFunction agg() { + return agg; + } + + @Override + public int hashCode() { + return Objects.hash(agg, direction(), missing()); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + AggregateSort other = (AggregateSort) obj; + return Objects.equals(direction(), other.direction()) + && Objects.equals(missing(), other.missing()) + && Objects.equals(agg, other.agg); + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/container/QueryContainer.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/container/QueryContainer.java index 3dd1a2ac108..2e388f94af3 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/container/QueryContainer.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/container/QueryContainer.java @@ -16,13 +16,15 @@ import org.elasticsearch.xpack.sql.execution.search.FieldExtraction; import org.elasticsearch.xpack.sql.execution.search.SourceGenerator; import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.AttributeMap; -import org.elasticsearch.xpack.sql.expression.ExpressionId; +import org.elasticsearch.xpack.sql.expression.Expression; +import org.elasticsearch.xpack.sql.expression.Expressions; import org.elasticsearch.xpack.sql.expression.FieldAttribute; -import org.elasticsearch.xpack.sql.expression.LiteralAttribute; -import org.elasticsearch.xpack.sql.expression.function.ScoreAttribute; -import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute; -import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute; +import org.elasticsearch.xpack.sql.expression.function.Score; +import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction; +import org.elasticsearch.xpack.sql.expression.gen.pipeline.ConstantInput; import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; +import org.elasticsearch.xpack.sql.expression.gen.pipeline.ScorePipe; import org.elasticsearch.xpack.sql.querydsl.agg.Aggs; import org.elasticsearch.xpack.sql.querydsl.agg.GroupByKey; import org.elasticsearch.xpack.sql.querydsl.agg.LeafAgg; @@ -38,7 +40,6 @@ import java.util.AbstractMap; import java.util.ArrayList; import java.util.BitSet; import java.util.Collection; -import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -67,11 +68,11 @@ public class QueryContainer { // for example in case of grouping or custom sorting, the response has extra columns // that is filtered before getting to the client - // the list contains both the field extraction and the id of its associated attribute (for custom sorting) - private final List> fields; + // the list contains both the field extraction and its id (for custom sorting) + private final List> fields; - // aliases (maps an alias to its actual resolved attribute) - private final Map aliases; + // aliases found in the tree + private final AttributeMap aliases; // pseudo functions (like count) - that are 'extracted' from other aggs private final Map pseudoFunctions; @@ -90,6 +91,9 @@ public class QueryContainer { // computed private Boolean aggsOnly; private Boolean customSort; + // associate Attributes with aliased FieldAttributes (since they map directly to ES fields) + private Map fieldAlias; + public QueryContainer() { this(null, null, null, null, null, null, null, -1, false, false, -1); @@ -97,9 +101,8 @@ public class QueryContainer { public QueryContainer(Query query, Aggs aggs, - List> fields, - Map aliases, + List> fields, + AttributeMap aliases, Map pseudoFunctions, AttributeMap scalarFunctions, Set sort, @@ -110,7 +113,7 @@ public class QueryContainer { this.query = query; this.aggs = aggs == null ? Aggs.EMPTY : aggs; this.fields = fields == null || fields.isEmpty() ? emptyList() : fields; - this.aliases = aliases == null || aliases.isEmpty() ? Collections.emptyMap() : aliases; + this.aliases = aliases == null || aliases.isEmpty() ? AttributeMap.emptyAttributeMap() : aliases; this.pseudoFunctions = pseudoFunctions == null || pseudoFunctions.isEmpty() ? emptyMap() : pseudoFunctions; this.scalarFunctions = scalarFunctions == null || scalarFunctions.isEmpty() ? AttributeMap.emptyAttributeMap() : scalarFunctions; this.sort = sort == null || sort.isEmpty() ? emptySet() : sort; @@ -136,31 +139,30 @@ public class QueryContainer { for (Sort s : sort) { Tuple tuple = new Tuple<>(Integer.valueOf(-1), null); - if (s instanceof AttributeSort) { - AttributeSort as = (AttributeSort) s; + if (s instanceof AggregateSort) { + AggregateSort as = (AggregateSort) s; // find the relevant column of each aggregate function - if (as.attribute() instanceof AggregateFunctionAttribute) { - aggSort = true; - AggregateFunctionAttribute afa = (AggregateFunctionAttribute) as.attribute(); - afa = (AggregateFunctionAttribute) aliases.getOrDefault(afa.innerId(), afa); - int atIndex = -1; - for (int i = 0; i < fields.size(); i++) { - Tuple field = fields.get(i); - if (field.v2().equals(afa.innerId())) { - atIndex = i; - break; - } - } + AggregateFunction af = as.agg(); - if (atIndex == -1) { - throw new SqlIllegalArgumentException("Cannot find backing column for ordering aggregation [{}]", afa.name()); - } - // assemble a comparator for it - Comparator comp = s.direction() == Sort.Direction.ASC ? Comparator.naturalOrder() : Comparator.reverseOrder(); - comp = s.missing() == Sort.Missing.FIRST ? Comparator.nullsFirst(comp) : Comparator.nullsLast(comp); + aggSort = true; + int atIndex = -1; + String id = Expressions.id(af); - tuple = new Tuple<>(Integer.valueOf(atIndex), comp); + for (int i = 0; i < fields.size(); i++) { + Tuple field = fields.get(i); + if (field.v2().equals(id)) { + atIndex = i; + break; + } } + if (atIndex == -1) { + throw new SqlIllegalArgumentException("Cannot find backing column for ordering aggregation [{}]", s); + } + // assemble a comparator for it + Comparator comp = s.direction() == Sort.Direction.ASC ? Comparator.naturalOrder() : Comparator.reverseOrder(); + comp = s.missing() == Sort.Missing.FIRST ? Comparator.nullsFirst(comp) : Comparator.nullsLast(comp); + + tuple = new Tuple<>(Integer.valueOf(atIndex), comp); } sortingColumns.add(tuple); } @@ -179,19 +181,20 @@ public class QueryContainer { */ public BitSet columnMask(List columns) { BitSet mask = new BitSet(fields.size()); + aliasName(columns.get(0)); + for (Attribute column : columns) { - Attribute alias = aliases.get(column.id()); + Expression expression = aliases.getOrDefault(column, column); + // find the column index + String id = Expressions.id(expression); int index = -1; - ExpressionId id = column instanceof AggregateFunctionAttribute ? ((AggregateFunctionAttribute) column).innerId() : column.id(); - ExpressionId aliasId = alias != null ? (alias instanceof AggregateFunctionAttribute ? ((AggregateFunctionAttribute) alias) - .innerId() : alias.id()) : null; for (int i = 0; i < fields.size(); i++) { - Tuple tuple = fields.get(i); + Tuple tuple = fields.get(i); // if the index is already set there is a collision, // so continue searching for the other tuple with the same id - if (mask.get(i)==false && (tuple.v2().equals(id) || (aliasId != null && tuple.v2().equals(aliasId)))) { + if (mask.get(i) == false && tuple.v2().equals(id)) { index = i; break; } @@ -214,11 +217,11 @@ public class QueryContainer { return aggs; } - public List> fields() { + public List> fields() { return fields; } - public Map aliases() { + public AttributeMap aliases() { return aliases; } @@ -267,12 +270,7 @@ public class QueryContainer { minPageSize); } - public QueryContainer withFields(List> f) { - return new QueryContainer(query, aggs, f, aliases, pseudoFunctions, scalarFunctions, sort, limit, trackHits, includeFrozen, - minPageSize); - } - - public QueryContainer withAliases(Map a) { + public QueryContainer withAliases(AttributeMap a) { return new QueryContainer(query, aggs, fields, a, pseudoFunctions, scalarFunctions, sort, limit, trackHits, includeFrozen, minPageSize); } @@ -313,7 +311,16 @@ public class QueryContainer { } private String aliasName(Attribute attr) { - return aliases.getOrDefault(attr.id(), attr).name(); + if (fieldAlias == null) { + fieldAlias = new LinkedHashMap<>(); + for (Map.Entry entry : aliases.entrySet()) { + if (entry.getValue() instanceof FieldAttribute) { + fieldAlias.put(entry.getKey(), (FieldAttribute) entry.getValue()); + } + } + } + FieldAttribute fa = fieldAlias.get(attr); + return fa != null ? fa.name() : attr.name(); } // @@ -397,17 +404,8 @@ public class QueryContainer { } // replace function/operators's input with references - private Tuple resolvedTreeComputingRef(ScalarFunctionAttribute ta) { - Attribute attribute = aliases.getOrDefault(ta.id(), ta); - Pipe proc = scalarFunctions.get(attribute); - - // check the attribute itself - if (proc == null) { - if (attribute instanceof ScalarFunctionAttribute) { - ta = (ScalarFunctionAttribute) attribute; - } - proc = ta.asPipe(); - } + private Tuple resolvedTreeComputingRef(ScalarFunction function, Attribute attr) { + Pipe proc = scalarFunctions.computeIfAbsent(attr, v -> function.asPipe()); // find the processor inputs (Attributes) and convert them into references // no need to promote them to the top since the container doesn't have to be aware @@ -420,8 +418,7 @@ public class QueryContainer { @Override public FieldExtraction resolve(Attribute attribute) { - Attribute attr = aliases.getOrDefault(attribute.id(), attribute); - Tuple ref = container.toReference(attr); + Tuple ref = container.asFieldExtraction(attribute); container = ref.v1(); return ref.v2(); } @@ -430,42 +427,55 @@ public class QueryContainer { proc = proc.resolveAttributes(resolver); QueryContainer qContainer = resolver.container; - // update proc - Map procs = new LinkedHashMap<>(qContainer.scalarFunctions()); - procs.put(attribute, proc); - qContainer = qContainer.withScalarProcessors(new AttributeMap<>(procs)); + // update proc (if needed) + if (qContainer.scalarFunctions().size() != scalarFunctions.size()) { + Map procs = new LinkedHashMap<>(qContainer.scalarFunctions()); + procs.put(attr, proc); + qContainer = qContainer.withScalarProcessors(new AttributeMap<>(procs)); + } + return new Tuple<>(qContainer, new ComputedRef(proc)); } public QueryContainer addColumn(Attribute attr) { - Tuple tuple = toReference(attr); - return tuple.v1().addColumn(tuple.v2(), attr); + Expression expression = aliases.getOrDefault(attr, attr); + Tuple tuple = asFieldExtraction(attr); + return tuple.v1().addColumn(tuple.v2(), Expressions.id(expression)); } - private Tuple toReference(Attribute attr) { - if (attr instanceof FieldAttribute) { - FieldAttribute fa = (FieldAttribute) attr; + private Tuple asFieldExtraction(Attribute attr) { + // resolve it Expression + Expression expression = aliases.getOrDefault(attr, attr); + + if (expression instanceof FieldAttribute) { + FieldAttribute fa = (FieldAttribute) expression; if (fa.isNested()) { return nestedHitFieldRef(fa); } else { return new Tuple<>(this, topHitFieldRef(fa)); } } - if (attr instanceof ScalarFunctionAttribute) { - return resolvedTreeComputingRef((ScalarFunctionAttribute) attr); + + if (expression == null) { + throw new SqlIllegalArgumentException("Unknown output attribute {}", attr); } - if (attr instanceof LiteralAttribute) { - return new Tuple<>(this, new ComputedRef(((LiteralAttribute) attr).asPipe())); + + if (expression.foldable()) { + return new Tuple<>(this, new ComputedRef(new ConstantInput(expression.source(), expression, expression.fold()))); } - if (attr instanceof ScoreAttribute) { - return new Tuple<>(this, new ComputedRef(((ScoreAttribute) attr).asPipe())); + + if (expression instanceof Score) { + return new Tuple<>(this, new ComputedRef(new ScorePipe(expression.source(), expression))); + } + + if (expression instanceof ScalarFunction) { + return resolvedTreeComputingRef((ScalarFunction) expression, attr); } throw new SqlIllegalArgumentException("Unknown output attribute {}", attr); } - public QueryContainer addColumn(FieldExtraction ref, Attribute attr) { - ExpressionId id = attr instanceof AggregateFunctionAttribute ? ((AggregateFunctionAttribute) attr).innerId() : attr.id(); + public QueryContainer addColumn(FieldExtraction ref, String id) { return new QueryContainer(query, aggs, combine(fields, new Tuple<>(ref, id)), aliases, pseudoFunctions, scalarFunctions, sort, limit, trackHits, includeFrozen, minPageSize); @@ -487,8 +497,8 @@ public class QueryContainer { return with(aggs.addGroups(values)); } - public GroupByKey findGroupForAgg(Attribute attr) { - return aggs.findGroupForAgg(attr); + public GroupByKey findGroupForAgg(String aggId) { + return aggs.findGroupForAgg(aggId); } public QueryContainer updateGroup(GroupByKey group) { diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/tree/Node.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/tree/Node.java index 2e40244a415..0d686fc5cf1 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/tree/Node.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/tree/Node.java @@ -9,6 +9,7 @@ import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import java.util.ArrayList; import java.util.BitSet; +import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.function.Consumer; @@ -378,7 +379,10 @@ public abstract class Node> { if (needsComma) { sb.append(","); } - String stringValue = Objects.toString(prop); + + String stringValue = toString(prop); + + //: Objects.toString(prop); if (maxWidth + stringValue.length() > TO_STRING_MAX_WIDTH) { int cutoff = Math.max(0, TO_STRING_MAX_WIDTH - maxWidth); sb.append(stringValue.substring(0, cutoff)); @@ -395,4 +399,28 @@ public abstract class Node> { return sb.toString(); } + + private String toString(Object obj) { + StringBuilder sb = new StringBuilder(); + toString(sb, obj); + return sb.toString(); + } + + private void toString(StringBuilder sb, Object obj) { + if (obj instanceof Iterable) { + sb.append("["); + for (Iterator it = ((Iterable) obj).iterator(); it.hasNext();) { + Object o = it.next(); + toString(sb, o); + if (it.hasNext() == true) { + sb.append(", "); + } + } + sb.append("]"); + } else if (obj instanceof Node) { + sb.append(((Node) obj).nodeString()); + } else { + sb.append(Objects.toString(obj)); + } + } } \ No newline at end of file diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java index 04e58d9fc87..08c2548a41b 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java @@ -437,7 +437,7 @@ public class VerifierErrorMessagesTests extends ESTestCase { } public void testGroupByOrderByFieldFromGroupByFunction() { - assertEquals("1:54: Cannot use non-grouped column [int], expected [ABS(int)]", + assertEquals("1:54: Cannot order by non-grouped column [int], expected [ABS(int)]", error("SELECT ABS(int) FROM test GROUP BY ABS(int) ORDER BY int")); } @@ -613,9 +613,9 @@ public class VerifierErrorMessagesTests extends ESTestCase { } public void testInvalidTypeForStringFunction_WithTwoArgs() { - assertEquals("1:8: first argument of [CONCAT] must be [string], found value [1] type [integer]", + assertEquals("1:8: first argument of [CONCAT(1, 'bar')] must be [string], found value [1] type [integer]", error("SELECT CONCAT(1, 'bar')")); - assertEquals("1:8: second argument of [CONCAT] must be [string], found value [2] type [integer]", + assertEquals("1:8: second argument of [CONCAT('foo', 2)] must be [string], found value [2] type [integer]", error("SELECT CONCAT('foo', 2)")); } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/execution/search/SourceGeneratorTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/execution/search/SourceGeneratorTests.java index fce24758a3b..7efbea74241 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/execution/search/SourceGeneratorTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/execution/search/SourceGeneratorTests.java @@ -14,7 +14,11 @@ import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.sql.expression.Attribute; +import org.elasticsearch.xpack.sql.expression.AttributeMap; +import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.FieldAttribute; +import org.elasticsearch.xpack.sql.expression.ReferenceAttribute; import org.elasticsearch.xpack.sql.expression.function.Score; import org.elasticsearch.xpack.sql.querydsl.agg.AvgAgg; import org.elasticsearch.xpack.sql.querydsl.agg.GroupByValue; @@ -27,6 +31,9 @@ import org.elasticsearch.xpack.sql.querydsl.query.MatchQuery; import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.type.KeywordEsField; +import java.util.LinkedHashMap; +import java.util.Map; + import static java.util.Collections.singletonList; import static org.elasticsearch.index.query.QueryBuilders.boolQuery; import static org.elasticsearch.index.query.QueryBuilders.matchQuery; @@ -79,7 +86,11 @@ public class SourceGeneratorTests extends ESTestCase { } public void testSelectScoreForcesTrackingScore() { - QueryContainer container = new QueryContainer().addColumn(new Score(Source.EMPTY).toAttribute()); + Score score = new Score(Source.EMPTY); + ReferenceAttribute attr = new ReferenceAttribute(score.source(), "score", score.dataType()); + Map alias = new LinkedHashMap<>(); + alias.put(attr, score); + QueryContainer container = new QueryContainer().withAliases(new AttributeMap<>(alias)).addColumn(attr); SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10)); assertTrue(sourceBuilder.trackScores()); } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/AttributeMapTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/AttributeMapTests.java index f2a6045124e..ee977687d90 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/AttributeMapTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/AttributeMapTests.java @@ -55,7 +55,7 @@ public class AttributeMapTests extends ESTestCase { Attribute one = m.keySet().iterator().next(); assertThat(m.containsKey(one), is(true)); - assertThat(m.containsKey(a("one")), is(true)); + assertThat(m.containsKey(a("one")), is(false)); assertThat(m.containsValue("one"), is(true)); assertThat(m.containsValue("on"), is(false)); assertThat(m.attributeNames(), contains("one", "two", "three")); @@ -74,7 +74,7 @@ public class AttributeMapTests extends ESTestCase { assertThat(m.isEmpty(), is(false)); assertThat(m.containsKey(one), is(true)); - assertThat(m.containsKey(a("one")), is(true)); + assertThat(m.containsKey(a("one")), is(false)); assertThat(m.containsValue("one"), is(true)); assertThat(m.containsValue("on"), is(false)); } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/ExpressionIdTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/ExpressionIdTests.java index 3efa228f7cc..dfbe3410434 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/ExpressionIdTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/ExpressionIdTests.java @@ -10,11 +10,11 @@ import java.util.concurrent.atomic.AtomicLong; public class ExpressionIdTests extends ESTestCase { /** - * Each {@link ExpressionId} should be unique. Technically + * Each {@link NameId} should be unique. Technically * you can roll the {@link AtomicLong} that backs them but * that is not going to happen within a single query. */ public void testUnique() { - assertNotEquals(new ExpressionId(), new ExpressionId()); + assertNotEquals(new NameId(), new NameId()); } } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/LiteralTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/LiteralTests.java index 2d36cb1e1e5..cd5e736c47c 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/LiteralTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/LiteralTests.java @@ -61,7 +61,7 @@ public class LiteralTests extends AbstractNodeTestCase { @Override protected Literal copy(Literal instance) { - return new Literal(instance.source(), instance.name(), instance.value(), instance.dataType()); + return new Literal(instance.source(), instance.value(), instance.dataType()); } @Override diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/UnresolvedAttributeTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/UnresolvedAttributeTests.java index 4deca1d1f63..a40e7661dc0 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/UnresolvedAttributeTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/UnresolvedAttributeTests.java @@ -18,7 +18,7 @@ public class UnresolvedAttributeTests extends AbstractNodeTestCase Objects.equals(v, a.qualifier()) ? newQualifier : v, Object.class)); - ExpressionId newId = new ExpressionId(); + NameId newId = new NameId(); assertEquals(new UnresolvedAttribute(a.source(), a.name(), a.qualifier(), newId, a.unresolvedMessage(), a.resolutionMetadata()), a.transformPropertiesOnly(v -> Objects.equals(v, a.id()) ? newId : v, Object.class)); diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/scalar/DatabaseFunctionTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/scalar/DatabaseFunctionTests.java index 0156d8fdfb5..8ad04d83c4c 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/scalar/DatabaseFunctionTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/scalar/DatabaseFunctionTests.java @@ -11,6 +11,8 @@ import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer; import org.elasticsearch.xpack.sql.analysis.analyzer.Verifier; import org.elasticsearch.xpack.sql.analysis.index.EsIndex; import org.elasticsearch.xpack.sql.analysis.index.IndexResolution; +import org.elasticsearch.xpack.sql.expression.Alias; +import org.elasticsearch.xpack.sql.expression.NamedExpression; import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.sql.parser.SqlParser; import org.elasticsearch.xpack.sql.plan.logical.Project; @@ -38,7 +40,9 @@ public class DatabaseFunctionTests extends ESTestCase { ); Project result = (Project) analyzer.analyze(parser.createStatement("SELECT DATABASE()"), true); - assertTrue(result.projections().get(0) instanceof Database); - assertEquals(clusterName, ((Database) result.projections().get(0)).fold()); + NamedExpression ne = result.projections().get(0); + assertTrue(ne instanceof Alias); + assertTrue(((Alias) ne).child() instanceof Database); + assertEquals(clusterName, ((Database) ((Alias) ne).child()).fold()); } } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/scalar/UserFunctionTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/scalar/UserFunctionTests.java index f8b3ed19764..a6e8d83a336 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/scalar/UserFunctionTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/scalar/UserFunctionTests.java @@ -11,6 +11,8 @@ import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer; import org.elasticsearch.xpack.sql.analysis.analyzer.Verifier; import org.elasticsearch.xpack.sql.analysis.index.EsIndex; import org.elasticsearch.xpack.sql.analysis.index.IndexResolution; +import org.elasticsearch.xpack.sql.expression.Alias; +import org.elasticsearch.xpack.sql.expression.NamedExpression; import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.sql.parser.SqlParser; import org.elasticsearch.xpack.sql.plan.logical.Project; @@ -28,9 +30,9 @@ public class UserFunctionTests extends ESTestCase { EsIndex test = new EsIndex("test", TypesTests.loadMapping("mapping-basic.json", true)); Analyzer analyzer = new Analyzer( new Configuration(DateUtils.UTC, Protocol.FETCH_SIZE, Protocol.REQUEST_TIMEOUT, - Protocol.PAGE_TIMEOUT, null, - randomFrom(Mode.values()), randomAlphaOfLength(10), - null, randomAlphaOfLengthBetween(1, 15), + Protocol.PAGE_TIMEOUT, null, + randomFrom(Mode.values()), randomAlphaOfLength(10), + null, randomAlphaOfLengthBetween(1, 15), randomBoolean(), randomBoolean()), new FunctionRegistry(), IndexResolution.valid(test), @@ -38,7 +40,9 @@ public class UserFunctionTests extends ESTestCase { ); Project result = (Project) analyzer.analyze(parser.createStatement("SELECT USER()"), true); - assertTrue(result.projections().get(0) instanceof User); - assertNull(((User) result.projections().get(0)).fold()); + NamedExpression ne = result.projections().get(0); + assertTrue(ne instanceof Alias); + assertTrue(((Alias) ne).child() instanceof User); + assertNull(((User) ((Alias) ne).child()).fold()); } } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseTests.java index 00004598f5c..899da8049b9 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseTests.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.sql.expression.predicate.conditional; import org.elasticsearch.xpack.sql.expression.Expression; +import org.elasticsearch.xpack.sql.expression.Expression.TypeResolution; import org.elasticsearch.xpack.sql.expression.Literal; import org.elasticsearch.xpack.sql.expression.function.scalar.FunctionTestUtils; import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Equals; @@ -21,7 +22,6 @@ import java.util.Collections; import java.util.List; import java.util.Objects; -import static org.elasticsearch.xpack.sql.expression.Expression.TypeResolution; import static org.elasticsearch.xpack.sql.expression.function.scalar.FunctionTestUtils.randomIntLiteral; import static org.elasticsearch.xpack.sql.expression.function.scalar.FunctionTestUtils.randomStringLiteral; import static org.elasticsearch.xpack.sql.tree.Source.EMPTY; @@ -69,10 +69,6 @@ public class CaseTests extends AbstractNodeTestCase { Source newSource = randomValueOtherThan(c.source(), SourceTests::randomSource); assertEquals(new Case(c.source(), c.children()), c.transformPropertiesOnly(p -> Objects.equals(p, c.source()) ? newSource: p, Object.class)); - - String newName = randomValueOtherThan(c.name(), () -> randomAlphaOfLength(5)); - assertEquals(new Case(c.source(), c.children()), - c.transformPropertiesOnly(p -> Objects.equals(p, c.name()) ? newName : p, Object.class)); } @Override diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/IifTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/IifTests.java index a07663b188d..6b468fcb8fb 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/IifTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/IifTests.java @@ -62,10 +62,6 @@ public class IifTests extends AbstractNodeTestCase { Source newSource = randomValueOtherThan(iif.source(), SourceTests::randomSource); assertEquals(new Iif(iif.source(), iif.conditions().get(0).condition(), iif.conditions().get(0).result(), iif.elseResult()), iif.transformPropertiesOnly(p -> Objects.equals(p, iif.source()) ? newSource: p, Object.class)); - - String newName = randomValueOtherThan(iif.name(), () -> randomAlphaOfLength(5)); - assertEquals(new Iif(iif.source(), iif.conditions().get(0).condition(), iif.conditions().get(0).result(), iif.elseResult()), - iif.transformPropertiesOnly(p -> Objects.equals(p, iif.name()) ? newName : p, Object.class)); } @Override diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java index 8efb6874289..c81b376c0a5 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java @@ -22,7 +22,6 @@ import org.elasticsearch.xpack.sql.expression.Order.OrderDirection; import org.elasticsearch.xpack.sql.expression.function.Function; import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.sql.expression.function.aggregate.Avg; -import org.elasticsearch.xpack.sql.expression.function.aggregate.Count; import org.elasticsearch.xpack.sql.expression.function.aggregate.ExtendedStats; import org.elasticsearch.xpack.sql.expression.function.aggregate.First; import org.elasticsearch.xpack.sql.expression.function.aggregate.InnerAggregate; @@ -101,11 +100,11 @@ import org.elasticsearch.xpack.sql.optimizer.Optimizer.CombineProjections; import org.elasticsearch.xpack.sql.optimizer.Optimizer.ConstantFolding; import org.elasticsearch.xpack.sql.optimizer.Optimizer.FoldNull; import org.elasticsearch.xpack.sql.optimizer.Optimizer.PropagateEquals; -import org.elasticsearch.xpack.sql.optimizer.Optimizer.PruneDuplicateFunctions; import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceAggsWithExtendedStats; import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceAggsWithStats; import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceFoldableAttributes; import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceMinMaxWithTopHits; +import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceReferenceAttributeWithSource; import org.elasticsearch.xpack.sql.optimizer.Optimizer.RewritePivot; import org.elasticsearch.xpack.sql.optimizer.Optimizer.SimplifyCase; import org.elasticsearch.xpack.sql.optimizer.Optimizer.SimplifyConditional; @@ -144,7 +143,7 @@ import static org.elasticsearch.xpack.sql.expression.Literal.of; import static org.elasticsearch.xpack.sql.tree.Source.EMPTY; import static org.elasticsearch.xpack.sql.util.DateUtils.UTC; import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.startsWith; +import static org.hamcrest.Matchers.is; public class OptimizerTests extends ESTestCase { @@ -210,6 +209,10 @@ public class OptimizerTests extends ESTestCase { return of(EMPTY, value); } + private static Alias a(String name, Expression e) { + return new Alias(e.source(), name, e); + } + private static FieldAttribute getFieldAttribute() { return getFieldAttribute("a"); } @@ -225,20 +228,6 @@ public class OptimizerTests extends ESTestCase { assertEquals(result, s); } - public void testDuplicateFunctions() { - AggregateFunction f1 = new Count(EMPTY, TRUE, false); - AggregateFunction f2 = new Count(EMPTY, TRUE, false); - - assertTrue(f1.functionEquals(f2)); - - Project p = new Project(EMPTY, FROM(), Arrays.asList(f1, f2)); - LogicalPlan result = new PruneDuplicateFunctions().apply(p); - assertTrue(result instanceof Project); - List projections = ((Project) result).projections(); - assertEquals(2, projections.size()); - assertEquals(projections.get(0), projections.get(1)); - } - public void testCombineProjections() { // a Alias a = new Alias(EMPTY, "a", FIVE); @@ -338,17 +327,17 @@ public class OptimizerTests extends ESTestCase { } public void testConstantFoldingBinaryLogic_WithNullHandling() { - assertEquals(NULL, new ConstantFolding().rule(new And(EMPTY, NULL, TRUE)).canonical()); - assertEquals(NULL, new ConstantFolding().rule(new And(EMPTY, TRUE, NULL)).canonical()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, TRUE)).canonical().nullable()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, TRUE, NULL)).canonical().nullable()); assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, NULL, FALSE)).canonical()); assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, FALSE, NULL)).canonical()); - assertEquals(NULL, new ConstantFolding().rule(new And(EMPTY, NULL, NULL)).canonical()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, NULL)).canonical().nullable()); assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, TRUE)).canonical()); assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, TRUE, NULL)).canonical()); - assertEquals(NULL, new ConstantFolding().rule(new Or(EMPTY, NULL, FALSE)).canonical()); - assertEquals(NULL, new ConstantFolding().rule(new Or(EMPTY, FALSE, NULL)).canonical()); - assertEquals(NULL, new ConstantFolding().rule(new Or(EMPTY, NULL, NULL)).canonical()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, FALSE)).canonical().nullable()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, FALSE, NULL)).canonical().nullable()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, NULL)).canonical().nullable()); } public void testConstantFoldingRange() { @@ -393,13 +382,15 @@ public class OptimizerTests extends ESTestCase { } public void testConstantFoldingIn_LeftValueNotFoldable() { - Project p = new Project(EMPTY, FROM(), Collections.singletonList( - new In(EMPTY, getFieldAttribute(), - Arrays.asList(ONE, TWO, ONE, THREE, new Sub(EMPTY, THREE, ONE), ONE, FOUR, new Abs(EMPTY, new Sub(EMPTY, TWO, FIVE)))))); + In in = new In(EMPTY, getFieldAttribute(), + Arrays.asList(ONE, TWO, ONE, THREE, new Sub(EMPTY, THREE, ONE), ONE, FOUR, new Abs(EMPTY, new Sub(EMPTY, TWO, FIVE)))); + Alias as = new Alias(in.source(), in.sourceText(), in); + Project p = new Project(EMPTY, FROM(), Collections.singletonList(as)); p = (Project) new ConstantFolding().apply(p); assertEquals(1, p.projections().size()); - In in = (In) p.projections().get(0); - assertThat(Foldables.valuesOf(in.list(), DataType.INTEGER), contains(1 ,2 ,3 ,4)); + Alias a = (Alias) p.projections().get(0); + In i = (In) a.child(); + assertThat(Foldables.valuesOf(i.list(), DataType.INTEGER), contains(1 ,2 ,3 ,4)); } public void testConstantFoldingIn_RightValueIsNull() { @@ -672,47 +663,12 @@ public class OptimizerTests extends ESTestCase { new IfConditional(EMPTY, new GreaterThan(EMPTY, getFieldAttribute(), ONE), Literal.of(EMPTY, "foo2")), Literal.of(EMPTY, "default"))); assertFalse(c.foldable()); - Expression e = new SimplifyCase().rule(c); assertEquals(Case.class, e.getClass()); c = (Case) e; assertEquals(2, c.conditions().size()); - assertThat(c.conditions().get(0).condition().toString(), startsWith("Equals[a{f}#")); - assertThat(c.conditions().get(1).condition().toString(), startsWith("GreaterThan[a{f}#")); - assertFalse(c.foldable()); - assertEquals(TypeResolution.TYPE_RESOLVED, c.typeResolved()); - } - - public void testSimplifyCaseConditionsFoldWhenTrue() { - // CASE WHEN a = 1 THEN 'foo1' - // WHEN 1 = 1 THEN 'bar1' - // WHEN 2 = 1 THEN 'bar2' - // WHEN a > 1 THEN 'foo2' - // ELSE 'default' - // END - // - // ==> - // - // CASE WHEN a = 1 THEN 'foo1' - // WHEN 1 = 1 THEN 'bar1' - // ELSE 'default' - // END - - Case c = new Case(EMPTY, Arrays.asList( - new IfConditional(EMPTY, new Equals(EMPTY, getFieldAttribute(), ONE), Literal.of(EMPTY, "foo1")), - new IfConditional(EMPTY, new Equals(EMPTY, ONE, ONE), Literal.of(EMPTY, "bar1")), - new IfConditional(EMPTY, new Equals(EMPTY, TWO, ONE), Literal.of(EMPTY, "bar2")), - new IfConditional(EMPTY, new GreaterThan(EMPTY, getFieldAttribute(), ONE), Literal.of(EMPTY, "foo2")), - Literal.of(EMPTY, "default"))); - assertFalse(c.foldable()); - - SimplifyCase rule = new SimplifyCase(); - Expression e = rule.rule(c); - assertEquals(Case.class, e.getClass()); - c = (Case) e; - assertEquals(2, c.conditions().size()); - assertThat(c.conditions().get(0).condition().toString(), startsWith("Equals[a{f}#")); - assertThat(c.conditions().get(1).condition().toString(), startsWith("Equals[=1,=1]#")); + assertThat(c.conditions().get(0).condition().getClass(), is(Equals.class)); + assertThat(c.conditions().get(1).condition().getClass(), is(GreaterThan.class)); assertFalse(c.foldable()); assertEquals(TypeResolution.TYPE_RESOLVED, c.typeResolved()); } @@ -738,7 +694,7 @@ public class OptimizerTests extends ESTestCase { assertEquals(Case.class, e.getClass()); c = (Case) e; assertEquals(1, c.conditions().size()); - assertThat(c.conditions().get(0).condition().toString(), startsWith("Equals[=1,=1]#")); + assertThat(c.conditions().get(0).condition().nodeString(), is("1[INTEGER] == 1[INTEGER]")); assertTrue(c.foldable()); assertEquals("foo2", c.fold()); assertEquals(TypeResolution.TYPE_RESOLVED, c.typeResolved()); @@ -822,7 +778,7 @@ public class OptimizerTests extends ESTestCase { assertFalse(iif.foldable()); assertEquals("myField", Expressions.name(iif.elseResult())); } - + // // Logical simplifications // @@ -854,13 +810,11 @@ public class OptimizerTests extends ESTestCase { assertEquals(IsNull.class, e.getClass()); IsNull isNull = (IsNull) e; assertEquals(source, isNull.source()); - assertEquals("IS_NULL(a)", isNull.name()); e = bcSimpl.rule(swapLiteralsToRight.rule(new NullEquals(source, NULL, fa))); assertEquals(IsNull.class, e.getClass()); isNull = (IsNull) e; assertEquals(source, isNull.source()); - assertEquals("IS_NULL(a)", isNull.name()); } public void testLiteralsOnTheRight() { @@ -1500,7 +1454,8 @@ public class OptimizerTests extends ESTestCase { Min min1 = new Min(EMPTY, new FieldAttribute(EMPTY, "str", new EsField("str", DataType.KEYWORD, emptyMap(), true))); Min min2 = new Min(EMPTY, getFieldAttribute()); - OrderBy plan = new OrderBy(EMPTY, new Aggregate(EMPTY, FROM(), emptyList(), Arrays.asList(min1, min2)), + OrderBy plan = new OrderBy(EMPTY, new Aggregate(EMPTY, FROM(), emptyList(), + Arrays.asList(a("min1", min1), a("min2", min2))), Arrays.asList( new Order(EMPTY, min1, OrderDirection.ASC, Order.NullsPosition.LAST), new Order(EMPTY, min2, OrderDirection.ASC, Order.NullsPosition.LAST))); @@ -1515,16 +1470,17 @@ public class OptimizerTests extends ESTestCase { assertTrue(((OrderBy) result).child() instanceof Aggregate); List aggregates = ((Aggregate) ((OrderBy) result).child()).aggregates(); assertEquals(2, aggregates.size()); - assertEquals(First.class, aggregates.get(0).getClass()); - assertSame(first, aggregates.get(0)); - assertEquals(min2, aggregates.get(1)); + assertEquals(Alias.class, aggregates.get(0).getClass()); + assertEquals(Alias.class, aggregates.get(1).getClass()); + assertSame(first, ((Alias) aggregates.get(0)).child()); + assertEquals(min2, ((Alias) aggregates.get(1)).child()); } public void testTranslateMaxToLast() { Max max1 = new Max(EMPTY, new FieldAttribute(EMPTY, "str", new EsField("str", DataType.KEYWORD, emptyMap(), true))); Max max2 = new Max(EMPTY, getFieldAttribute()); - OrderBy plan = new OrderBy(EMPTY, new Aggregate(EMPTY, FROM(), emptyList(), Arrays.asList(max1, max2)), + OrderBy plan = new OrderBy(EMPTY, new Aggregate(EMPTY, FROM(), emptyList(), Arrays.asList(a("max1", max1), a("max2", max2))), Arrays.asList( new Order(EMPTY, max1, OrderDirection.ASC, Order.NullsPosition.LAST), new Order(EMPTY, max2, OrderDirection.ASC, Order.NullsPosition.LAST))); @@ -1538,9 +1494,10 @@ public class OptimizerTests extends ESTestCase { assertTrue(((OrderBy) result).child() instanceof Aggregate); List aggregates = ((Aggregate) ((OrderBy) result).child()).aggregates(); assertEquals(2, aggregates.size()); - assertEquals(Last.class, aggregates.get(0).getClass()); - assertSame(last, aggregates.get(0)); - assertEquals(max2, aggregates.get(1)); + assertEquals(Alias.class, aggregates.get(0).getClass()); + assertEquals(Alias.class, aggregates.get(1).getClass()); + assertSame(last, ((Alias) aggregates.get(0)).child()); + assertEquals(max2, ((Alias) aggregates.get(1)).child()); } public void testSortAggregateOnOrderByWithTwoFields() { @@ -1551,12 +1508,12 @@ public class OptimizerTests extends ESTestCase { Alias secondAlias = new Alias(EMPTY, "second_alias", secondField); Order firstOrderBy = new Order(EMPTY, firstField, OrderDirection.ASC, Order.NullsPosition.LAST); Order secondOrderBy = new Order(EMPTY, secondField, OrderDirection.ASC, Order.NullsPosition.LAST); - + OrderBy orderByPlan = new OrderBy(EMPTY, new Aggregate(EMPTY, FROM(), Arrays.asList(secondField, firstField), Arrays.asList(secondAlias, firstAlias)), Arrays.asList(firstOrderBy, secondOrderBy)); LogicalPlan result = new SortAggregateOnOrderBy().apply(orderByPlan); - + assertTrue(result instanceof OrderBy); List order = ((OrderBy) result).order(); assertEquals(2, order.size()); @@ -1564,7 +1521,7 @@ public class OptimizerTests extends ESTestCase { assertTrue(order.get(1).child() instanceof FieldAttribute); assertEquals("first_field", ((FieldAttribute) order.get(0).child()).name()); assertEquals("second_field", ((FieldAttribute) order.get(1).child()).name()); - + assertTrue(((OrderBy) result).child() instanceof Aggregate); Aggregate agg = (Aggregate) ((OrderBy) result).child(); List groupings = agg.groupings(); @@ -1583,12 +1540,12 @@ public class OptimizerTests extends ESTestCase { Alias secondAlias = new Alias(EMPTY, "second_alias", secondField); Order firstOrderBy = new Order(EMPTY, firstAlias, OrderDirection.ASC, Order.NullsPosition.LAST); Order secondOrderBy = new Order(EMPTY, secondAlias, OrderDirection.ASC, Order.NullsPosition.LAST); - + OrderBy orderByPlan = new OrderBy(EMPTY, new Aggregate(EMPTY, FROM(), Arrays.asList(secondAlias, firstAlias), Arrays.asList(secondAlias, firstAlias)), Arrays.asList(firstOrderBy, secondOrderBy)); LogicalPlan result = new SortAggregateOnOrderBy().apply(orderByPlan); - + assertTrue(result instanceof OrderBy); List order = ((OrderBy) result).order(); assertEquals(2, order.size()); @@ -1596,7 +1553,7 @@ public class OptimizerTests extends ESTestCase { assertTrue(order.get(1).child() instanceof Alias); assertEquals("first_alias", ((Alias) order.get(0).child()).name()); assertEquals("second_alias", ((Alias) order.get(1).child()).name()); - + assertTrue(((OrderBy) result).child() instanceof Aggregate); Aggregate agg = (Aggregate) ((OrderBy) result).child(); List groupings = agg.groupings(); @@ -1611,7 +1568,7 @@ public class OptimizerTests extends ESTestCase { FieldAttribute column = getFieldAttribute("pivot"); FieldAttribute number = getFieldAttribute("number"); List values = Arrays.asList(new Alias(EMPTY, "ONE", L(1)), new Alias(EMPTY, "TWO", L(2))); - List aggs = Arrays.asList(new Avg(EMPTY, number)); + List aggs = Arrays.asList(new Alias(EMPTY, "AVG", new Avg(EMPTY, number))); Pivot pivot = new Pivot(EMPTY, new EsRelation(EMPTY, new EsIndex("table", emptyMap()), false), column, values, aggs); LogicalPlan result = new RewritePivot().apply(pivot); @@ -1657,8 +1614,8 @@ public class OptimizerTests extends ESTestCase { } AggregateFunction firstAggregate = randomFrom(aggregates); AggregateFunction secondAggregate = randomValueOtherThan(firstAggregate, () -> randomFrom(aggregates)); - Aggregate aggregatePlan = new Aggregate(EMPTY, filter, Collections.singletonList(matchField), - Arrays.asList(firstAggregate, secondAggregate)); + Aggregate aggregatePlan = new Aggregate(EMPTY, filter, singletonList(matchField), + Arrays.asList(new Alias(EMPTY, "first", firstAggregate), new Alias(EMPTY, "second", secondAggregate))); LogicalPlan result; if (isSimpleStats) { result = new ReplaceAggsWithStats().apply(aggregatePlan); @@ -1669,11 +1626,17 @@ public class OptimizerTests extends ESTestCase { assertTrue(result instanceof Aggregate); Aggregate resultAgg = (Aggregate) result; assertEquals(2, resultAgg.aggregates().size()); - assertTrue(resultAgg.aggregates().get(0) instanceof InnerAggregate); - assertTrue(resultAgg.aggregates().get(1) instanceof InnerAggregate); - InnerAggregate resultFirstAgg = (InnerAggregate) resultAgg.aggregates().get(0); - InnerAggregate resultSecondAgg = (InnerAggregate) resultAgg.aggregates().get(1); + NamedExpression one = resultAgg.aggregates().get(0); + assertTrue(one instanceof Alias); + assertTrue(((Alias) one).child() instanceof InnerAggregate); + + NamedExpression two = resultAgg.aggregates().get(1); + assertTrue(two instanceof Alias); + assertTrue(((Alias) two).child() instanceof InnerAggregate); + + InnerAggregate resultFirstAgg = (InnerAggregate) ((Alias) one).child(); + InnerAggregate resultSecondAgg = (InnerAggregate) ((Alias) two).child(); assertEquals(resultFirstAgg.inner(), firstAggregate); assertEquals(resultSecondAgg.inner(), secondAggregate); if (isSimpleStats) { @@ -1691,4 +1654,34 @@ public class OptimizerTests extends ESTestCase { assertTrue(resultAgg.child() instanceof Filter); assertEquals(resultAgg.child(), filter); } -} + + public void testReplaceAttributesWithTarget() { + FieldAttribute a = getFieldAttribute("a"); + FieldAttribute b = getFieldAttribute("b"); + + Alias aAlias = new Alias(EMPTY, "aAlias", a); + Alias bAlias = new Alias(EMPTY, "bAlias", b); + + Project p = new Project(EMPTY, FROM(), Arrays.asList(aAlias, bAlias)); + Filter f = new Filter(EMPTY, p, + new And(EMPTY, new GreaterThan(EMPTY, aAlias.toAttribute(), L(1)), new GreaterThan(EMPTY, bAlias.toAttribute(), L(2)))); + + ReplaceReferenceAttributeWithSource rule = new ReplaceReferenceAttributeWithSource(); + Expression condition = f.condition(); + assertTrue(condition instanceof And); + And and = (And) condition; + assertTrue(and.left() instanceof GreaterThan); + GreaterThan gt = (GreaterThan) and.left(); + assertEquals(aAlias.toAttribute(), gt.left()); + + LogicalPlan plan = rule.apply(f); + + Filter filter = (Filter) plan; + condition = filter.condition(); + assertTrue(condition instanceof And); + and = (And) condition; + assertTrue(and.left() instanceof GreaterThan); + gt = (GreaterThan) and.left(); + assertEquals(a, gt.left()); + } +} \ No newline at end of file diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/ExpressionTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/ExpressionTests.java index c9fb153f57e..8d25901650b 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/ExpressionTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/ExpressionTests.java @@ -493,8 +493,8 @@ public class ExpressionTests extends ESTestCase { assertEquals(3, c.conditions().size()); IfConditional ifc = c.conditions().get(0); assertEquals("WHEN a = 1 THEN 'one'", ifc.sourceText()); - assertThat(ifc.condition().toString(), startsWith("Equals[?a,1]#")); - assertEquals("'one'=one", ifc.result().toString()); + assertThat(ifc.condition().toString(), startsWith("a = 1")); + assertEquals("one", ifc.result().toString()); assertEquals(Literal.NULL, c.elseResult()); expr = parser.createExpression( @@ -508,7 +508,7 @@ public class ExpressionTests extends ESTestCase { assertEquals(2, c.conditions().size()); ifc = c.conditions().get(0); assertEquals("WHEN a = 1 THEN 'one'", ifc.sourceText()); - assertEquals("'many'=many", c.elseResult().toString()); + assertEquals("many", c.elseResult().toString()); } public void testCaseWithOperand() { @@ -523,8 +523,8 @@ public class ExpressionTests extends ESTestCase { assertEquals(3, c.conditions().size()); IfConditional ifc = c.conditions().get(0); assertEquals("WHEN 1 THEN 'one'", ifc.sourceText()); - assertThat(ifc.condition().toString(), startsWith("Equals[?a,1]#")); - assertEquals("'one'=one", ifc.result().toString()); + assertThat(ifc.condition().toString(), startsWith("WHEN 1 THEN 'one'")); + assertEquals("one", ifc.result().toString()); assertEquals(Literal.NULL, c.elseResult()); expr = parser.createExpression( @@ -537,6 +537,6 @@ public class ExpressionTests extends ESTestCase { assertEquals(2, c.conditions().size()); ifc = c.conditions().get(0); assertEquals("WHEN 1 THEN 'one'", ifc.sourceText()); - assertEquals("'many'=many", c.elseResult().toString()); + assertEquals("many", c.elseResult().toString()); } } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java index ca31e32b2ed..1dc9567016e 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java @@ -12,14 +12,12 @@ import org.elasticsearch.xpack.sql.expression.Alias; import org.elasticsearch.xpack.sql.expression.Literal; import org.elasticsearch.xpack.sql.expression.NamedExpression; import org.elasticsearch.xpack.sql.expression.Order; +import org.elasticsearch.xpack.sql.expression.UnresolvedAlias; import org.elasticsearch.xpack.sql.expression.UnresolvedAttribute; -import org.elasticsearch.xpack.sql.expression.UnresolvedStar; import org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction; -import org.elasticsearch.xpack.sql.expression.function.scalar.Cast; import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MatchQueryPredicate; import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MultiMatchQueryPredicate; import org.elasticsearch.xpack.sql.expression.predicate.fulltext.StringQueryPredicate; -import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.sql.plan.logical.Filter; import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.sql.plan.logical.OrderBy; @@ -40,7 +38,7 @@ import static org.hamcrest.Matchers.startsWith; public class SqlParserTests extends ESTestCase { public void testSelectStar() { - singleProjection(project(parseStatement("SELECT * FROM foo")), UnresolvedStar.class); + singleProjection(project(parseStatement("SELECT * FROM foo")), UnresolvedAlias.class); } private T singleProjection(Project project, Class type) { @@ -69,42 +67,44 @@ public class SqlParserTests extends ESTestCase { } public void testSelectField() { - UnresolvedAttribute a = singleProjection(project(parseStatement("SELECT bar FROM foo")), UnresolvedAttribute.class); - assertEquals("bar", a.name()); + UnresolvedAlias a = singleProjection(project(parseStatement("SELECT bar FROM foo")), UnresolvedAlias.class); + assertEquals("bar", a.sourceText()); } public void testSelectScore() { - UnresolvedFunction f = singleProjection(project(parseStatement("SELECT SCORE() FROM foo")), UnresolvedFunction.class); + UnresolvedAlias f = singleProjection(project(parseStatement("SELECT SCORE() FROM foo")), UnresolvedAlias.class); assertEquals("SCORE()", f.sourceText()); } public void testSelectCast() { - Cast f = singleProjection(project(parseStatement("SELECT CAST(POWER(languages, 2) AS DOUBLE) FROM foo")), Cast.class); + UnresolvedAlias f = singleProjection(project(parseStatement("SELECT CAST(POWER(languages, 2) AS DOUBLE) FROM foo")), + UnresolvedAlias.class); assertEquals("CAST(POWER(languages, 2) AS DOUBLE)", f.sourceText()); } public void testSelectCastOperator() { - Cast f = singleProjection(project(parseStatement("SELECT POWER(languages, 2)::DOUBLE FROM foo")), Cast.class); + UnresolvedAlias f = singleProjection(project(parseStatement("SELECT POWER(languages, 2)::DOUBLE FROM foo")), UnresolvedAlias.class); assertEquals("POWER(languages, 2)::DOUBLE", f.sourceText()); } public void testSelectCastWithSQLOperator() { - Cast f = singleProjection(project(parseStatement("SELECT CONVERT(POWER(languages, 2), SQL_DOUBLE) FROM foo")), Cast.class); + UnresolvedAlias f = singleProjection(project(parseStatement("SELECT CONVERT(POWER(languages, 2), SQL_DOUBLE) FROM foo")), + UnresolvedAlias.class); assertEquals("CONVERT(POWER(languages, 2), SQL_DOUBLE)", f.sourceText()); } public void testSelectCastToEsType() { - Cast f = singleProjection(project(parseStatement("SELECT CAST('0.' AS SCALED_FLOAT)")), Cast.class); + UnresolvedAlias f = singleProjection(project(parseStatement("SELECT CAST('0.' AS SCALED_FLOAT)")), UnresolvedAlias.class); assertEquals("CAST('0.' AS SCALED_FLOAT)", f.sourceText()); } public void testSelectAddWithParanthesis() { - Add f = singleProjection(project(parseStatement("SELECT (1 + 2)")), Add.class); - assertEquals("1 + 2", f.sourceText()); + UnresolvedAlias f = singleProjection(project(parseStatement("SELECT (1 + 2)")), UnresolvedAlias.class); + assertEquals("(1 + 2)", f.sourceText()); } public void testSelectRightFunction() { - UnresolvedFunction f = singleProjection(project(parseStatement("SELECT RIGHT()")), UnresolvedFunction.class); + UnresolvedAlias f = singleProjection(project(parseStatement("SELECT RIGHT()")), UnresolvedAlias.class); assertEquals("RIGHT()", f.sourceText()); } @@ -124,8 +124,8 @@ public class SqlParserTests extends ESTestCase { for (int i = 0; i < project.projections().size(); i++) { NamedExpression ne = project.projections().get(i); - assertEquals(UnresolvedAttribute.class, ne.getClass()); - assertEquals(reserved[i], ne.name()); + assertEquals(UnresolvedAlias.class, ne.getClass()); + assertEquals(reserved[i], ne.sourceText()); } } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryFolderTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryFolderTests.java index 11f6cc949de..18afb92b273 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryFolderTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryFolderTests.java @@ -12,8 +12,8 @@ import org.elasticsearch.xpack.sql.analysis.analyzer.Verifier; import org.elasticsearch.xpack.sql.analysis.index.EsIndex; import org.elasticsearch.xpack.sql.analysis.index.IndexResolution; import org.elasticsearch.xpack.sql.expression.Expressions; +import org.elasticsearch.xpack.sql.expression.ReferenceAttribute; import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry; -import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute; import org.elasticsearch.xpack.sql.optimizer.Optimizer; import org.elasticsearch.xpack.sql.parser.SqlParser; import org.elasticsearch.xpack.sql.plan.physical.EsQueryExec; @@ -70,7 +70,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); } public void testFoldingToLocalExecWithProjectAndLimit() { @@ -80,7 +80,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); } public void testFoldingToLocalExecWithProjectAndOrderBy() { @@ -90,7 +90,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); } public void testFoldingToLocalExecWithProjectAndOrderByAndLimit() { @@ -100,7 +100,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); } public void testLocalExecWithPrunedFilterWithFunction() { @@ -110,7 +110,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("E(){c}#")); + assertThat(ee.output().get(0).toString(), startsWith("E(){r}#")); } public void testLocalExecWithPrunedFilterWithFunctionAndAggregation() { @@ -120,7 +120,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("E(){c}#")); + assertThat(ee.output().get(0).toString(), startsWith("E(){r}#")); } public void testFoldingToLocalExecWithAggregationAndLimit() { @@ -130,7 +130,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(SingletonExecutable.class, le.executable().getClass()); SingletonExecutable ee = (SingletonExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("'foo'{c}#")); + assertThat(ee.output().get(0).toString(), startsWith("'foo'{r}#")); } public void testFoldingToLocalExecWithAggregationAndOrderBy() { @@ -140,7 +140,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(SingletonExecutable.class, le.executable().getClass()); SingletonExecutable ee = (SingletonExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("'foo'{c}#")); + assertThat(ee.output().get(0).toString(), startsWith("'foo'{r}#")); } public void testFoldingToLocalExecWithAggregationAndOrderByAndLimit() { @@ -150,7 +150,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(SingletonExecutable.class, le.executable().getClass()); SingletonExecutable ee = (SingletonExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("'foo'{c}#")); + assertThat(ee.output().get(0).toString(), startsWith("'foo'{r}#")); } public void testLocalExecWithoutFromClause() { @@ -160,9 +160,9 @@ public class QueryFolderTests extends ESTestCase { assertEquals(SingletonExecutable.class, le.executable().getClass()); SingletonExecutable ee = (SingletonExecutable) le.executable(); assertEquals(3, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("E(){c}#")); - assertThat(ee.output().get(1).toString(), startsWith("'foo'{c}#")); - assertThat(ee.output().get(2).toString(), startsWith("abs(10){c}#")); + assertThat(ee.output().get(0).toString(), startsWith("E(){r}#")); + assertThat(ee.output().get(1).toString(), startsWith("'foo'{r}#")); + assertThat(ee.output().get(2).toString(), startsWith("abs(10){r}#")); } public void testLocalExecWithoutFromClauseWithPrunedFilter() { @@ -172,7 +172,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("E(){c}#")); + assertThat(ee.output().get(0).toString(), startsWith("E(){r}#")); } public void testFoldingOfIsNull() { @@ -180,7 +180,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(LocalExec.class, p.getClass()); LocalExec ee = (LocalExec) p; assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); } public void testFoldingToLocalExecBooleanAndNull_WhereClause() { @@ -190,7 +190,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); } public void testFoldingToLocalExecBooleanAndNull_HavingClause() { @@ -200,8 +200,8 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(2, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); - assertThat(ee.output().get(1).toString(), startsWith("max(int){a->")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); + assertThat(ee.output().get(1).toString(), startsWith("max(int){r}")); } public void testFoldingBooleanOrNull_WhereClause() { @@ -211,7 +211,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals("{\"range\":{\"int\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":false,\"boost\":1.0}}}", ee.queryContainer().query().asBuilder().toString().replaceAll("\\s+", "")); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); } public void testFoldingBooleanOrNull_HavingClause() { @@ -222,8 +222,8 @@ public class QueryFolderTests extends ESTestCase { "\"script\":{\"source\":\"InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.gt(params.a0,params.v0))\"," + "\"lang\":\"painless\",\"params\":{\"v0\":10}},")); assertEquals(2, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); - assertThat(ee.output().get(1).toString(), startsWith("max(int){a->")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); + assertThat(ee.output().get(1).toString(), startsWith("max(int){r}")); } public void testFoldingOfIsNotNull() { @@ -231,7 +231,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EsQueryExec.class, p.getClass()); EsQueryExec ee = (EsQueryExec) p; assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); } public void testFoldingToLocalExecWithNullFilter() { @@ -241,7 +241,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); } public void testFoldingToLocalExecWithProject_FoldableIn() { @@ -251,7 +251,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); } public void testFoldingToLocalExecWithProject_WithOrderAndLimit() { @@ -261,7 +261,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); } public void testFoldingToLocalExecWithProjectWithGroupBy_WithOrderAndLimit() { @@ -271,8 +271,8 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(2, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); - assertThat(ee.output().get(1).toString(), startsWith("max(int){a->")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); + assertThat(ee.output().get(1).toString(), startsWith("max(int){r}")); } public void testFoldingToLocalExecWithProjectWithGroupBy_WithHaving_WithOrderAndLimit() { @@ -282,8 +282,8 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(2, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); - assertThat(ee.output().get(1).toString(), startsWith("max(int){a->")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); + assertThat(ee.output().get(1).toString(), startsWith("max(int){r}")); } public void testGroupKeyTypes_Boolean() { @@ -296,8 +296,8 @@ public class QueryFolderTests extends ESTestCase { "\"lang\":\"painless\",\"params\":{\"v0\":\"int\",\"v1\":10}},\"missing_bucket\":true," + "\"value_type\":\"boolean\",\"order\":\"asc\"}}}]}}}")); assertEquals(2, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("count(*){a->")); - assertThat(ee.output().get(1).toString(), startsWith("a{s->")); + assertThat(ee.output().get(0).toString(), startsWith("count(*){r}")); + assertThat(ee.output().get(1).toString(), startsWith("a{r}")); } public void testGroupKeyTypes_Integer() { @@ -310,8 +310,8 @@ public class QueryFolderTests extends ESTestCase { "\"lang\":\"painless\",\"params\":{\"v0\":\"int\",\"v1\":10}},\"missing_bucket\":true," + "\"value_type\":\"long\",\"order\":\"asc\"}}}]}}}")); assertEquals(2, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("count(*){a->")); - assertThat(ee.output().get(1).toString(), startsWith("a{s->")); + assertThat(ee.output().get(0).toString(), startsWith("count(*){r}")); + assertThat(ee.output().get(1).toString(), startsWith("a{r}")); } public void testGroupKeyTypes_Rational() { @@ -324,8 +324,8 @@ public class QueryFolderTests extends ESTestCase { "\"lang\":\"painless\",\"params\":{\"v0\":\"int\"}},\"missing_bucket\":true," + "\"value_type\":\"double\",\"order\":\"asc\"}}}]}}}")); assertEquals(2, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("count(*){a->")); - assertThat(ee.output().get(1).toString(), startsWith("a{s->")); + assertThat(ee.output().get(0).toString(), startsWith("count(*){r}")); + assertThat(ee.output().get(1).toString(), startsWith("a{r}")); } public void testGroupKeyTypes_String() { @@ -338,8 +338,8 @@ public class QueryFolderTests extends ESTestCase { "\"lang\":\"painless\",\"params\":{\"v0\":\"keyword\"}},\"missing_bucket\":true," + "\"value_type\":\"string\",\"order\":\"asc\"}}}]}}}")); assertEquals(2, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("count(*){a->")); - assertThat(ee.output().get(1).toString(), startsWith("a{s->")); + assertThat(ee.output().get(0).toString(), startsWith("count(*){r}#")); + assertThat(ee.output().get(1).toString(), startsWith("a{r}")); } public void testGroupKeyTypes_IP() { @@ -352,8 +352,8 @@ public class QueryFolderTests extends ESTestCase { "\"lang\":\"painless\",\"params\":{\"v0\":\"keyword\",\"v1\":\"IP\"}}," + "\"missing_bucket\":true,\"value_type\":\"ip\",\"order\":\"asc\"}}}]}}}")); assertEquals(2, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("count(*){a->")); - assertThat(ee.output().get(1).toString(), startsWith("a{s->")); + assertThat(ee.output().get(0).toString(), startsWith("count(*){r}#")); + assertThat(ee.output().get(1).toString(), startsWith("a{r}")); } public void testGroupKeyTypes_DateTime() { @@ -367,8 +367,8 @@ public class QueryFolderTests extends ESTestCase { "\"v0\":\"date\",\"v1\":\"P1Y2M\",\"v2\":\"INTERVAL_YEAR_TO_MONTH\"}},\"missing_bucket\":true," + "\"value_type\":\"long\",\"order\":\"asc\"}}}]}}}")); assertEquals(2, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("count(*){a->")); - assertThat(ee.output().get(1).toString(), startsWith("a{s->")); + assertThat(ee.output().get(0).toString(), startsWith("count(*){r}#")); + assertThat(ee.output().get(1).toString(), startsWith("a{r}")); } public void testConcatIsNotFoldedForNull() { @@ -378,7 +378,7 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EmptyExecutable.class, le.executable().getClass()); EmptyExecutable ee = (EmptyExecutable) le.executable(); assertEquals(1, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("keyword{f}#")); + assertThat(ee.output().get(0).toString(), startsWith("test.keyword{f}#")); } public void testFoldingOfPercentileSecondArgument() { @@ -386,9 +386,8 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EsQueryExec.class, p.getClass()); EsQueryExec ee = (EsQueryExec) p; assertEquals(1, ee.output().size()); - assertEquals(AggregateFunctionAttribute.class, ee.output().get(0).getClass()); - AggregateFunctionAttribute afa = (AggregateFunctionAttribute) ee.output().get(0); - assertThat(afa.propertyPath(), endsWith("[3.0]")); + assertEquals(ReferenceAttribute.class, ee.output().get(0).getClass()); + assertTrue(ee.toString().contains("3.0")); } public void testFoldingOfPercentileRankSecondArgument() { @@ -396,9 +395,8 @@ public class QueryFolderTests extends ESTestCase { assertEquals(EsQueryExec.class, p.getClass()); EsQueryExec ee = (EsQueryExec) p; assertEquals(1, ee.output().size()); - assertEquals(AggregateFunctionAttribute.class, ee.output().get(0).getClass()); - AggregateFunctionAttribute afa = (AggregateFunctionAttribute) ee.output().get(0); - assertThat(afa.propertyPath(), endsWith("[3.0]")); + assertEquals(ReferenceAttribute.class, ee.output().get(0).getClass()); + assertTrue(ee.toString().contains("3.0")); } public void testFoldingOfPivot() { diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java index 36722e6e1d0..41da90ad316 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer; import org.elasticsearch.xpack.sql.analysis.analyzer.Verifier; import org.elasticsearch.xpack.sql.analysis.index.EsIndex; import org.elasticsearch.xpack.sql.analysis.index.IndexResolution; +import org.elasticsearch.xpack.sql.expression.Alias; import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.FieldAttribute; import org.elasticsearch.xpack.sql.expression.Literal; @@ -36,6 +37,7 @@ import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.sql.plan.logical.Project; import org.elasticsearch.xpack.sql.plan.physical.EsQueryExec; import org.elasticsearch.xpack.sql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.sql.planner.QueryFolder.FoldAggregate.GroupingContext; import org.elasticsearch.xpack.sql.planner.QueryTranslator.QueryTranslation; import org.elasticsearch.xpack.sql.querydsl.agg.AggFilter; import org.elasticsearch.xpack.sql.querydsl.agg.GroupByDateHistogram; @@ -437,7 +439,7 @@ public class QueryTranslatorTests extends ESTestCase { assertTrue(p instanceof Filter); Expression condition = ((Filter) p).condition(); SqlIllegalArgumentException ex = expectThrows(SqlIllegalArgumentException.class, () -> QueryTranslator.toQuery(condition, false)); - assertEquals("Scalar function [LTRIM(keyword)] not allowed (yet) as argument for LIKE", ex.getMessage()); + assertEquals("Scalar function [LTRIM(keyword)] not allowed (yet) as argument for LTRIM(keyword) like '%a%'", ex.getMessage()); } public void testRLikeConstructsNotSupported() { @@ -447,7 +449,7 @@ public class QueryTranslatorTests extends ESTestCase { assertTrue(p instanceof Filter); Expression condition = ((Filter) p).condition(); SqlIllegalArgumentException ex = expectThrows(SqlIllegalArgumentException.class, () -> QueryTranslator.toQuery(condition, false)); - assertEquals("Scalar function [LTRIM(keyword)] not allowed (yet) as argument for RLIKE", ex.getMessage()); + assertEquals("Scalar function [LTRIM(keyword)] not allowed (yet) as argument for LTRIM(keyword) RLIKE '.*a.*'", ex.getMessage()); } public void testDifferentLikeAndNotLikePatterns() { @@ -592,7 +594,7 @@ public class QueryTranslatorTests extends ESTestCase { AggFilter aggFilter = translation.aggFilter; assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.isNull(params.a0))", aggFilter.scriptTemplate().toString()); - assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=max(int){a->")); + assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=max(int)")); } public void testTranslateIsNotNullExpression_HavingClause_Painless() { @@ -605,7 +607,7 @@ public class QueryTranslatorTests extends ESTestCase { AggFilter aggFilter = translation.aggFilter; assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.isNotNull(params.a0))", aggFilter.scriptTemplate().toString()); - assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=max(int){a->")); + assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=max(int)")); } public void testTranslateInExpression_WhereClause() { @@ -676,7 +678,7 @@ public class QueryTranslatorTests extends ESTestCase { AggFilter aggFilter = translation.aggFilter; assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))", aggFilter.scriptTemplate().toString()); - assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=max(int){a->")); + assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=max(int)")); assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=[10, 20]}]")); } @@ -690,7 +692,7 @@ public class QueryTranslatorTests extends ESTestCase { AggFilter aggFilter = translation.aggFilter; assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))", aggFilter.scriptTemplate().toString()); - assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=max(int){a->")); + assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=max(int)")); assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=[10]}]")); } @@ -705,7 +707,7 @@ public class QueryTranslatorTests extends ESTestCase { AggFilter aggFilter = translation.aggFilter; assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))", aggFilter.scriptTemplate().toString()); - assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=max(int){a->")); + assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=max(int)")); assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=[10, null, 20, 30]}]")); } @@ -724,7 +726,7 @@ public class QueryTranslatorTests extends ESTestCase { assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.gt(InternalSqlScriptUtils." + operation.name().toLowerCase(Locale.ROOT) + "(params.a0),params.v0))", aggFilter.scriptTemplate().toString()); - assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=max(int){a->")); + assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=max(int)")); assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=10}]")); } @@ -735,12 +737,12 @@ public class QueryTranslatorTests extends ESTestCase { assertEquals(1, ((Aggregate) p).groupings().size()); assertEquals(1, ((Aggregate) p).aggregates().size()); assertTrue(((Aggregate) p).groupings().get(0) instanceof Round); - assertTrue(((Aggregate) p).aggregates().get(0) instanceof Round); + assertTrue(((Alias) (((Aggregate) p).aggregates().get(0))).child() instanceof Round); Round groupingRound = (Round) ((Aggregate) p).groupings().get(0); assertEquals(1, groupingRound.children().size()); - QueryTranslator.GroupingContext groupingContext = QueryTranslator.groupBy(((Aggregate) p).groupings()); + GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings()); assertNotNull(groupingContext); ScriptTemplate scriptTemplate = groupingContext.tail.script(); assertEquals("InternalSqlScriptUtils.round(InternalSqlScriptUtils.dateTimeChrono(InternalSqlScriptUtils.docValue(doc,params.v0), " @@ -756,14 +758,15 @@ public class QueryTranslatorTests extends ESTestCase { assertEquals(1, ((Aggregate) p).groupings().size()); assertEquals(1, ((Aggregate) p).aggregates().size()); assertTrue(((Aggregate) p).groupings().get(0) instanceof Round); - assertTrue(((Aggregate) p).aggregates().get(0) instanceof Round); + assertTrue(((Aggregate) p).aggregates().get(0) instanceof Alias); + assertTrue(((Alias) (((Aggregate) p).aggregates().get(0))).child() instanceof Round); Round groupingRound = (Round) ((Aggregate) p).groupings().get(0); assertEquals(2, groupingRound.children().size()); assertTrue(groupingRound.children().get(1) instanceof Literal); assertEquals(-2, ((Literal) groupingRound.children().get(1)).value()); - QueryTranslator.GroupingContext groupingContext = QueryTranslator.groupBy(((Aggregate) p).groupings()); + GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings()); assertNotNull(groupingContext); ScriptTemplate scriptTemplate = groupingContext.tail.script(); assertEquals("InternalSqlScriptUtils.round(InternalSqlScriptUtils.dateTimeChrono(InternalSqlScriptUtils.docValue(doc,params.v0), " @@ -783,7 +786,7 @@ public class QueryTranslatorTests extends ESTestCase { assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.gt(InternalSqlScriptUtils.abs" + "(params.a0),params.v0))", aggFilter.scriptTemplate().toString()); - assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=MAX(int){a->")); + assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=MAX(int)")); assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=10}]")); } @@ -895,7 +898,7 @@ public class QueryTranslatorTests extends ESTestCase { assertTrue(p instanceof Aggregate); Expression condition = ((Aggregate) p).groupings().get(0); assertFalse(condition.foldable()); - QueryTranslator.GroupingContext groupingContext = QueryTranslator.groupBy(((Aggregate) p).groupings()); + GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings()); assertNotNull(groupingContext); ScriptTemplate scriptTemplate = groupingContext.tail.script(); assertEquals("InternalSqlScriptUtils.coalesce([InternalSqlScriptUtils.docValue(doc,params.v0),params.v1])", @@ -908,7 +911,7 @@ public class QueryTranslatorTests extends ESTestCase { assertTrue(p instanceof Aggregate); Expression condition = ((Aggregate) p).groupings().get(0); assertFalse(condition.foldable()); - QueryTranslator.GroupingContext groupingContext = QueryTranslator.groupBy(((Aggregate) p).groupings()); + GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings()); assertNotNull(groupingContext); ScriptTemplate scriptTemplate = groupingContext.tail.script(); assertEquals("InternalSqlScriptUtils.nullif(InternalSqlScriptUtils.docValue(doc,params.v0),params.v1)", @@ -921,7 +924,7 @@ public class QueryTranslatorTests extends ESTestCase { assertTrue(p instanceof Aggregate); Expression condition = ((Aggregate) p).groupings().get(0); assertFalse(condition.foldable()); - QueryTranslator.GroupingContext groupingContext = QueryTranslator.groupBy(((Aggregate) p).groupings()); + GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings()); assertNotNull(groupingContext); ScriptTemplate scriptTemplate = groupingContext.tail.script(); assertEquals("InternalSqlScriptUtils.caseFunction([InternalSqlScriptUtils.gt(InternalSqlScriptUtils.docValue(" + "" + @@ -936,7 +939,7 @@ public class QueryTranslatorTests extends ESTestCase { assertTrue(p instanceof Aggregate); Expression condition = ((Aggregate) p).groupings().get(0); assertFalse(condition.foldable()); - QueryTranslator.GroupingContext groupingContext = QueryTranslator.groupBy(((Aggregate) p).groupings()); + GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings()); assertNotNull(groupingContext); ScriptTemplate scriptTemplate = groupingContext.tail.script(); assertEquals("InternalSqlScriptUtils.caseFunction([InternalSqlScriptUtils.gt(" + @@ -1066,8 +1069,8 @@ public class QueryTranslatorTests extends ESTestCase { assertEquals(EsQueryExec.class, p.getClass()); EsQueryExec ee = (EsQueryExec) p; assertEquals(2, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("dkey{a->")); - assertThat(ee.output().get(1).toString(), startsWith("key{a->")); + assertThat(ee.output().get(0).toString(), startsWith("dkey{r}")); + assertThat(ee.output().get(1).toString(), startsWith("key{r}")); Collection subAggs = ee.queryContainer().aggs().asAggBuilder().getSubAggregations(); assertEquals(2, subAggs.size()); @@ -1092,12 +1095,12 @@ public class QueryTranslatorTests extends ESTestCase { assertEquals(EsQueryExec.class, p.getClass()); EsQueryExec ee = (EsQueryExec) p; assertEquals(6, ee.output().size()); - assertThat(ee.output().get(0).toString(), startsWith("AVG(int){a->")); - assertThat(ee.output().get(1).toString(), startsWith("ln{a->")); - assertThat(ee.output().get(2).toString(), startsWith("dln{a->")); - assertThat(ee.output().get(3).toString(), startsWith("fn{a->")); - assertThat(ee.output().get(4).toString(), startsWith("dfn{a->")); - assertThat(ee.output().get(5).toString(), startsWith("ccc{a->")); + assertThat(ee.output().get(0).toString(), startsWith("AVG(int){r}")); + assertThat(ee.output().get(1).toString(), startsWith("ln{r}")); + assertThat(ee.output().get(2).toString(), startsWith("dln{r}")); + assertThat(ee.output().get(3).toString(), startsWith("fn{r}")); + assertThat(ee.output().get(4).toString(), startsWith("dfn{r}")); + assertThat(ee.output().get(5).toString(), startsWith("ccc{r}")); Collection subAggs = ee.queryContainer().aggs().asAggBuilder().getSubAggregations(); assertEquals(5, subAggs.size()); diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/querydsl/container/QueryContainerTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/querydsl/container/QueryContainerTests.java index a23dc8a3f27..432be5fea4a 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/querydsl/container/QueryContainerTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/querydsl/container/QueryContainerTests.java @@ -8,7 +8,8 @@ package org.elasticsearch.xpack.sql.querydsl.container; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.sql.expression.Alias; import org.elasticsearch.xpack.sql.expression.Attribute; -import org.elasticsearch.xpack.sql.expression.ExpressionId; +import org.elasticsearch.xpack.sql.expression.AttributeMap; +import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.FieldAttribute; import org.elasticsearch.xpack.sql.querydsl.query.BoolQuery; import org.elasticsearch.xpack.sql.querydsl.query.MatchAll; @@ -81,11 +82,11 @@ public class QueryContainerTests extends ESTestCase { Attribute fourth = new FieldAttribute(Source.EMPTY, "fourth", esField); Alias firstAliased = new Alias(Source.EMPTY, "firstAliased", first); - Map aliasesMap = new LinkedHashMap<>(); - aliasesMap.put(firstAliased.id(), first); + Map aliasesMap = new LinkedHashMap<>(); + aliasesMap.put(firstAliased.toAttribute(), first); QueryContainer queryContainer = new QueryContainer() - .withAliases(aliasesMap) + .withAliases(new AttributeMap<>(aliasesMap)) .addColumn(third) .addColumn(first) .addColumn(fourth)