SQL: Allow sorting of groups by aggregates (#38042)

Introduce client-side sorting of groups based on aggregate
functions. To allow this, the Analyzer has been extended to push down
to underlying Aggregate, aggregate function and the Querier has been
extended to identify the case and consume the results in order and sort
them based on the given columns.
The underlying QueryContainer has been slightly modified to allow a view
of the underlying values being extracted as the columns used for sorting
might not be requested by the user.

The PR also adds minor tweaks, mainly related to tree output.

Close #35118
This commit is contained in:
Costin Leau 2019-02-02 01:38:25 +02:00 committed by GitHub
parent 630889baec
commit 783c9ed372
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 1352 additions and 410 deletions

View File

@ -67,8 +67,18 @@ a field is an array (has multiple values) or not, so without reading all the dat
=== Sorting by aggregation === Sorting by aggregation
When doing aggregations (`GROUP BY`) {es-sql} relies on {es}'s `composite` aggregation for its support for paginating results. When doing aggregations (`GROUP BY`) {es-sql} relies on {es}'s `composite` aggregation for its support for paginating results.
But this type of aggregation does come with a limitation: sorting can only be applied on the key used for the aggregation's buckets. This However this type of aggregation does come with a limitation: sorting can only be applied on the key used for the aggregation's buckets.
means that queries like `SELECT * FROM test GROUP BY age ORDER BY COUNT(*)` are not possible. {es-sql} overcomes this limitation by doing client-side sorting however as a safety measure, allows only up to *512* rows.
It is recommended to use `LIMIT` for queries that use sorting by aggregation, essentially indicating the top N results that are desired:
[source, sql]
--------------------------------------------------
SELECT * FROM test GROUP BY age ORDER BY COUNT(*) LIMIT 100;
--------------------------------------------------
It is possible to run the same queries without a `LIMIT` however in that case if the maximum size (*512*) is passed, an exception will be
returned as {es-sql} is unable to track (and sort) all the results returned.
[float] [float]
=== Using aggregation functions on top of scalar functions === Using aggregation functions on top of scalar functions

View File

@ -20,7 +20,7 @@ public class CliExplainIT extends CliIntegrationTestCase {
assertThat(readLine(), startsWith("----------")); assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("With[{}]")); assertThat(readLine(), startsWith("With[{}]"));
assertThat(readLine(), startsWith("\\_Project[[?*]]")); assertThat(readLine(), startsWith("\\_Project[[?*]]"));
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[[][index=test],null,Unknown index [test]]")); assertThat(readLine(), startsWith(" \\_UnresolvedRelation[test]"));
assertEquals("", readLine()); assertEquals("", readLine());
assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT * FROM test"), containsString("plan")); assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT * FROM test"), containsString("plan"));
@ -64,22 +64,22 @@ public class CliExplainIT extends CliIntegrationTestCase {
assertThat(readLine(), startsWith("----------")); assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("With[{}]")); assertThat(readLine(), startsWith("With[{}]"));
assertThat(readLine(), startsWith("\\_Project[[?*]]")); assertThat(readLine(), startsWith("\\_Project[[?*]]"));
assertThat(readLine(), startsWith(" \\_Filter[i = 2#")); assertThat(readLine(), startsWith(" \\_Filter[Equals[?i,2"));
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[[][index=test],null,Unknown index [test]]")); assertThat(readLine(), startsWith(" \\_UnresolvedRelation[test]"));
assertEquals("", readLine()); assertEquals("", readLine());
assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT * FROM test WHERE i = 2"), assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT * FROM test WHERE i = 2"),
containsString("plan")); containsString("plan"));
assertThat(readLine(), startsWith("----------")); assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("Project[[i{f}#")); assertThat(readLine(), startsWith("Project[[i{f}#"));
assertThat(readLine(), startsWith("\\_Filter[i = 2#")); assertThat(readLine(), startsWith("\\_Filter[Equals[i"));
assertThat(readLine(), startsWith(" \\_EsRelation[test][i{f}#")); assertThat(readLine(), startsWith(" \\_EsRelation[test][i{f}#"));
assertEquals("", readLine()); assertEquals("", readLine());
assertThat(command("EXPLAIN (PLAN OPTIMIZED) SELECT * FROM test WHERE i = 2"), containsString("plan")); assertThat(command("EXPLAIN (PLAN OPTIMIZED) SELECT * FROM test WHERE i = 2"), containsString("plan"));
assertThat(readLine(), startsWith("----------")); assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("Project[[i{f}#")); assertThat(readLine(), startsWith("Project[[i{f}#"));
assertThat(readLine(), startsWith("\\_Filter[i = 2#")); assertThat(readLine(), startsWith("\\_Filter[Equals[i"));
assertThat(readLine(), startsWith(" \\_EsRelation[test][i{f}#")); assertThat(readLine(), startsWith(" \\_EsRelation[test][i{f}#"));
assertEquals("", readLine()); assertEquals("", readLine());
@ -123,20 +123,20 @@ public class CliExplainIT extends CliIntegrationTestCase {
assertThat(command("EXPLAIN (PLAN PARSED) SELECT COUNT(*) FROM test"), containsString("plan")); assertThat(command("EXPLAIN (PLAN PARSED) SELECT COUNT(*) FROM test"), containsString("plan"));
assertThat(readLine(), startsWith("----------")); assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("With[{}]")); assertThat(readLine(), startsWith("With[{}]"));
assertThat(readLine(), startsWith("\\_Project[[?COUNT(*)]]")); assertThat(readLine(), startsWith("\\_Project[[?COUNT[?*]]]"));
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[[][index=test],null,Unknown index [test]]")); assertThat(readLine(), startsWith(" \\_UnresolvedRelation[test]"));
assertEquals("", readLine()); assertEquals("", readLine());
assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT COUNT(*) FROM test"), assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT COUNT(*) FROM test"),
containsString("plan")); containsString("plan"));
assertThat(readLine(), startsWith("----------")); assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("Aggregate[[],[COUNT(*)#")); assertThat(readLine(), startsWith("Aggregate[[],[Count[*=1"));
assertThat(readLine(), startsWith("\\_EsRelation[test][i{f}#")); assertThat(readLine(), startsWith("\\_EsRelation[test][i{f}#"));
assertEquals("", readLine()); assertEquals("", readLine());
assertThat(command("EXPLAIN (PLAN OPTIMIZED) SELECT COUNT(*) FROM test"), containsString("plan")); assertThat(command("EXPLAIN (PLAN OPTIMIZED) SELECT COUNT(*) FROM test"), containsString("plan"));
assertThat(readLine(), startsWith("----------")); assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("Aggregate[[],[COUNT(*)#")); assertThat(readLine(), startsWith("Aggregate[[],[Count[*=1"));
assertThat(readLine(), startsWith("\\_EsRelation[test][i{f}#")); assertThat(readLine(), startsWith("\\_EsRelation[test][i{f}#"));
assertEquals("", readLine()); assertEquals("", readLine());

View File

@ -73,7 +73,7 @@ public abstract class ErrorsTestCase extends CliIntegrationTestCase implements o
public void testSelectOrderByScoreInAggContext() throws Exception { public void testSelectOrderByScoreInAggContext() throws Exception {
index("test", body -> body.field("foo", 1)); index("test", body -> body.field("foo", 1));
assertFoundOneProblem(command("SELECT foo, COUNT(*) FROM test GROUP BY foo ORDER BY SCORE()")); assertFoundOneProblem(command("SELECT foo, COUNT(*) FROM test GROUP BY foo ORDER BY SCORE()"));
assertEquals("line 1:54: Cannot order by non-grouped column [SCORE()], expected [foo]" + END, readLine()); assertEquals("line 1:54: Cannot order by non-grouped column [SCORE()], expected [foo] or an aggregate function" + END, readLine());
} }
@Override @Override

View File

@ -81,7 +81,9 @@ public class ErrorsTestCase extends JdbcIntegrationTestCase implements org.elast
try (Connection c = esJdbc()) { try (Connection c = esJdbc()) {
SQLException e = expectThrows(SQLException.class, () -> SQLException e = expectThrows(SQLException.class, () ->
c.prepareStatement("SELECT foo, COUNT(*) FROM test GROUP BY foo ORDER BY SCORE()").executeQuery()); c.prepareStatement("SELECT foo, COUNT(*) FROM test GROUP BY foo ORDER BY SCORE()").executeQuery());
assertEquals("Found 1 problem(s)\nline 1:54: Cannot order by non-grouped column [SCORE()], expected [foo]", e.getMessage()); assertEquals(
"Found 1 problem(s)\nline 1:54: Cannot order by non-grouped column [SCORE()], expected [foo] or an aggregate function",
e.getMessage());
} }
} }

View File

@ -38,6 +38,7 @@ public abstract class SqlSpecTestCase extends SpecBaseIntegrationTestCase {
tests.addAll(readScriptSpec("/datetime.sql-spec", parser)); tests.addAll(readScriptSpec("/datetime.sql-spec", parser));
tests.addAll(readScriptSpec("/math.sql-spec", parser)); tests.addAll(readScriptSpec("/math.sql-spec", parser));
tests.addAll(readScriptSpec("/agg.sql-spec", parser)); tests.addAll(readScriptSpec("/agg.sql-spec", parser));
tests.addAll(readScriptSpec("/agg-ordering.sql-spec", parser));
tests.addAll(readScriptSpec("/arithmetic.sql-spec", parser)); tests.addAll(readScriptSpec("/arithmetic.sql-spec", parser));
tests.addAll(readScriptSpec("/string-functions.sql-spec", parser)); tests.addAll(readScriptSpec("/string-functions.sql-spec", parser));
tests.addAll(readScriptSpec("/case-functions.sql-spec", parser)); tests.addAll(readScriptSpec("/case-functions.sql-spec", parser));

View File

@ -0,0 +1,87 @@
//
// Custom sorting/ordering on aggregates
//
countWithImplicitGroupBy
SELECT MAX(salary) AS m FROM test_emp ORDER BY COUNT(*);
countWithImplicitGroupByWithHaving
SELECT MAX(salary) AS m FROM test_emp HAVING MIN(salary) > 1 ORDER BY COUNT(*);
countAndMaxWithImplicitGroupBy
SELECT MAX(salary) AS m FROM test_emp ORDER BY MAX(salary), COUNT(*);
maxWithAliasWithImplicitGroupBy
SELECT MAX(salary) AS m FROM test_emp ORDER BY m;
maxWithAliasWithImplicitGroupByAndHaving
SELECT MAX(salary) AS m FROM test_emp HAVING COUNT(*) > 1 ORDER BY m;
multipleOrderWithImplicitGroupByWithHaving
SELECT MAX(salary) AS m FROM test_emp HAVING MIN(salary) > 1 ORDER BY COUNT(*), m DESC;
multipleOrderWithImplicitGroupByWithoutAlias
SELECT MAX(salary) AS m FROM test_emp HAVING MIN(salary) > 1 ORDER BY COUNT(*), MIN(salary) DESC;
multipleOrderWithImplicitGroupByOfOrdinals
SELECT MAX(salary) AS max, MIN(salary) AS min FROM test_emp HAVING MIN(salary) > 1 ORDER BY 1, COUNT(*), 2 DESC;
aggWithoutAlias
SELECT MAX(salary) AS max FROM test_emp GROUP BY gender ORDER BY MAX(salary);
aggWithAlias
SELECT MAX(salary) AS m FROM test_emp GROUP BY gender ORDER BY m;
multipleAggsThatGetRewrittenWithoutAlias
SELECT MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY gender ORDER BY MAX(salary);
multipleAggsThatGetRewrittenWithAliasDesc
SELECT MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY gender ORDER BY 1 DESC;
multipleAggsThatGetRewrittenWithAlias
SELECT MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY gender ORDER BY max;
aggNotSpecifiedInTheAggregate
SELECT MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender ORDER BY MAX(salary);
aggNotSpecifiedInTheAggregatePlusOrdinal
SELECT MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender ORDER BY MAX(salary), 2 DESC;
aggNotSpecifiedInTheAggregateWithHaving
SELECT MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY MAX(salary);
aggNotSpecifiedInTheAggregateWithHavingDesc
SELECT MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY MAX(salary) DESC;
aggNotSpecifiedInTheAggregateAndGroupWithHaving
SELECT gender, MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY MAX(salary), gender;
groupAndAggNotSpecifiedInTheAggregateWithHaving
SELECT gender, MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY gender, MAX(salary);
multipleAggsThatGetRewrittenWithAliasOnAMediumGroupBy
SELECT languages, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY languages ORDER BY max;
multipleAggsThatGetRewrittenWithAliasOnALargeGroupBy
SELECT emp_no, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY emp_no ORDER BY max;
multipleAggsThatGetRewrittenWithAliasOnAMediumGroupByWithHaving
SELECT languages, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY languages HAVING min BETWEEN 1000 AND 99999 ORDER BY max;
aggNotSpecifiedInTheAggregatemultipleAggsThatGetRewrittenWithAliasOnALargeGroupBy
SELECT emp_no, MIN(salary) AS min FROM test_emp GROUP BY emp_no ORDER BY MAX(salary);
aggNotSpecifiedWithHavingOnLargeGroupBy
SELECT MAX(salary) AS max FROM test_emp GROUP BY emp_no HAVING AVG(salary) > 1000 ORDER BY MIN(salary);
aggWithTieBreakerDescAsc
SELECT emp_no, MIN(languages) AS min FROM test_emp GROUP BY emp_no ORDER BY MIN(languages) DESC NULLS FIRST, emp_no ASC;
aggWithTieBreakerDescDesc
SELECT emp_no, MIN(languages) AS min FROM test_emp GROUP BY emp_no ORDER BY MIN(languages) DESC NULLS FIRST, emp_no DESC;
aggWithTieBreakerAscDesc
SELECT emp_no, MIN(languages) AS min FROM test_emp GROUP BY emp_no ORDER BY MAX(languages) ASC NULLS FIRST, emp_no DESC;
aggWithMixOfOrdinals
SELECT gender AS g, MAX(salary) AS m FROM test_emp GROUP BY gender ORDER BY 2 DESC LIMIT 3;

View File

@ -52,6 +52,8 @@ import org.elasticsearch.xpack.sql.type.DataTypeConversion;
import org.elasticsearch.xpack.sql.type.DataTypes; import org.elasticsearch.xpack.sql.type.DataTypes;
import org.elasticsearch.xpack.sql.type.InvalidMappedField; import org.elasticsearch.xpack.sql.type.InvalidMappedField;
import org.elasticsearch.xpack.sql.type.UnsupportedEsField; import org.elasticsearch.xpack.sql.type.UnsupportedEsField;
import org.elasticsearch.xpack.sql.util.CollectionUtils;
import org.elasticsearch.xpack.sql.util.Holder;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -106,7 +108,8 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
new ResolveFunctions(), new ResolveFunctions(),
new ResolveAliases(), new ResolveAliases(),
new ProjectedAggregations(), new ProjectedAggregations(),
new ResolveAggsInHaving() new ResolveAggsInHaving(),
new ResolveAggsInOrderBy()
//new ImplicitCasting() //new ImplicitCasting()
); );
Batch finish = new Batch("Finish Analysis", Batch finish = new Batch("Finish Analysis",
@ -926,7 +929,7 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
// Handle aggs in HAVING. To help folding any aggs not found in Aggregation // Handle aggs in HAVING. To help folding any aggs not found in Aggregation
// will be pushed down to the Aggregate and then projected. This also simplifies the Verifier's job. // will be pushed down to the Aggregate and then projected. This also simplifies the Verifier's job.
// //
private class ResolveAggsInHaving extends AnalyzeRule<LogicalPlan> { private class ResolveAggsInHaving extends AnalyzeRule<Filter> {
@Override @Override
protected boolean skipResolved() { protected boolean skipResolved() {
@ -934,10 +937,8 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
} }
@Override @Override
protected LogicalPlan rule(LogicalPlan plan) { protected LogicalPlan rule(Filter f) {
// HAVING = Filter followed by an Agg // HAVING = Filter followed by an Agg
if (plan instanceof Filter) {
Filter f = (Filter) plan;
if (f.child() instanceof Aggregate && f.child().resolved()) { if (f.child() instanceof Aggregate && f.child().resolved()) {
Aggregate agg = (Aggregate) f.child(); Aggregate agg = (Aggregate) f.child();
@ -962,7 +963,7 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
.get(tryResolvingCondition.aggregates().size() - 1)).child(); .get(tryResolvingCondition.aggregates().size() - 1)).child();
} else { } else {
// else bail out // else bail out
return plan; return f;
} }
} }
@ -978,10 +979,7 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
return new Filter(f.source(), f.child(), condition); return new Filter(f.source(), f.child(), condition);
} }
return plan; return f;
}
return plan;
} }
private Set<NamedExpression> findMissingAggregate(Aggregate target, Expression from) { private Set<NamedExpression> findMissingAggregate(Aggregate target, Expression from) {
@ -1001,6 +999,66 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
} }
} }
//
// Handle aggs in ORDER BY. To help folding any aggs not found in Aggregation
// will be pushed down to the Aggregate and then projected. This also simplifies the Verifier's job.
// Similar to Having however using a different matching pattern since HAVING is always Filter with Agg,
// while an OrderBy can have multiple intermediate nodes (Filter,Project, etc...)
//
private static class ResolveAggsInOrderBy extends AnalyzeRule<OrderBy> {
@Override
protected boolean skipResolved() {
return false;
}
@Override
protected LogicalPlan rule(OrderBy ob) {
List<Order> orders = ob.order();
// 1. collect aggs inside an order by
List<NamedExpression> aggs = new ArrayList<>();
for (Order order : orders) {
if (Functions.isAggregate(order.child())) {
aggs.add(Expressions.wrapAsNamed(order.child()));
}
}
if (aggs.isEmpty()) {
return ob;
}
// 2. find first Aggregate child and update it
final Holder<Boolean> found = new Holder<>(Boolean.FALSE);
LogicalPlan plan = ob.transformDown(a -> {
if (found.get() == Boolean.FALSE) {
found.set(Boolean.TRUE);
List<NamedExpression> missing = new ArrayList<>();
for (NamedExpression orderedAgg : aggs) {
if (Expressions.anyMatch(a.aggregates(), e -> Expressions.equalsAsAttribute(e, orderedAgg)) == false) {
missing.add(orderedAgg);
}
}
// agg already contains all aggs
if (missing.isEmpty() == false) {
// save aggregates
return new Aggregate(a.source(), a.child(), a.groupings(), CollectionUtils.combine(a.aggregates(), missing));
}
}
return a;
}, Aggregate.class);
// if the plan was updated, project the initial aggregates
if (plan != ob) {
return new Project(ob.source(), plan, ob.output());
}
return ob;
}
}
private class PruneDuplicateFunctions extends AnalyzeRule<LogicalPlan> { private class PruneDuplicateFunctions extends AnalyzeRule<LogicalPlan> {
@Override @Override

View File

@ -54,8 +54,8 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import static java.lang.String.format;
import static java.util.stream.Collectors.toMap; import static java.util.stream.Collectors.toMap;
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.sql.stats.FeatureMetric.COMMAND; import static org.elasticsearch.xpack.sql.stats.FeatureMetric.COMMAND;
import static org.elasticsearch.xpack.sql.stats.FeatureMetric.GROUPBY; import static org.elasticsearch.xpack.sql.stats.FeatureMetric.GROUPBY;
import static org.elasticsearch.xpack.sql.stats.FeatureMetric.HAVING; import static org.elasticsearch.xpack.sql.stats.FeatureMetric.HAVING;
@ -118,7 +118,7 @@ public final class Verifier {
} }
private static Failure fail(Node<?> source, String message, Object... args) { private static Failure fail(Node<?> source, String message, Object... args) {
return new Failure(source, format(Locale.ROOT, message, args)); return new Failure(source, format(message, args));
} }
public Map<Node<?>, String> verifyFailures(LogicalPlan plan) { public Map<Node<?>, String> verifyFailures(LogicalPlan plan) {
@ -314,11 +314,12 @@ public final class Verifier {
Aggregate a = (Aggregate) child; Aggregate a = (Aggregate) child;
Map<Expression, Node<?>> missing = new LinkedHashMap<>(); Map<Expression, Node<?>> missing = new LinkedHashMap<>();
o.order().forEach(oe -> { o.order().forEach(oe -> {
Expression e = oe.child(); Expression e = oe.child();
// cannot order by aggregates (not supported by composite)
if (Functions.isAggregate(e)) { // aggregates are allowed
missing.put(e, oe); if (Functions.isAggregate(e) || e instanceof AggregateFunctionAttribute) {
return; return;
} }
@ -352,7 +353,8 @@ public final class Verifier {
String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY; String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY;
// get the location of the first missing expression as the order by might be on a different line // get the location of the first missing expression as the order by might be on a different line
localFailures.add( localFailures.add(
fail(missing.values().iterator().next(), "Cannot order by non-grouped column" + plural + " %s, expected %s", fail(missing.values().iterator().next(),
"Cannot order by non-grouped column" + plural + " {}, expected {} or an aggregate function",
Expressions.names(missing.keySet()), Expressions.names(missing.keySet()),
Expressions.names(a.groupings()))); Expressions.names(a.groupings())));
groupingFailures.add(a); groupingFailures.add(a);
@ -379,7 +381,7 @@ public final class Verifier {
if (!missing.isEmpty()) { if (!missing.isEmpty()) {
String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY; String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY;
localFailures.add( localFailures.add(
fail(condition, "Cannot use HAVING filter on non-aggregate" + plural + " %s; use WHERE instead", fail(condition, "Cannot use HAVING filter on non-aggregate" + plural + " {}; use WHERE instead",
Expressions.names(missing))); Expressions.names(missing)));
groupingFailures.add(a); groupingFailures.add(a);
return false; return false;
@ -388,7 +390,7 @@ public final class Verifier {
if (!unsupported.isEmpty()) { if (!unsupported.isEmpty()) {
String plural = unsupported.size() > 1 ? "s" : StringUtils.EMPTY; String plural = unsupported.size() > 1 ? "s" : StringUtils.EMPTY;
localFailures.add( localFailures.add(
fail(condition, "HAVING filter is unsupported for function" + plural + " %s", fail(condition, "HAVING filter is unsupported for function" + plural + " {}",
Expressions.names(unsupported))); Expressions.names(unsupported)));
groupingFailures.add(a); groupingFailures.add(a);
return false; return false;
@ -480,7 +482,7 @@ public final class Verifier {
e.collectFirstChildren(c -> { e.collectFirstChildren(c -> {
if (Functions.isGrouping(c)) { if (Functions.isGrouping(c)) {
localFailures.add(fail(c, localFailures.add(fail(c,
"Cannot combine [%s] grouping function inside GROUP BY, found [%s];" "Cannot combine [{}] grouping function inside GROUP BY, found [{}];"
+ " consider moving the expression inside the histogram", + " consider moving the expression inside the histogram",
Expressions.name(c), Expressions.name(e))); Expressions.name(c), Expressions.name(e)));
return true; return true;
@ -509,7 +511,7 @@ public final class Verifier {
if (!missing.isEmpty()) { if (!missing.isEmpty()) {
String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY; String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY;
localFailures.add(fail(missing.values().iterator().next(), "Cannot use non-grouped column" + plural + " %s, expected %s", localFailures.add(fail(missing.values().iterator().next(), "Cannot use non-grouped column" + plural + " {}, expected {}",
Expressions.names(missing.keySet()), Expressions.names(missing.keySet()),
Expressions.names(a.groupings()))); Expressions.names(a.groupings())));
return false; return false;
@ -592,7 +594,7 @@ public final class Verifier {
filter.condition().forEachDown(e -> { filter.condition().forEachDown(e -> {
if (Functions.isAggregate(e) || e instanceof AggregateFunctionAttribute) { if (Functions.isAggregate(e) || e instanceof AggregateFunctionAttribute) {
localFailures.add( localFailures.add(
fail(e, "Cannot use WHERE filtering on aggregate function [%s], use HAVING instead", Expressions.name(e))); fail(e, "Cannot use WHERE filtering on aggregate function [{}], use HAVING instead", Expressions.name(e)));
} }
}, Expression.class); }, Expression.class);
} }
@ -606,7 +608,7 @@ public final class Verifier {
filter.condition().forEachDown(e -> { filter.condition().forEachDown(e -> {
if (Functions.isGrouping(e) || e instanceof GroupingFunctionAttribute) { if (Functions.isGrouping(e) || e instanceof GroupingFunctionAttribute) {
localFailures localFailures
.add(fail(e, "Cannot filter on grouping function [%s], use its argument instead", Expressions.name(e))); .add(fail(e, "Cannot filter on grouping function [{}], use its argument instead", Expressions.name(e)));
} }
}, Expression.class); }, Expression.class);
} }
@ -659,7 +661,7 @@ public final class Verifier {
DataType dt = in.value().dataType(); DataType dt = in.value().dataType();
for (Expression value : in.list()) { for (Expression value : in.list()) {
if (areTypesCompatible(dt, value.dataType()) == false) { if (areTypesCompatible(dt, value.dataType()) == false) {
localFailures.add(fail(value, "expected data type [%s], value provided is of type [%s]", localFailures.add(fail(value, "expected data type [{}], value provided is of type [{}]",
dt.esType, value.dataType().esType)); dt.esType, value.dataType().esType));
return; return;
} }
@ -680,7 +682,7 @@ public final class Verifier {
} }
} else { } else {
if (areTypesCompatible(dt, child.dataType()) == false) { if (areTypesCompatible(dt, child.dataType()) == false) {
localFailures.add(fail(child, "expected data type [%s], value provided is of type [%s]", localFailures.add(fail(child, "expected data type [{}], value provided is of type [{}]",
dt.esType, child.dataType().esType)); dt.esType, child.dataType().esType));
return; return;
} }

View File

@ -60,7 +60,7 @@ public class PlanExecutor {
} }
private SqlSession newSession(Configuration cfg) { private SqlSession newSession(Configuration cfg) {
return new SqlSession(cfg, client, functionRegistry, indexResolver, preAnalyzer, verifier, optimizer, planner); return new SqlSession(cfg, client, functionRegistry, indexResolver, preAnalyzer, verifier, optimizer, planner, this);
} }
public void searchSource(Configuration cfg, String sql, List<SqlTypedParamValue> params, ActionListener<SearchSourceBuilder> listener) { public void searchSource(Configuration cfg, String sql, List<SqlTypedParamValue> params, ActionListener<SearchSourceBuilder> listener) {
@ -68,15 +68,20 @@ public class PlanExecutor {
if (exec instanceof EsQueryExec) { if (exec instanceof EsQueryExec) {
EsQueryExec e = (EsQueryExec) exec; EsQueryExec e = (EsQueryExec) exec;
listener.onResponse(SourceGenerator.sourceBuilder(e.queryContainer(), cfg.filter(), cfg.pageSize())); listener.onResponse(SourceGenerator.sourceBuilder(e.queryContainer(), cfg.filter(), cfg.pageSize()));
} else if (exec instanceof LocalExec) { }
listener.onFailure(new PlanningException("Cannot generate a query DSL for an SQL query that either " + // try to provide a better resolution of what failed
"its WHERE clause evaluates to FALSE or doesn't operate on a table (missing a FROM clause), sql statement: [{}]", else {
sql)); String message = null;
if (exec instanceof LocalExec) {
message = "Cannot generate a query DSL for an SQL query that either " +
"its WHERE clause evaluates to FALSE or doesn't operate on a table (missing a FROM clause)";
} else if (exec instanceof CommandExec) { } else if (exec instanceof CommandExec) {
listener.onFailure(new PlanningException("Cannot generate a query DSL for a special SQL command " + message = "Cannot generate a query DSL for a special SQL command " +
"(e.g.: DESCRIBE, SHOW), sql statement: [{}]", sql)); "(e.g.: DESCRIBE, SHOW)";
} else { } else {
listener.onFailure(new PlanningException("Cannot generate a query DSL, sql statement: [{}]", sql)); message = "Cannot generate a query DSL";
}
listener.onFailure(new PlanningException(message + ", sql statement: [{}]", sql));
} }
}, listener::onFailure)); }, listener::onFailure));
} }

View File

@ -32,6 +32,7 @@ import org.elasticsearch.xpack.sql.util.StringUtils;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.BitSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -49,12 +50,14 @@ public class CompositeAggregationCursor implements Cursor {
private final String[] indices; private final String[] indices;
private final byte[] nextQuery; private final byte[] nextQuery;
private final List<BucketExtractor> extractors; private final List<BucketExtractor> extractors;
private final BitSet mask;
private final int limit; private final int limit;
CompositeAggregationCursor(byte[] next, List<BucketExtractor> exts, int remainingLimit, String... indices) { CompositeAggregationCursor(byte[] next, List<BucketExtractor> exts, BitSet mask, int remainingLimit, String... indices) {
this.indices = indices; this.indices = indices;
this.nextQuery = next; this.nextQuery = next;
this.extractors = exts; this.extractors = exts;
this.mask = mask;
this.limit = remainingLimit; this.limit = remainingLimit;
} }
@ -64,6 +67,7 @@ public class CompositeAggregationCursor implements Cursor {
limit = in.readVInt(); limit = in.readVInt();
extractors = in.readNamedWriteableList(BucketExtractor.class); extractors = in.readNamedWriteableList(BucketExtractor.class);
mask = BitSet.valueOf(in.readByteArray());
} }
@Override @Override
@ -73,6 +77,7 @@ public class CompositeAggregationCursor implements Cursor {
out.writeVInt(limit); out.writeVInt(limit);
out.writeNamedWriteableList(extractors); out.writeNamedWriteableList(extractors);
out.writeByteArray(mask.toByteArray());
} }
@Override @Override
@ -88,6 +93,10 @@ public class CompositeAggregationCursor implements Cursor {
return nextQuery; return nextQuery;
} }
BitSet mask() {
return mask;
}
List<BucketExtractor> extractors() { List<BucketExtractor> extractors() {
return extractors; return extractors;
} }
@ -125,7 +134,7 @@ public class CompositeAggregationCursor implements Cursor {
} }
updateCompositeAfterKey(r, query); updateCompositeAfterKey(r, query);
CompositeAggsRowSet rowSet = new CompositeAggsRowSet(extractors, r, limit, serializeQuery(query), indices); CompositeAggsRowSet rowSet = new CompositeAggsRowSet(extractors, mask, r, limit, serializeQuery(query), indices);
listener.onResponse(rowSet); listener.onResponse(rowSet);
} catch (Exception ex) { } catch (Exception ex) {
listener.onFailure(ex); listener.onFailure(ex);

View File

@ -8,10 +8,10 @@ package org.elasticsearch.xpack.sql.execution.search;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation; import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.xpack.sql.execution.search.extractor.BucketExtractor; import org.elasticsearch.xpack.sql.execution.search.extractor.BucketExtractor;
import org.elasticsearch.xpack.sql.session.AbstractRowSet;
import org.elasticsearch.xpack.sql.session.Cursor; import org.elasticsearch.xpack.sql.session.Cursor;
import org.elasticsearch.xpack.sql.session.RowSet; import org.elasticsearch.xpack.sql.session.RowSet;
import java.util.BitSet;
import java.util.List; import java.util.List;
import static java.util.Collections.emptyList; import static java.util.Collections.emptyList;
@ -19,8 +19,7 @@ import static java.util.Collections.emptyList;
/** /**
* {@link RowSet} specific to (GROUP BY) aggregation. * {@link RowSet} specific to (GROUP BY) aggregation.
*/ */
class CompositeAggsRowSet extends AbstractRowSet { class CompositeAggsRowSet extends ResultRowSet<BucketExtractor> {
private final List<BucketExtractor> exts;
private final List<? extends CompositeAggregation.Bucket> buckets; private final List<? extends CompositeAggregation.Bucket> buckets;
@ -29,8 +28,8 @@ class CompositeAggsRowSet extends AbstractRowSet {
private final int size; private final int size;
private int row = 0; private int row = 0;
CompositeAggsRowSet(List<BucketExtractor> exts, SearchResponse response, int limit, byte[] next, String... indices) { CompositeAggsRowSet(List<BucketExtractor> exts, BitSet mask, SearchResponse response, int limit, byte[] next, String... indices) {
this.exts = exts; super(exts, mask);
CompositeAggregation composite = CompositeAggregationCursor.getComposite(response); CompositeAggregation composite = CompositeAggregationCursor.getComposite(response);
if (composite != null) { if (composite != null) {
@ -54,19 +53,14 @@ class CompositeAggsRowSet extends AbstractRowSet {
if (next == null || size == 0 || remainingLimit == 0) { if (next == null || size == 0 || remainingLimit == 0) {
cursor = Cursor.EMPTY; cursor = Cursor.EMPTY;
} else { } else {
cursor = new CompositeAggregationCursor(next, exts, remainingLimit, indices); cursor = new CompositeAggregationCursor(next, exts, mask, remainingLimit, indices);
} }
} }
} }
@Override @Override
protected Object getColumn(int column) { protected Object extractValue(BucketExtractor e) {
return exts.get(column).extract(buckets.get(row)); return e.extract(buckets.get(row));
}
@Override
public int columnCount() {
return exts.size();
} }
@Override @Override

View File

@ -0,0 +1,106 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.execution.search;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.sql.session.Configuration;
import org.elasticsearch.xpack.sql.session.Cursor;
import org.elasticsearch.xpack.sql.session.RowSet;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import static java.util.Collections.emptyList;
public class PagingListCursor implements Cursor {
public static final String NAME = "p";
private final List<List<?>> data;
private final int columnCount;
private final int pageSize;
PagingListCursor(List<List<?>> data, int columnCount, int pageSize) {
this.data = data;
this.columnCount = columnCount;
this.pageSize = pageSize;
}
@SuppressWarnings("unchecked")
public PagingListCursor(StreamInput in) throws IOException {
data = (List<List<?>>) in.readGenericValue();
columnCount = in.readVInt();
pageSize = in.readVInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeGenericValue(data);
out.writeVInt(columnCount);
out.writeVInt(pageSize);
}
@Override
public String getWriteableName() {
return NAME;
}
List<List<?>> data() {
return data;
}
int columnCount() {
return columnCount;
}
int pageSize() {
return pageSize;
}
@Override
public void nextPage(Configuration cfg, Client client, NamedWriteableRegistry registry, ActionListener<RowSet> listener) {
// the check is really a safety measure since the page initialization handles it already (by returning an empty cursor)
List<List<?>> nextData = data.size() > pageSize ? data.subList(pageSize, data.size()) : emptyList();
listener.onResponse(new PagingListRowSet(nextData, columnCount, pageSize));
}
@Override
public void clear(Configuration cfg, Client client, ActionListener<Boolean> listener) {
listener.onResponse(true);
}
@Override
public int hashCode() {
return Objects.hash(data, columnCount, pageSize);
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
PagingListCursor other = (PagingListCursor) obj;
return Objects.equals(pageSize, other.pageSize)
&& Objects.equals(columnCount, other.columnCount)
&& Objects.equals(data, other.data);
}
@Override
public String toString() {
return "cursor for paging list";
}
}

View File

@ -0,0 +1,46 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.execution.search;
import org.elasticsearch.xpack.sql.session.Cursor;
import org.elasticsearch.xpack.sql.session.ListRowSet;
import org.elasticsearch.xpack.sql.type.Schema;
import java.util.List;
class PagingListRowSet extends ListRowSet {
private final int pageSize;
private final int columnCount;
private final Cursor cursor;
PagingListRowSet(List<List<?>> list, int columnCount, int pageSize) {
this(Schema.EMPTY, list, columnCount, pageSize);
}
PagingListRowSet(Schema schema, List<List<?>> list, int columnCount, int pageSize) {
super(schema, list);
this.columnCount = columnCount;
this.pageSize = Math.min(pageSize, list.size());
this.cursor = list.size() > pageSize ? new PagingListCursor(list, columnCount, pageSize) : Cursor.EMPTY;
}
@Override
public int size() {
return pageSize;
}
@Override
public int columnCount() {
return columnCount;
}
@Override
public Cursor nextPageCursor() {
return cursor;
}
}

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.sql.execution.search;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.PriorityQueue;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
@ -14,6 +15,7 @@ import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
@ -25,6 +27,7 @@ import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation.Buck
import org.elasticsearch.search.aggregations.bucket.filter.Filters; import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.execution.PlanExecutor;
import org.elasticsearch.xpack.sql.execution.search.extractor.BucketExtractor; import org.elasticsearch.xpack.sql.execution.search.extractor.BucketExtractor;
import org.elasticsearch.xpack.sql.execution.search.extractor.CompositeKeyExtractor; import org.elasticsearch.xpack.sql.execution.search.extractor.CompositeKeyExtractor;
import org.elasticsearch.xpack.sql.execution.search.extractor.ComputingExtractor; import org.elasticsearch.xpack.sql.execution.search.extractor.ComputingExtractor;
@ -33,11 +36,14 @@ import org.elasticsearch.xpack.sql.execution.search.extractor.FieldHitExtractor;
import org.elasticsearch.xpack.sql.execution.search.extractor.HitExtractor; import org.elasticsearch.xpack.sql.execution.search.extractor.HitExtractor;
import org.elasticsearch.xpack.sql.execution.search.extractor.MetricAggExtractor; import org.elasticsearch.xpack.sql.execution.search.extractor.MetricAggExtractor;
import org.elasticsearch.xpack.sql.execution.search.extractor.TopHitsAggExtractor; import org.elasticsearch.xpack.sql.execution.search.extractor.TopHitsAggExtractor;
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.ExpressionId;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.AggExtractorInput; import org.elasticsearch.xpack.sql.expression.gen.pipeline.AggExtractorInput;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.AggPathInput; import org.elasticsearch.xpack.sql.expression.gen.pipeline.AggPathInput;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.HitExtractorInput; import org.elasticsearch.xpack.sql.expression.gen.pipeline.HitExtractorInput;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.ReferenceInput; import org.elasticsearch.xpack.sql.expression.gen.pipeline.ReferenceInput;
import org.elasticsearch.xpack.sql.planner.PlanningException;
import org.elasticsearch.xpack.sql.querydsl.agg.Aggs; import org.elasticsearch.xpack.sql.querydsl.agg.Aggs;
import org.elasticsearch.xpack.sql.querydsl.container.ComputedRef; import org.elasticsearch.xpack.sql.querydsl.container.ComputedRef;
import org.elasticsearch.xpack.sql.querydsl.container.GlobalCountRef; import org.elasticsearch.xpack.sql.querydsl.container.GlobalCountRef;
@ -48,16 +54,23 @@ import org.elasticsearch.xpack.sql.querydsl.container.ScriptFieldRef;
import org.elasticsearch.xpack.sql.querydsl.container.SearchHitFieldRef; import org.elasticsearch.xpack.sql.querydsl.container.SearchHitFieldRef;
import org.elasticsearch.xpack.sql.querydsl.container.TopHitsAggRef; import org.elasticsearch.xpack.sql.querydsl.container.TopHitsAggRef;
import org.elasticsearch.xpack.sql.session.Configuration; import org.elasticsearch.xpack.sql.session.Configuration;
import org.elasticsearch.xpack.sql.session.Cursor;
import org.elasticsearch.xpack.sql.session.RowSet;
import org.elasticsearch.xpack.sql.session.Rows; import org.elasticsearch.xpack.sql.session.Rows;
import org.elasticsearch.xpack.sql.session.SchemaRowSet; import org.elasticsearch.xpack.sql.session.SchemaRowSet;
import org.elasticsearch.xpack.sql.session.SqlSession;
import org.elasticsearch.xpack.sql.type.Schema; import org.elasticsearch.xpack.sql.type.Schema;
import org.elasticsearch.xpack.sql.util.StringUtils; import org.elasticsearch.xpack.sql.util.StringUtils;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.BitSet;
import java.util.Comparator;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
// TODO: add retry/back-off // TODO: add retry/back-off
@ -65,25 +78,25 @@ public class Querier {
private final Logger log = LogManager.getLogger(getClass()); private final Logger log = LogManager.getLogger(getClass());
private final PlanExecutor planExecutor;
private final Configuration cfg;
private final TimeValue keepAlive, timeout; private final TimeValue keepAlive, timeout;
private final int size; private final int size;
private final Client client; private final Client client;
@Nullable @Nullable
private final QueryBuilder filter; private final QueryBuilder filter;
public Querier(Client client, Configuration cfg) { public Querier(SqlSession sqlSession) {
this(client, cfg.requestTimeout(), cfg.pageTimeout(), cfg.filter(), cfg.pageSize()); this.planExecutor = sqlSession.planExecutor();
this.client = sqlSession.client();
this.cfg = sqlSession.configuration();
this.keepAlive = cfg.requestTimeout();
this.timeout = cfg.pageTimeout();
this.filter = cfg.filter();
this.size = cfg.pageSize();
} }
public Querier(Client client, TimeValue keepAlive, TimeValue timeout, QueryBuilder filter, int size) { public void query(List<Attribute> output, QueryContainer query, String index, ActionListener<SchemaRowSet> listener) {
this.client = client;
this.keepAlive = keepAlive;
this.timeout = timeout;
this.filter = filter;
this.size = size;
}
public void query(Schema schema, QueryContainer query, String index, ActionListener<SchemaRowSet> listener) {
// prepare the request // prepare the request
SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(query, filter, size); SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(query, filter, size);
// set query timeout // set query timeout
@ -97,16 +110,21 @@ public class Querier {
SearchRequest search = prepareRequest(client, sourceBuilder, timeout, Strings.commaDelimitedListToStringArray(index)); SearchRequest search = prepareRequest(client, sourceBuilder, timeout, Strings.commaDelimitedListToStringArray(index));
ActionListener<SearchResponse> l; @SuppressWarnings("rawtypes")
List<Tuple<Integer, Comparator>> sortingColumns = query.sortingColumns();
listener = sortingColumns.isEmpty() ? listener : new LocalAggregationSorterListener(listener, sortingColumns, query.limit());
ActionListener<SearchResponse> l = null;
if (query.isAggsOnly()) { if (query.isAggsOnly()) {
if (query.aggs().useImplicitGroupBy()) { if (query.aggs().useImplicitGroupBy()) {
l = new ImplicitGroupActionListener(listener, client, timeout, schema, query, search); l = new ImplicitGroupActionListener(listener, client, timeout, output, query, search);
} else { } else {
l = new CompositeActionListener(listener, client, timeout, schema, query, search); l = new CompositeActionListener(listener, client, timeout, output, query, search);
} }
} else { } else {
search.scroll(keepAlive); search.scroll(keepAlive);
l = new ScrollActionListener(listener, client, timeout, schema, query); l = new ScrollActionListener(listener, client, timeout, output, query);
} }
client.search(search, l); client.search(search, l);
@ -123,6 +141,141 @@ public class Querier {
return search; return search;
} }
/**
* Listener used for local sorting (typically due to aggregations used inside `ORDER BY`).
*
* This listener consumes the whole result set, sorts it in memory then sends the paginated
* results back to the client.
*/
@SuppressWarnings("rawtypes")
class LocalAggregationSorterListener implements ActionListener<SchemaRowSet> {
private final ActionListener<SchemaRowSet> listener;
// keep the top N entries.
private final PriorityQueue<Tuple<List<?>, Integer>> data;
private final AtomicInteger counter = new AtomicInteger();
private volatile Schema schema;
private static final int MAXIMUM_SIZE = 512;
private final boolean noLimit;
LocalAggregationSorterListener(ActionListener<SchemaRowSet> listener, List<Tuple<Integer, Comparator>> sortingColumns, int limit) {
this.listener = listener;
int size = MAXIMUM_SIZE;
if (limit < 0) {
noLimit = true;
} else {
noLimit = false;
if (limit > MAXIMUM_SIZE) {
throw new PlanningException("The maximum LIMIT for aggregate sorting is [{}], received [{}]", limit, MAXIMUM_SIZE);
} else {
size = limit;
}
}
this.data = new PriorityQueue<Tuple<List<?>, Integer>>(size) {
// compare row based on the received attribute sort
// if a sort item is not in the list, it is assumed the sorting happened in ES
// and the results are left as is (by using the row ordering), otherwise it is sorted based on the given criteria.
//
// Take for example ORDER BY a, x, b, y
// a, b - are sorted in ES
// x, y - need to be sorted client-side
// sorting on x kicks in, only if the values for a are equal.
// thanks to @jpountz for the row ordering idea as a way to preserve ordering
@SuppressWarnings("unchecked")
@Override
protected boolean lessThan(Tuple<List<?>, Integer> l, Tuple<List<?>, Integer> r) {
for (Tuple<Integer, Comparator> tuple : sortingColumns) {
int i = tuple.v1().intValue();
Comparator comparator = tuple.v2();
Object vl = l.v1().get(i);
Object vr = r.v1().get(i);
if (comparator != null) {
int result = comparator.compare(vl, vr);
// if things are equals, move to the next comparator
if (result != 0) {
return result < 0;
}
}
// no comparator means the existing order needs to be preserved
else {
// check the values - if they are equal move to the next comparator
// otherwise return the row order
if (Objects.equals(vl, vr) == false) {
return l.v2().compareTo(r.v2()) < 0;
}
}
}
// everything is equal, fall-back to the row order
return l.v2().compareTo(r.v2()) < 0;
}
};
}
@Override
public void onResponse(SchemaRowSet schemaRowSet) {
schema = schemaRowSet.schema();
doResponse(schemaRowSet);
}
private void doResponse(RowSet rowSet) {
// 1. consume all pages received
if (consumeRowSet(rowSet) == false) {
return;
}
Cursor cursor = rowSet.nextPageCursor();
// 1a. trigger a next call if there's still data
if (cursor != Cursor.EMPTY) {
// trigger a next call
planExecutor.nextPage(cfg, cursor, ActionListener.wrap(this::doResponse, this::onFailure));
// make sure to bail out afterwards as we'll get called by a different thread
return;
}
// no more data available, the last thread sends the response
// 2. send the in-memory view to the client
sendResponse();
}
private boolean consumeRowSet(RowSet rowSet) {
// use a synchronized block for visibility purposes (there's no concurrency)
ResultRowSet<?> rrs = (ResultRowSet<?>) rowSet;
synchronized (data) {
for (boolean hasRows = rrs.hasCurrentRow(); hasRows; hasRows = rrs.advanceRow()) {
List<Object> row = new ArrayList<>(rrs.columnCount());
rrs.forEachResultColumn(row::add);
// if the queue overflows and no limit was specified, bail out
if (data.insertWithOverflow(new Tuple<>(row, counter.getAndIncrement())) != null && noLimit) {
onFailure(new SqlIllegalArgumentException(
"The default limit [{}] for aggregate sorting has been reached; please specify a LIMIT"));
return false;
}
}
}
return true;
}
private void sendResponse() {
List<List<?>> list = new ArrayList<>(data.size());
Tuple<List<?>, Integer> pop = null;
while ((pop = data.pop()) != null) {
list.add(pop.v1());
}
listener.onResponse(new PagingListRowSet(schema, list, schema.size(), cfg.pageSize()));
}
@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
}
/** /**
* Dedicated listener for implicit/default group-by queries that return only _one_ result. * Dedicated listener for implicit/default group-by queries that return only _one_ result.
*/ */
@ -156,9 +309,9 @@ public class Querier {
} }
}); });
ImplicitGroupActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, Schema schema, ImplicitGroupActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, List<Attribute> output,
QueryContainer query, SearchRequest request) { QueryContainer query, SearchRequest request) {
super(listener, client, keepAlive, schema, query, request); super(listener, client, keepAlive, output, query, request);
} }
@Override @Override
@ -182,9 +335,12 @@ public class Querier {
if (buckets.size() == 1) { if (buckets.size() == 1) {
Bucket implicitGroup = buckets.get(0); Bucket implicitGroup = buckets.get(0);
List<BucketExtractor> extractors = initBucketExtractors(response); List<BucketExtractor> extractors = initBucketExtractors(response);
Object[] values = new Object[extractors.size()];
for (int i = 0; i < values.length; i++) { Object[] values = new Object[mask.cardinality()];
values[i] = extractors.get(i).extract(implicitGroup);
int index = 0;
for (int i = mask.nextSetBit(0); i >= 0; i = mask.nextSetBit(i + 1)) {
values[index++] = extractors.get(i).extract(implicitGroup);
} }
listener.onResponse(Rows.singleton(schema, values)); listener.onResponse(Rows.singleton(schema, values));
@ -205,8 +361,8 @@ public class Querier {
static class CompositeActionListener extends BaseAggActionListener { static class CompositeActionListener extends BaseAggActionListener {
CompositeActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, CompositeActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive,
Schema schema, QueryContainer query, SearchRequest request) { List<Attribute> output, QueryContainer query, SearchRequest request) {
super(listener, client, keepAlive, schema, query, request); super(listener, client, keepAlive, output, query, request);
} }
@ -232,7 +388,7 @@ public class Querier {
} }
listener.onResponse( listener.onResponse(
new SchemaCompositeAggsRowSet(schema, initBucketExtractors(response), response, query.limit(), new SchemaCompositeAggsRowSet(schema, initBucketExtractors(response), mask, response, query.limit(),
nextSearch, nextSearch,
request.indices())); request.indices()));
} }
@ -246,23 +402,25 @@ public class Querier {
abstract static class BaseAggActionListener extends BaseActionListener { abstract static class BaseAggActionListener extends BaseActionListener {
final QueryContainer query; final QueryContainer query;
final SearchRequest request; final SearchRequest request;
final BitSet mask;
BaseAggActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, Schema schema, BaseAggActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, List<Attribute> output,
QueryContainer query, SearchRequest request) { QueryContainer query, SearchRequest request) {
super(listener, client, keepAlive, schema); super(listener, client, keepAlive, output);
this.query = query; this.query = query;
this.request = request; this.request = request;
this.mask = query.columnMask(output);
} }
protected List<BucketExtractor> initBucketExtractors(SearchResponse response) { protected List<BucketExtractor> initBucketExtractors(SearchResponse response) {
// create response extractors for the first time // create response extractors for the first time
List<FieldExtraction> refs = query.columns(); List<Tuple<FieldExtraction, ExpressionId>> refs = query.fields();
List<BucketExtractor> exts = new ArrayList<>(refs.size()); List<BucketExtractor> exts = new ArrayList<>(refs.size());
ConstantExtractor totalCount = new ConstantExtractor(response.getHits().getTotalHits().value); ConstantExtractor totalCount = new ConstantExtractor(response.getHits().getTotalHits().value);
for (FieldExtraction ref : refs) { for (Tuple<FieldExtraction, ExpressionId> ref : refs) {
exts.add(createExtractor(ref, totalCount)); exts.add(createExtractor(ref.v1(), totalCount));
} }
return exts; return exts;
} }
@ -308,11 +466,13 @@ public class Querier {
*/ */
static class ScrollActionListener extends BaseActionListener { static class ScrollActionListener extends BaseActionListener {
private final QueryContainer query; private final QueryContainer query;
private final BitSet mask;
ScrollActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, ScrollActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive,
Schema schema, QueryContainer query) { List<Attribute> output, QueryContainer query) {
super(listener, client, keepAlive, schema); super(listener, client, keepAlive, output);
this.query = query; this.query = query;
this.mask = query.columnMask(output);
} }
@Override @Override
@ -320,17 +480,17 @@ public class Querier {
SearchHit[] hits = response.getHits().getHits(); SearchHit[] hits = response.getHits().getHits();
// create response extractors for the first time // create response extractors for the first time
List<FieldExtraction> refs = query.columns(); List<Tuple<FieldExtraction, ExpressionId>> refs = query.fields();
List<HitExtractor> exts = new ArrayList<>(refs.size()); List<HitExtractor> exts = new ArrayList<>(refs.size());
for (FieldExtraction ref : refs) { for (Tuple<FieldExtraction, ExpressionId> ref : refs) {
exts.add(createExtractor(ref)); exts.add(createExtractor(ref.v1()));
} }
// there are some results // there are some results
if (hits.length > 0) { if (hits.length > 0) {
String scrollId = response.getScrollId(); String scrollId = response.getScrollId();
SchemaSearchHitRowSet hitRowSet = new SchemaSearchHitRowSet(schema, exts, hits, query.limit(), scrollId); SchemaSearchHitRowSet hitRowSet = new SchemaSearchHitRowSet(schema, exts, mask, hits, query.limit(), scrollId);
// if there's an id, try to setup next scroll // if there's an id, try to setup next scroll
if (scrollId != null && if (scrollId != null &&
@ -340,7 +500,7 @@ public class Querier {
|| hitRowSet.isLimitReached())) { || hitRowSet.isLimitReached())) {
// if so, clear the scroll // if so, clear the scroll
clear(response.getScrollId(), ActionListener.wrap( clear(response.getScrollId(), ActionListener.wrap(
succeeded -> listener.onResponse(new SchemaSearchHitRowSet(schema, exts, hits, query.limit(), null)), succeeded -> listener.onResponse(new SchemaSearchHitRowSet(schema, exts, mask, hits, query.limit(), null)),
listener::onFailure)); listener::onFailure));
} else { } else {
listener.onResponse(hitRowSet); listener.onResponse(hitRowSet);
@ -401,12 +561,12 @@ public class Querier {
final TimeValue keepAlive; final TimeValue keepAlive;
final Schema schema; final Schema schema;
BaseActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, Schema schema) { BaseActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, List<Attribute> output) {
this.listener = listener; this.listener = listener;
this.client = client; this.client = client;
this.keepAlive = keepAlive; this.keepAlive = keepAlive;
this.schema = schema; this.schema = Rows.schema(output);
} }
// TODO: need to handle rejections plus check failures (shard size, etc...) // TODO: need to handle rejections plus check failures (shard size, etc...)

View File

@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.execution.search;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.session.AbstractRowSet;
import org.elasticsearch.xpack.sql.util.Check;
import java.util.BitSet;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
abstract class ResultRowSet<E extends NamedWriteable> extends AbstractRowSet {
private final List<E> extractors;
private final BitSet mask;
ResultRowSet(List<E> extractors, BitSet mask) {
this.extractors = extractors;
this.mask = mask;
Check.isTrue(mask.length() <= extractors.size(), "Invalid number of extracted columns specified");
}
@Override
public final int columnCount() {
return mask.cardinality();
}
@Override
protected Object getColumn(int column) {
return extractValue(userExtractor(column));
}
List<E> extractors() {
return extractors;
}
BitSet mask() {
return mask;
}
E userExtractor(int column) {
int i = -1;
// find the nth set bit
for (i = mask.nextSetBit(0); i >= 0; i = mask.nextSetBit(i + 1)) {
if (column-- == 0) {
return extractors.get(i);
}
}
throw new SqlIllegalArgumentException("Cannot find column [{}]", column);
}
Object resultColumn(int column) {
return extractValue(extractors().get(column));
}
int resultColumnCount() {
return extractors.size();
}
void forEachResultColumn(Consumer<? super Object> action) {
Objects.requireNonNull(action);
int rowSize = resultColumnCount();
for (int i = 0; i < rowSize; i++) {
action.accept(resultColumn(i));
}
}
protected abstract Object extractValue(E e);
}

View File

@ -11,6 +11,7 @@ import org.elasticsearch.xpack.sql.session.RowSet;
import org.elasticsearch.xpack.sql.session.SchemaRowSet; import org.elasticsearch.xpack.sql.session.SchemaRowSet;
import org.elasticsearch.xpack.sql.type.Schema; import org.elasticsearch.xpack.sql.type.Schema;
import java.util.BitSet;
import java.util.List; import java.util.List;
/** /**
@ -21,9 +22,10 @@ class SchemaCompositeAggsRowSet extends CompositeAggsRowSet implements SchemaRow
private final Schema schema; private final Schema schema;
SchemaCompositeAggsRowSet(Schema schema, List<BucketExtractor> exts, SearchResponse response, int limitAggs, byte[] next, SchemaCompositeAggsRowSet(Schema schema, List<BucketExtractor> exts, BitSet mask, SearchResponse response, int limitAggs,
byte[] next,
String... indices) { String... indices) {
super(exts, response, limitAggs, next, indices); super(exts, mask, response, limitAggs, next, indices);
this.schema = schema; this.schema = schema;
} }

View File

@ -10,6 +10,7 @@ import org.elasticsearch.xpack.sql.execution.search.extractor.HitExtractor;
import org.elasticsearch.xpack.sql.session.SchemaRowSet; import org.elasticsearch.xpack.sql.session.SchemaRowSet;
import org.elasticsearch.xpack.sql.type.Schema; import org.elasticsearch.xpack.sql.type.Schema;
import java.util.BitSet;
import java.util.List; import java.util.List;
/** /**
@ -20,8 +21,8 @@ import java.util.List;
class SchemaSearchHitRowSet extends SearchHitRowSet implements SchemaRowSet { class SchemaSearchHitRowSet extends SearchHitRowSet implements SchemaRowSet {
private final Schema schema; private final Schema schema;
SchemaSearchHitRowSet(Schema schema, List<HitExtractor> exts, SearchHit[] hits, int limitHits, String scrollId) { SchemaSearchHitRowSet(Schema schema, List<HitExtractor> exts, BitSet mask, SearchHit[] hits, int limitHits, String scrollId) {
super(exts, hits, limitHits, scrollId); super(exts, mask, hits, limitHits, scrollId);
this.schema = schema; this.schema = schema;
} }

View File

@ -23,6 +23,7 @@ import org.elasticsearch.xpack.sql.session.Cursor;
import org.elasticsearch.xpack.sql.session.RowSet; import org.elasticsearch.xpack.sql.session.RowSet;
import java.io.IOException; import java.io.IOException;
import java.util.BitSet;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
@ -34,11 +35,13 @@ public class ScrollCursor implements Cursor {
private final String scrollId; private final String scrollId;
private final List<HitExtractor> extractors; private final List<HitExtractor> extractors;
private final BitSet mask;
private final int limit; private final int limit;
public ScrollCursor(String scrollId, List<HitExtractor> extractors, int limit) { public ScrollCursor(String scrollId, List<HitExtractor> extractors, BitSet mask, int limit) {
this.scrollId = scrollId; this.scrollId = scrollId;
this.extractors = extractors; this.extractors = extractors;
this.mask = mask;
this.limit = limit; this.limit = limit;
} }
@ -47,6 +50,7 @@ public class ScrollCursor implements Cursor {
limit = in.readVInt(); limit = in.readVInt();
extractors = in.readNamedWriteableList(HitExtractor.class); extractors = in.readNamedWriteableList(HitExtractor.class);
mask = BitSet.valueOf(in.readByteArray());
} }
@Override @Override
@ -55,6 +59,7 @@ public class ScrollCursor implements Cursor {
out.writeVInt(limit); out.writeVInt(limit);
out.writeNamedWriteableList(extractors); out.writeNamedWriteableList(extractors);
out.writeByteArray(mask.toByteArray());
} }
@Override @Override
@ -66,6 +71,10 @@ public class ScrollCursor implements Cursor {
return scrollId; return scrollId;
} }
BitSet mask() {
return mask;
}
List<HitExtractor> extractors() { List<HitExtractor> extractors() {
return extractors; return extractors;
} }
@ -79,7 +88,7 @@ public class ScrollCursor implements Cursor {
SearchScrollRequest request = new SearchScrollRequest(scrollId).scroll(cfg.pageTimeout()); SearchScrollRequest request = new SearchScrollRequest(scrollId).scroll(cfg.pageTimeout());
client.searchScroll(request, ActionListener.wrap((SearchResponse response) -> { client.searchScroll(request, ActionListener.wrap((SearchResponse response) -> {
SearchHitRowSet rowSet = new SearchHitRowSet(extractors, response.getHits().getHits(), SearchHitRowSet rowSet = new SearchHitRowSet(extractors, mask, response.getHits().getHits(),
limit, response.getScrollId()); limit, response.getScrollId());
if (rowSet.nextPageCursor() == Cursor.EMPTY ) { if (rowSet.nextPageCursor() == Cursor.EMPTY ) {
// we are finished with this cursor, let's clean it before continuing // we are finished with this cursor, let's clean it before continuing

View File

@ -9,10 +9,10 @@ import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.execution.search.extractor.HitExtractor; import org.elasticsearch.xpack.sql.execution.search.extractor.HitExtractor;
import org.elasticsearch.xpack.sql.session.AbstractRowSet;
import org.elasticsearch.xpack.sql.session.Cursor; import org.elasticsearch.xpack.sql.session.Cursor;
import java.util.Arrays; import java.util.Arrays;
import java.util.BitSet;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -20,10 +20,9 @@ import java.util.Set;
/** /**
* Extracts rows from an array of {@link SearchHit}. * Extracts rows from an array of {@link SearchHit}.
*/ */
class SearchHitRowSet extends AbstractRowSet { class SearchHitRowSet extends ResultRowSet<HitExtractor> {
private final SearchHit[] hits; private final SearchHit[] hits;
private final Cursor cursor; private final Cursor cursor;
private final List<HitExtractor> extractors;
private final Set<String> innerHits = new LinkedHashSet<>(); private final Set<String> innerHits = new LinkedHashSet<>();
private final String innerHit; private final String innerHit;
@ -31,10 +30,10 @@ class SearchHitRowSet extends AbstractRowSet {
private final int[] indexPerLevel; private final int[] indexPerLevel;
private int row = 0; private int row = 0;
SearchHitRowSet(List<HitExtractor> exts, SearchHit[] hits, int limit, String scrollId) { SearchHitRowSet(List<HitExtractor> exts, BitSet mask, SearchHit[] hits, int limit, String scrollId) {
super(exts, mask);
this.hits = hits; this.hits = hits;
this.extractors = exts;
// Since the results might contain nested docs, the iteration is similar to that of Aggregation // Since the results might contain nested docs, the iteration is similar to that of Aggregation
// namely it discovers the nested docs and then, for iteration, increments the deepest level first // namely it discovers the nested docs and then, for iteration, increments the deepest level first
@ -85,7 +84,7 @@ class SearchHitRowSet extends AbstractRowSet {
if (size == 0 || remainingLimit == 0) { if (size == 0 || remainingLimit == 0) {
cursor = Cursor.EMPTY; cursor = Cursor.EMPTY;
} else { } else {
cursor = new ScrollCursor(scrollId, extractors, remainingLimit); cursor = new ScrollCursor(scrollId, extractors(), mask, remainingLimit);
} }
} }
} }
@ -95,13 +94,7 @@ class SearchHitRowSet extends AbstractRowSet {
} }
@Override @Override
public int columnCount() { protected Object extractValue(HitExtractor e) {
return extractors.size();
}
@Override
protected Object getColumn(int column) {
HitExtractor e = extractors.get(column);
int extractorLevel = e.hitName() == null ? 0 : 1; int extractorLevel = e.hitName() == null ? 0 : 1;
SearchHit hit = null; SearchHit hit = null;

View File

@ -58,7 +58,7 @@ public abstract class SourceGenerator {
// need to be retrieved from the result documents // need to be retrieved from the result documents
// NB: the sortBuilder takes care of eliminating duplicates // NB: the sortBuilder takes care of eliminating duplicates
container.columns().forEach(cr -> cr.collectFields(sortBuilder)); container.fields().forEach(f -> f.v1().collectFields(sortBuilder));
sortBuilder.build(source); sortBuilder.build(source);
optimize(sortBuilder, source); optimize(sortBuilder, source);

View File

@ -11,7 +11,6 @@ import java.util.Iterator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.Set; import java.util.Set;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -21,7 +20,7 @@ import static java.util.Collections.singletonMap;
import static java.util.Collections.unmodifiableCollection; import static java.util.Collections.unmodifiableCollection;
import static java.util.Collections.unmodifiableSet; import static java.util.Collections.unmodifiableSet;
public class AttributeMap<E> { public class AttributeMap<E> implements Map<Attribute, E> {
static class AttributeWrapper { static class AttributeWrapper {
@ -120,8 +119,9 @@ public class AttributeMap<E> {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <A> A[] toArray(A[] a) { public <A> A[] toArray(A[] a) {
// collection is immutable so use that to our advantage // collection is immutable so use that to our advantage
if (a.length < size()) if (a.length < size()) {
a = (A[]) java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), size()); a = (A[]) java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), size());
}
int i = 0; int i = 0;
Object[] result = a; Object[] result = a;
for (U u : this) { for (U u : this) {
@ -140,6 +140,14 @@ public class AttributeMap<E> {
} }
} }
@SuppressWarnings("rawtypes")
public static final AttributeMap EMPTY = new AttributeMap<>();
@SuppressWarnings("unchecked")
public static final <E> AttributeMap<E> emptyAttributeMap() {
return EMPTY;
}
private final Map<AttributeWrapper, E> delegate; private final Map<AttributeWrapper, E> delegate;
private Set<Attribute> keySet = null; private Set<Attribute> keySet = null;
private Collection<E> values = null; private Collection<E> values = null;
@ -175,6 +183,14 @@ public class AttributeMap<E> {
delegate.putAll(other.delegate); delegate.putAll(other.delegate);
} }
public AttributeMap<E> combine(AttributeMap<E> other) {
AttributeMap<E> combine = new AttributeMap<>();
combine.addAll(this);
combine.addAll(other);
return combine;
}
public AttributeMap<E> subtract(AttributeMap<E> other) { public AttributeMap<E> subtract(AttributeMap<E> other) {
AttributeMap<E> diff = new AttributeMap<>(); AttributeMap<E> diff = new AttributeMap<>();
for (Entry<AttributeWrapper, E> entry : this.delegate.entrySet()) { for (Entry<AttributeWrapper, E> entry : this.delegate.entrySet()) {
@ -222,14 +238,17 @@ public class AttributeMap<E> {
return s; return s;
} }
@Override
public int size() { public int size() {
return delegate.size(); return delegate.size();
} }
@Override
public boolean isEmpty() { public boolean isEmpty() {
return delegate.isEmpty(); return delegate.isEmpty();
} }
@Override
public boolean containsKey(Object key) { public boolean containsKey(Object key) {
if (key instanceof NamedExpression) { if (key instanceof NamedExpression) {
return delegate.keySet().contains(new AttributeWrapper(((NamedExpression) key).toAttribute())); return delegate.keySet().contains(new AttributeWrapper(((NamedExpression) key).toAttribute()));
@ -237,10 +256,12 @@ public class AttributeMap<E> {
return false; return false;
} }
@Override
public boolean containsValue(Object value) { public boolean containsValue(Object value) {
return delegate.values().contains(value); return delegate.values().contains(value);
} }
@Override
public E get(Object key) { public E get(Object key) {
if (key instanceof NamedExpression) { if (key instanceof NamedExpression) {
return delegate.get(new AttributeWrapper(((NamedExpression) key).toAttribute())); return delegate.get(new AttributeWrapper(((NamedExpression) key).toAttribute()));
@ -248,6 +269,7 @@ public class AttributeMap<E> {
return null; return null;
} }
@Override
public E getOrDefault(Object key, E defaultValue) { public E getOrDefault(Object key, E defaultValue) {
E e; E e;
return (((e = get(key)) != null) || containsKey(key)) return (((e = get(key)) != null) || containsKey(key))
@ -255,6 +277,27 @@ public class AttributeMap<E> {
: defaultValue; : defaultValue;
} }
@Override
public E put(Attribute key, E value) {
throw new UnsupportedOperationException();
}
@Override
public E remove(Object key) {
throw new UnsupportedOperationException();
}
@Override
public void putAll(Map<? extends Attribute, ? extends E> m) {
throw new UnsupportedOperationException();
}
@Override
public void clear() {
throw new UnsupportedOperationException();
}
@Override
public Set<Attribute> keySet() { public Set<Attribute> keySet() {
if (keySet == null) { if (keySet == null) {
keySet = new UnwrappingSet<AttributeWrapper, Attribute>(delegate.keySet()) { keySet = new UnwrappingSet<AttributeWrapper, Attribute>(delegate.keySet()) {
@ -267,6 +310,7 @@ public class AttributeMap<E> {
return keySet; return keySet;
} }
@Override
public Collection<E> values() { public Collection<E> values() {
if (values == null) { if (values == null) {
values = unmodifiableCollection(delegate.values()); values = unmodifiableCollection(delegate.values());
@ -274,6 +318,7 @@ public class AttributeMap<E> {
return values; return values;
} }
@Override
public Set<Entry<Attribute, E>> entrySet() { public Set<Entry<Attribute, E>> entrySet() {
if (entrySet == null) { if (entrySet == null) {
entrySet = new UnwrappingSet<Entry<AttributeWrapper, E>, Entry<Attribute, E>>(delegate.entrySet()) { entrySet = new UnwrappingSet<Entry<AttributeWrapper, E>, Entry<Attribute, E>>(delegate.entrySet()) {
@ -301,6 +346,7 @@ public class AttributeMap<E> {
return entrySet; return entrySet;
} }
@Override
public void forEach(BiConsumer<? super Attribute, ? super E> action) { public void forEach(BiConsumer<? super Attribute, ? super E> action) {
delegate.forEach((k, v) -> action.accept(k.attr, v)); delegate.forEach((k, v) -> action.accept(k.attr, v));
} }

View File

@ -57,6 +57,10 @@ public class AttributeSet implements Set<Attribute> {
delegate.addAll(other.delegate); delegate.addAll(other.delegate);
} }
public AttributeSet combine(AttributeSet other) {
return new AttributeSet(delegate.combine(other.delegate));
}
public AttributeSet subtract(AttributeSet other) { public AttributeSet subtract(AttributeSet other) {
return new AttributeSet(delegate.subtract(other.delegate)); return new AttributeSet(delegate.subtract(other.delegate));
} }

View File

@ -8,8 +8,8 @@ package org.elasticsearch.xpack.sql.expression;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.capabilities.Resolvable; import org.elasticsearch.xpack.sql.capabilities.Resolvable;
import org.elasticsearch.xpack.sql.capabilities.Resolvables; import org.elasticsearch.xpack.sql.capabilities.Resolvables;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.tree.Node; import org.elasticsearch.xpack.sql.tree.Node;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.type.DataType; import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.util.StringUtils; import org.elasticsearch.xpack.sql.util.StringUtils;
@ -64,6 +64,7 @@ public abstract class Expression extends Node<Expression> implements Resolvable
private TypeResolution lazyTypeResolution = null; private TypeResolution lazyTypeResolution = null;
private Boolean lazyChildrenResolved = null; private Boolean lazyChildrenResolved = null;
private Expression lazyCanonical = null; private Expression lazyCanonical = null;
private AttributeSet lazyReferences = null;
public Expression(Source source, List<Expression> children) { public Expression(Source source, List<Expression> children) {
super(source, children); super(source, children);
@ -82,7 +83,10 @@ public abstract class Expression extends Node<Expression> implements Resolvable
// the references/inputs/leaves of the expression tree // the references/inputs/leaves of the expression tree
public AttributeSet references() { public AttributeSet references() {
return Expressions.references(children()); if (lazyReferences == null) {
lazyReferences = Expressions.references(children());
}
return lazyReferences;
} }
public boolean childrenResolved() { public boolean childrenResolved() {

View File

@ -36,7 +36,7 @@ public final class Expressions {
private Expressions() {} private Expressions() {}
public static NamedExpression wrapAsNamed(Expression exp) { public static NamedExpression wrapAsNamed(Expression exp) {
return exp instanceof NamedExpression ? (NamedExpression) exp : new Alias(exp.source(), exp.nodeName(), exp); return exp instanceof NamedExpression ? (NamedExpression) exp : new Alias(exp.source(), exp.sourceText(), exp);
} }
public static List<Attribute> asAttributes(List<? extends NamedExpression> named) { public static List<Attribute> asAttributes(List<? extends NamedExpression> named) {

View File

@ -91,4 +91,9 @@ public abstract class NamedExpression extends Expression {
&& Objects.equals(name, other.name) && Objects.equals(name, other.name)
&& Objects.equals(children(), other.children()); && Objects.equals(children(), other.children());
} }
@Override
public String toString() {
return super.toString() + "#" + id();
}
} }

View File

@ -49,11 +49,6 @@ public abstract class Function extends NamedExpression {
return Expressions.nullable(children()); return Expressions.nullable(children());
} }
@Override
public String toString() {
return sourceText() + "#" + id();
}
public String functionName() { public String functionName() {
return functionName; return functionName;
} }

View File

@ -166,7 +166,7 @@ public class UnresolvedFunction extends Function implements Unresolvable {
@Override @Override
public String toString() { public String toString() {
return UNRESOLVED_PREFIX + sourceText(); return UNRESOLVED_PREFIX + name + children();
} }
@Override @Override

View File

@ -52,7 +52,7 @@ public abstract class AggregateFunction extends Function {
public AggregateFunctionAttribute toAttribute() { public AggregateFunctionAttribute toAttribute() {
if (lazyAttribute == null) { if (lazyAttribute == null) {
// this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl) // this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl)
lazyAttribute = new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId(), null); lazyAttribute = new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId());
} }
return lazyAttribute; return lazyAttribute;
} }

View File

@ -18,23 +18,36 @@ import java.util.Objects;
public class AggregateFunctionAttribute extends FunctionAttribute { public class AggregateFunctionAttribute extends FunctionAttribute {
// used when dealing with a inner agg (avg -> stats) to keep track of
// packed id
// used since the functionId points to the compoundAgg
private final ExpressionId innerId;
private final String propertyPath; private final String propertyPath;
AggregateFunctionAttribute(Source source, String name, DataType dataType, ExpressionId id, AggregateFunctionAttribute(Source source, String name, DataType dataType, ExpressionId id, String functionId) {
String functionId, String propertyPath) { this(source, name, dataType, null, Nullability.FALSE, id, false, functionId, null, null);
this(source, name, dataType, null, Nullability.FALSE, id, false, functionId, propertyPath);
} }
public AggregateFunctionAttribute(Source source, String name, DataType dataType, String qualifier, AggregateFunctionAttribute(Source source, String name, DataType dataType, ExpressionId id, String functionId, ExpressionId innerId,
Nullability nullability, ExpressionId id, boolean synthetic, String functionId, String propertyPath) { String propertyPath) {
this(source, name, dataType, null, Nullability.FALSE, id, false, functionId, innerId, propertyPath);
}
public AggregateFunctionAttribute(Source source, String name, DataType dataType, String qualifier, Nullability nullability,
ExpressionId id, boolean synthetic, String functionId, ExpressionId innerId, String propertyPath) {
super(source, name, dataType, qualifier, nullability, id, synthetic, functionId); super(source, name, dataType, qualifier, nullability, id, synthetic, functionId);
this.innerId = innerId;
this.propertyPath = propertyPath; this.propertyPath = propertyPath;
} }
@Override @Override
protected NodeInfo<AggregateFunctionAttribute> info() { protected NodeInfo<AggregateFunctionAttribute> info() {
return NodeInfo.create(this, AggregateFunctionAttribute::new, return NodeInfo.create(this, AggregateFunctionAttribute::new, name(), dataType(), qualifier(), nullable(), id(), synthetic(),
name(), dataType(), qualifier(), nullable(), id(), synthetic(), functionId(), propertyPath); functionId(), innerId, propertyPath);
}
public ExpressionId innerId() {
return innerId != null ? innerId : id();
} }
public String propertyPath() { public String propertyPath() {
@ -43,33 +56,38 @@ public class AggregateFunctionAttribute extends FunctionAttribute {
@Override @Override
protected Expression canonicalize() { protected Expression canonicalize() {
return new AggregateFunctionAttribute(source(), "<none>", dataType(), null, Nullability.TRUE, id(), false, "<none>", null); return new AggregateFunctionAttribute(source(), "<none>", dataType(), null, Nullability.TRUE, id(), false, "<none>", null, null);
} }
@Override @Override
protected Attribute clone(Source source, String name, String qualifier, Nullability nullability, ExpressionId id, boolean synthetic) { protected Attribute clone(Source source, String name, String qualifier, Nullability nullability, ExpressionId id, boolean synthetic) {
// this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl) // this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl)
// that is the functionId is actually derived from the expression id to easily track it across contexts // that is the functionId is actually derived from the expression id to easily track it across contexts
return new AggregateFunctionAttribute(source, name, dataType(), qualifier, nullability, id, synthetic, functionId(), propertyPath); return new AggregateFunctionAttribute(source, name, dataType(), qualifier, nullability, id, synthetic, functionId(), innerId,
propertyPath);
} }
public AggregateFunctionAttribute withFunctionId(String functionId, String propertyPath) { public AggregateFunctionAttribute withFunctionId(String functionId, String propertyPath) {
return new AggregateFunctionAttribute(source(), name(), dataType(), qualifier(), nullable(), return new AggregateFunctionAttribute(source(), name(), dataType(), qualifier(), nullable(), id(), synthetic(), functionId, innerId,
id(), synthetic(), functionId, propertyPath); propertyPath);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(super.hashCode(), propertyPath); return Objects.hash(super.hashCode(), innerId, propertyPath);
} }
@Override @Override
public boolean equals(Object obj) { public boolean equals(Object obj) {
return super.equals(obj) && Objects.equals(propertyPath(), ((AggregateFunctionAttribute) obj).propertyPath()); if (super.equals(obj)) {
AggregateFunctionAttribute other = (AggregateFunctionAttribute) obj;
return Objects.equals(innerId, other.innerId) && Objects.equals(propertyPath, other.propertyPath);
}
return false;
} }
@Override @Override
protected String label() { protected String label() {
return "a->" + functionId(); return "a->" + innerId();
} }
} }

View File

@ -77,11 +77,11 @@ public class Count extends AggregateFunction {
public AggregateFunctionAttribute toAttribute() { public AggregateFunctionAttribute toAttribute() {
// COUNT(*) gets its value from the parent aggregation on which _count is called // COUNT(*) gets its value from the parent aggregation on which _count is called
if (field() instanceof Literal) { if (field() instanceof Literal) {
return new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId(), "_count"); return new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId(), id(), "_count");
} }
// COUNT(column) gets its value from a sibling aggregation (an exists filter agg) by calling its id and then _count on it // COUNT(column) gets its value from a sibling aggregation (an exists filter agg) by calling its id and then _count on it
if (!distinct()) { if (!distinct()) {
return new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId(), functionId() + "._count"); return new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId(), id(), functionId() + "._count");
} }
return super.toAttribute(); return super.toAttribute();
} }

View File

@ -7,8 +7,8 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.function.Function; import org.elasticsearch.xpack.sql.expression.function.Function;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.tree.NodeInfo; import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.type.DataType; import org.elasticsearch.xpack.sql.type.DataType;
import java.util.List; import java.util.List;
@ -17,7 +17,7 @@ public class InnerAggregate extends AggregateFunction {
private final AggregateFunction inner; private final AggregateFunction inner;
private final CompoundNumericAggregate outer; private final CompoundNumericAggregate outer;
private final String innerId; private final String innerName;
// used when the result needs to be extracted from a map (like in MatrixAggs or Percentiles) // used when the result needs to be extracted from a map (like in MatrixAggs or Percentiles)
private final Expression innerKey; private final Expression innerKey;
@ -29,7 +29,7 @@ public class InnerAggregate extends AggregateFunction {
super(source, outer.field(), outer.arguments()); super(source, outer.field(), outer.arguments());
this.inner = inner; this.inner = inner;
this.outer = outer; this.outer = outer;
this.innerId = ((EnclosedAgg) inner).innerName(); this.innerName = ((EnclosedAgg) inner).innerName();
this.innerKey = innerKey; this.innerKey = innerKey;
} }
@ -55,8 +55,8 @@ public class InnerAggregate extends AggregateFunction {
return outer; return outer;
} }
public String innerId() { public String innerName() {
return innerId; return innerName;
} }
public Expression innerKey() { public Expression innerKey() {
@ -77,10 +77,10 @@ public class InnerAggregate extends AggregateFunction {
public AggregateFunctionAttribute toAttribute() { public AggregateFunctionAttribute toAttribute() {
// this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl) // this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl)
return new AggregateFunctionAttribute(source(), name(), dataType(), outer.id(), functionId(), return new AggregateFunctionAttribute(source(), name(), dataType(), outer.id(), functionId(),
aggMetricValue(functionId(), innerId)); inner.id(), aggMetricValue(functionId(), innerName));
} }
public static String aggMetricValue(String aggPath, String valueName) { private static String aggMetricValue(String aggPath, String valueName) {
// handle aggPath inconsistency (for percentiles and percentileRanks) percentile[99.9] (valid) vs percentile.99.9 (invalid) // handle aggPath inconsistency (for percentiles and percentileRanks) percentile[99.9] (valid) vs percentile.99.9 (invalid)
return aggPath + "[" + valueName + "]"; return aggPath + "[" + valueName + "]";
} }
@ -98,4 +98,9 @@ public class InnerAggregate extends AggregateFunction {
public String name() { public String name() {
return inner.name(); return inner.name();
} }
@Override
public String toString() {
return nodeName() + "[" + outer + ">" + inner.nodeName() + "#" + inner.id() + "]";
}
} }

View File

@ -10,7 +10,6 @@ import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer.CleanAliases;
import org.elasticsearch.xpack.sql.expression.Alias; import org.elasticsearch.xpack.sql.expression.Alias;
import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.AttributeMap; import org.elasticsearch.xpack.sql.expression.AttributeMap;
import org.elasticsearch.xpack.sql.expression.AttributeSet;
import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.ExpressionId; import org.elasticsearch.xpack.sql.expression.ExpressionId;
import org.elasticsearch.xpack.sql.expression.ExpressionSet; import org.elasticsearch.xpack.sql.expression.ExpressionSet;
@ -78,22 +77,23 @@ import org.elasticsearch.xpack.sql.rule.Rule;
import org.elasticsearch.xpack.sql.rule.RuleExecutor; import org.elasticsearch.xpack.sql.rule.RuleExecutor;
import org.elasticsearch.xpack.sql.session.EmptyExecutable; import org.elasticsearch.xpack.sql.session.EmptyExecutable;
import org.elasticsearch.xpack.sql.session.SingletonExecutable; import org.elasticsearch.xpack.sql.session.SingletonExecutable;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.type.DataType; import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.util.CollectionUtils; import org.elasticsearch.xpack.sql.util.CollectionUtils;
import org.elasticsearch.xpack.sql.util.Holder;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer; import java.util.function.Consumer;
import static java.util.stream.Collectors.toList;
import static org.elasticsearch.xpack.sql.expression.Literal.FALSE; import static org.elasticsearch.xpack.sql.expression.Literal.FALSE;
import static org.elasticsearch.xpack.sql.expression.Literal.TRUE; import static org.elasticsearch.xpack.sql.expression.Literal.TRUE;
import static org.elasticsearch.xpack.sql.expression.predicate.Predicates.combineAnd; import static org.elasticsearch.xpack.sql.expression.predicate.Predicates.combineAnd;
@ -117,19 +117,8 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
@Override @Override
protected Iterable<RuleExecutor<LogicalPlan>.Batch> batches() { protected Iterable<RuleExecutor<LogicalPlan>.Batch> batches() {
Batch aggregate = new Batch("Aggregation",
new PruneDuplicatesInGroupBy(),
new ReplaceDuplicateAggsWithReferences(),
new ReplaceAggsWithMatrixStats(),
new ReplaceAggsWithExtendedStats(),
new ReplaceAggsWithStats(),
new PromoteStatsToExtendedStats(),
new ReplaceAggsWithPercentiles(),
new ReplaceAggsWithPercentileRanks(),
new ReplaceMinMaxWithTopHits()
);
Batch operators = new Batch("Operator Optimization", Batch operators = new Batch("Operator Optimization",
new PruneDuplicatesInGroupBy(),
// combining // combining
new CombineProjections(), new CombineProjections(),
// folding // folding
@ -157,6 +146,16 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
//new PruneDuplicateFunctions() //new PruneDuplicateFunctions()
); );
Batch aggregate = new Batch("Aggregation Rewrite",
//new ReplaceDuplicateAggsWithReferences(),
new ReplaceAggsWithMatrixStats(),
new ReplaceAggsWithExtendedStats(),
new ReplaceAggsWithStats(),
new PromoteStatsToExtendedStats(),
new ReplaceAggsWithPercentiles(),
new ReplaceAggsWithPercentileRanks()
);
Batch local = new Batch("Skip Elasticsearch", Batch local = new Batch("Skip Elasticsearch",
new SkipQueryOnLimitZero(), new SkipQueryOnLimitZero(),
new SkipQueryIfFoldingProjection() new SkipQueryIfFoldingProjection()
@ -253,7 +252,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
seen.put(argument, matrixStats); seen.put(argument, matrixStats);
} }
InnerAggregate ia = new InnerAggregate(f.source(), f, matrixStats, f.field()); InnerAggregate ia = new InnerAggregate(f.source(), f, matrixStats, argument);
promotedIds.putIfAbsent(f.functionId(), ia.toAttribute()); promotedIds.putIfAbsent(f.functionId(), ia.toAttribute());
return ia; return ia;
} }
@ -310,8 +309,8 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
private static class Match { private static class Match {
final Stats stats; final Stats stats;
int count = 1; private final Set<Class<? extends AggregateFunction>> functionTypes = new LinkedHashSet<>();
final Set<Class<? extends AggregateFunction>> functionTypes = new LinkedHashSet<>(); private Map<Class<? extends AggregateFunction>, InnerAggregate> innerAggs = null;
Match(Stats stats) { Match(Stats stats) {
this.stats = stats; this.stats = stats;
@ -321,6 +320,22 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
public String toString() { public String toString() {
return stats.toString(); return stats.toString();
} }
public void add(Class<? extends AggregateFunction> aggType) {
functionTypes.add(aggType);
}
// if the stat has at least two different functions for it, promote it as stat
// also keep the promoted function around for reuse
public AggregateFunction maybePromote(AggregateFunction agg) {
if (functionTypes.size() > 1) {
if (innerAggs == null) {
innerAggs = new LinkedHashMap<>();
}
return innerAggs.computeIfAbsent(agg.getClass(), k -> new InnerAggregate(agg, stats));
}
return agg;
}
} }
@Override @Override
@ -359,15 +374,10 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
Match match = seen.get(argument); Match match = seen.get(argument);
if (match == null) { if (match == null) {
match = new Match(new Stats(f.source(), argument)); match = new Match(new Stats(new Source(f.sourceLocation(), "STATS(" + Expressions.name(argument) + ")"), argument));
match.functionTypes.add(f.getClass());
seen.put(argument, match); seen.put(argument, match);
} }
else { match.add(f.getClass());
if (match.functionTypes.add(f.getClass())) {
match.count++;
}
}
} }
return e; return e;
@ -378,13 +388,14 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
AggregateFunction f = (AggregateFunction) e; AggregateFunction f = (AggregateFunction) e;
Expression argument = f.field(); Expression argument = f.field();
Match counter = seen.get(argument); Match match = seen.get(argument);
// if the stat has at least two different functions for it, promote it as stat if (match != null) {
if (counter != null && counter.count > 1) { AggregateFunction inner = match.maybePromote(f);
InnerAggregate innerAgg = new InnerAggregate(f, counter.stats); if (inner != f) {
attrs.putIfAbsent(f.functionId(), innerAgg.toAttribute()); attrs.putIfAbsent(f.functionId(), inner.toAttribute());
return innerAgg; }
return inner;
} }
} }
return e; return e;
@ -819,31 +830,23 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
@Override @Override
protected LogicalPlan rule(OrderBy ob) { protected LogicalPlan rule(OrderBy ob) {
List<Order> order = ob.order(); Holder<Boolean> foundAggregate = new Holder<>(Boolean.FALSE);
Holder<Boolean> foundImplicitGroupBy = new Holder<>(Boolean.FALSE);
// remove constants // if the first found aggregate has no grouping, there's no need to do ordering
List<Order> nonConstant = order.stream().filter(o -> !o.child().foldable()).collect(toList()); ob.forEachDown(a -> {
// take into account
if (nonConstant.isEmpty()) { if (foundAggregate.get() == Boolean.TRUE) {
return ob.child(); return;
} }
foundAggregate.set(Boolean.TRUE);
// if the sort points to an agg, consider it only if there's grouping
if (ob.child() instanceof Aggregate) {
Aggregate a = (Aggregate) ob.child();
if (a.groupings().isEmpty()) { if (a.groupings().isEmpty()) {
AttributeSet aggsAttr = new AttributeSet(Expressions.asAttributes(a.aggregates())); foundImplicitGroupBy.set(Boolean.TRUE);
List<Order> nonAgg = nonConstant.stream().filter(o -> {
if (o.child() instanceof NamedExpression) {
return !aggsAttr.contains(((NamedExpression) o.child()).toAttribute());
} }
return true; }, Aggregate.class);
}).collect(toList());
return nonAgg.isEmpty() ? ob.child() : new OrderBy(ob.source(), ob.child(), nonAgg); if (foundImplicitGroupBy.get() == Boolean.TRUE) {
} return ob.child();
} }
return ob; return ob;
} }
@ -858,34 +861,43 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
protected LogicalPlan rule(OrderBy ob) { protected LogicalPlan rule(OrderBy ob) {
List<Order> order = ob.order(); List<Order> order = ob.order();
// remove constants // remove constants and put the items in reverse order so the iteration happens back to front
List<Order> nonConstant = order.stream().filter(o -> !o.child().foldable()).collect(toList()); List<Order> nonConstant = new LinkedList<>();
for (Order o : order) {
if (o.child().foldable() == false) {
nonConstant.add(0, o);
}
}
// if the sort points to an agg, change the agg order based on the order Holder<Boolean> foundAggregate = new Holder<>(Boolean.FALSE);
if (ob.child() instanceof Aggregate) {
Aggregate a = (Aggregate) ob.child();
List<Expression> groupings = new ArrayList<>(a.groupings());
boolean orderChanged = false;
for (int orderIndex = 0; orderIndex < nonConstant.size(); orderIndex++) { // if the first found aggregate has no grouping, there's no need to do ordering
Order o = nonConstant.get(orderIndex); return ob.transformDown(a -> {
// take into account
if (foundAggregate.get() == Boolean.TRUE) {
return a;
}
foundAggregate.set(Boolean.TRUE);
List<Expression> groupings = new LinkedList<>(a.groupings());
for (Order o : nonConstant) {
Expression fieldToOrder = o.child(); Expression fieldToOrder = o.child();
for (Expression group : a.groupings()) { for (Expression group : a.groupings()) {
if (Expressions.equalsAsAttribute(fieldToOrder, group)) { if (Expressions.equalsAsAttribute(fieldToOrder, group)) {
// move grouping in front // move grouping in front
groupings.remove(group); groupings.remove(group);
groupings.add(orderIndex, group); groupings.add(0, group);
orderChanged = true;
} }
} }
} }
if (orderChanged) { if (groupings.equals(a.groupings()) == false) {
Aggregate newAgg = new Aggregate(a.source(), a.child(), groupings, a.aggregates()); return new Aggregate(a.source(), a.child(), groupings, a.aggregates());
return new OrderBy(ob.source(), newAgg, ob.order());
} }
}
return ob; return a;
}, Aggregate.class);
} }
} }
@ -1017,6 +1029,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
// eliminate lower project but first replace the aliases in the upper one // eliminate lower project but first replace the aliases in the upper one
return new Project(p.source(), p.child(), combineProjections(project.projections(), p.projections())); return new Project(p.source(), p.child(), combineProjections(project.projections(), p.projections()));
} }
if (child instanceof Aggregate) { if (child instanceof Aggregate) {
Aggregate a = (Aggregate) child; Aggregate a = (Aggregate) child;
return new Aggregate(a.source(), a.child(), a.groupings(), combineProjections(project.projections(), a.aggregates())); return new Aggregate(a.source(), a.child(), a.groupings(), combineProjections(project.projections(), a.aggregates()));
@ -1029,23 +1042,25 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
// that might be reused by the upper one, these need to be replaced. // that might be reused by the upper one, these need to be replaced.
// for example an alias defined in the lower list might be referred in the upper - without replacing it the alias becomes invalid // for example an alias defined in the lower list might be referred in the upper - without replacing it the alias becomes invalid
private List<NamedExpression> combineProjections(List<? extends NamedExpression> upper, List<? extends NamedExpression> lower) { private List<NamedExpression> combineProjections(List<? extends NamedExpression> upper, List<? extends NamedExpression> lower) {
//TODO: this need rewriting when moving functions of NamedExpression
// collect aliases in the lower list // collect aliases in the lower list
Map<Attribute, Alias> map = new LinkedHashMap<>(); Map<Attribute, NamedExpression> map = new LinkedHashMap<>();
for (NamedExpression ne : lower) { for (NamedExpression ne : lower) {
if (ne instanceof Alias) { if ((ne instanceof Attribute) == false) {
Alias a = (Alias) ne; map.put(ne.toAttribute(), ne);
map.put(a.toAttribute(), a);
} }
} }
AttributeMap<Alias> aliases = new AttributeMap<>(map); AttributeMap<NamedExpression> aliases = new AttributeMap<>(map);
List<NamedExpression> replaced = new ArrayList<>(); List<NamedExpression> replaced = new ArrayList<>();
// replace any matching attribute with a lower alias (if there's a match) // replace any matching attribute with a lower alias (if there's a match)
// but clean-up non-top aliases at the end // but clean-up non-top aliases at the end
for (NamedExpression ne : upper) { for (NamedExpression ne : upper) {
NamedExpression replacedExp = (NamedExpression) ne.transformUp(a -> { NamedExpression replacedExp = (NamedExpression) ne.transformUp(a -> {
Alias as = aliases.get(a); NamedExpression as = aliases.get(a);
return as != null ? as : a; return as != null ? as : a;
}, Attribute.class); }, Attribute.class);
@ -1088,12 +1103,12 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
return plan; return plan;
} }
AtomicBoolean stop = new AtomicBoolean(false); Holder<Boolean> stop = new Holder<>(Boolean.FALSE);
// propagate folding up to unary nodes // propagate folding up to unary nodes
// anything higher and the propagate stops // anything higher and the propagate stops
plan = plan.transformUp(p -> { plan = plan.transformUp(p -> {
if (stop.get() == false && canPropagateFoldable(p)) { if (stop.get() == Boolean.FALSE && canPropagateFoldable(p)) {
return p.transformExpressionsDown(e -> { return p.transformExpressionsDown(e -> {
if (e instanceof Attribute && attrs.contains(e)) { if (e instanceof Attribute && attrs.contains(e)) {
Alias as = aliases.get(e); Alias as = aliases.get(e);
@ -1108,7 +1123,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
} }
if (p.children().size() > 1) { if (p.children().size() > 1) {
stop.set(true); stop.set(Boolean.TRUE);
} }
return p; return p;

View File

@ -60,13 +60,11 @@ public class TableIdentifier {
@Override @Override
public String toString() { public String toString() {
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
builder.append("[");
if (cluster != null) { if (cluster != null) {
builder.append(cluster); builder.append(cluster);
builder.append(":");
} }
builder.append("][index=");
builder.append(index); builder.append(index);
builder.append("]");
return builder.toString(); return builder.toString();
} }
} }

View File

@ -8,13 +8,15 @@ package org.elasticsearch.xpack.sql.plan.logical;
import org.elasticsearch.xpack.sql.capabilities.Unresolvable; import org.elasticsearch.xpack.sql.capabilities.Unresolvable;
import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.plan.TableIdentifier; import org.elasticsearch.xpack.sql.plan.TableIdentifier;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.tree.NodeInfo; import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.tree.Source;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import static java.util.Collections.singletonList;
public class UnresolvedRelation extends LeafPlan implements Unresolvable { public class UnresolvedRelation extends LeafPlan implements Unresolvable {
private final TableIdentifier table; private final TableIdentifier table;
@ -86,4 +88,14 @@ public class UnresolvedRelation extends LeafPlan implements Unresolvable {
&& Objects.equals(alias, other.alias) && Objects.equals(alias, other.alias)
&& unresolvedMsg.equals(other.unresolvedMsg); && unresolvedMsg.equals(other.unresolvedMsg);
} }
@Override
public List<Object> nodeProperties() {
return singletonList(table);
}
@Override
public String toString() {
return UNRESOLVED_PREFIX + table.index();
}
} }

View File

@ -9,11 +9,10 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.sql.execution.search.Querier; import org.elasticsearch.xpack.sql.execution.search.Querier;
import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.querydsl.container.QueryContainer; import org.elasticsearch.xpack.sql.querydsl.container.QueryContainer;
import org.elasticsearch.xpack.sql.session.Rows;
import org.elasticsearch.xpack.sql.session.SchemaRowSet; import org.elasticsearch.xpack.sql.session.SchemaRowSet;
import org.elasticsearch.xpack.sql.session.SqlSession; import org.elasticsearch.xpack.sql.session.SqlSession;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.tree.NodeInfo; import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.tree.Source;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
@ -22,7 +21,6 @@ public class EsQueryExec extends LeafExec {
private final String index; private final String index;
private final List<Attribute> output; private final List<Attribute> output;
private final QueryContainer queryContainer; private final QueryContainer queryContainer;
public EsQueryExec(Source source, String index, List<Attribute> output, QueryContainer queryContainer) { public EsQueryExec(Source source, String index, List<Attribute> output, QueryContainer queryContainer) {
@ -56,8 +54,9 @@ public class EsQueryExec extends LeafExec {
@Override @Override
public void execute(SqlSession session, ActionListener<SchemaRowSet> listener) { public void execute(SqlSession session, ActionListener<SchemaRowSet> listener) {
Querier scroller = new Querier(session.client(), session.configuration()); Querier scroller = new Querier(session);
scroller.query(Rows.schema(output), queryContainer, index, listener);
scroller.query(output, queryContainer, index, listener);
} }
@Override @Override

View File

@ -10,6 +10,7 @@ import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.execution.search.AggRef; import org.elasticsearch.xpack.sql.execution.search.AggRef;
import org.elasticsearch.xpack.sql.expression.Alias; import org.elasticsearch.xpack.sql.expression.Alias;
import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.AttributeMap;
import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions; import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Foldables; import org.elasticsearch.xpack.sql.expression.Foldables;
@ -146,8 +147,12 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
} }
} }
QueryContainer clone = new QueryContainer(queryC.query(), queryC.aggs(), queryC.columns(), aliases, QueryContainer clone = new QueryContainer(queryC.query(), queryC.aggs(), queryC.fields(),
queryC.pseudoFunctions(), processors, queryC.sort(), queryC.limit()); new AttributeMap<>(aliases),
queryC.pseudoFunctions(),
new AttributeMap<>(processors),
queryC.sort(),
queryC.limit());
return new EsQueryExec(exec.source(), exec.index(), project.output(), clone); return new EsQueryExec(exec.source(), exec.index(), project.output(), clone);
} }
return project; return project;
@ -170,7 +175,8 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
} }
Aggs aggs = addPipelineAggs(qContainer, qt, plan); Aggs aggs = addPipelineAggs(qContainer, qt, plan);
qContainer = new QueryContainer(query, aggs, qContainer.columns(), qContainer.aliases(), qContainer = new QueryContainer(query, aggs, qContainer.fields(),
qContainer.aliases(),
qContainer.pseudoFunctions(), qContainer.pseudoFunctions(),
qContainer.scalarFunctions(), qContainer.scalarFunctions(),
qContainer.sort(), qContainer.sort(),
@ -315,7 +321,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
} }
// add the computed column // add the computed column
queryC = qC.get().addColumn(new ComputedRef(proc)); queryC = qC.get().addColumn(new ComputedRef(proc), f.toAttribute());
// TODO: is this needed? // TODO: is this needed?
// redirect the alias to the scalar group id (changing the id altogether doesn't work it is // redirect the alias to the scalar group id (changing the id altogether doesn't work it is
@ -337,20 +343,22 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
// UTC is used since that's what the server uses and there's no conversion applied // UTC is used since that's what the server uses and there's no conversion applied
// (like for date histograms) // (like for date histograms)
ZoneId zi = child.dataType().isDateBased() ? DateUtils.UTC : null; ZoneId zi = child.dataType().isDateBased() ? DateUtils.UTC : null;
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, zi)); queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, zi), ((Attribute) child));
} }
// handle histogram // handle histogram
else if (child instanceof GroupingFunction) { else if (child instanceof GroupingFunction) {
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, null)); queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, null),
((GroupingFunction) child).toAttribute());
} }
// fallback to regular agg functions // fallback to regular agg functions
else { else {
// the only thing left is agg function // the only thing left is agg function
Check.isTrue(Functions.isAggregate(child), Check.isTrue(Functions.isAggregate(child),
"Expected aggregate function inside alias; got [{}]", child.nodeString()); "Expected aggregate function inside alias; got [{}]", child.nodeString());
Tuple<QueryContainer, AggPathInput> withAgg = addAggFunction(matchingGroup, AggregateFunction af = (AggregateFunction) child;
(AggregateFunction) child, compoundAggMap, queryC); Tuple<QueryContainer, AggPathInput> withAgg = addAggFunction(matchingGroup, af, compoundAggMap, queryC);
queryC = withAgg.v1().addColumn(withAgg.v2().context()); // make sure to add the inner id (to handle compound aggs)
queryC = withAgg.v1().addColumn(withAgg.v2().context(), af.toAttribute());
} }
} }
// not an Alias or Function means it's an Attribute so apply the same logic as above // not an Alias or Function means it's an Attribute so apply the same logic as above
@ -361,7 +369,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
Check.notNull(matchingGroup, "Cannot find group [{}]", Expressions.name(ne)); Check.notNull(matchingGroup, "Cannot find group [{}]", Expressions.name(ne));
ZoneId zi = ne.dataType().isDateBased() ? DateUtils.UTC : null; ZoneId zi = ne.dataType().isDateBased() ? DateUtils.UTC : null;
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, zi)); queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, zi), ne.toAttribute());
} }
} }
} }
@ -369,7 +377,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
if (!aliases.isEmpty()) { if (!aliases.isEmpty()) {
Map<Attribute, Attribute> newAliases = new LinkedHashMap<>(queryC.aliases()); Map<Attribute, Attribute> newAliases = new LinkedHashMap<>(queryC.aliases());
newAliases.putAll(aliases); newAliases.putAll(aliases);
queryC = queryC.withAliases(newAliases); queryC = queryC.withAliases(new AttributeMap<>(newAliases));
} }
return new EsQueryExec(exec.source(), exec.index(), a.output(), queryC); return new EsQueryExec(exec.source(), exec.index(), a.output(), queryC);
} }
@ -420,7 +428,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
// FIXME: concern leak - hack around MatrixAgg which is not // FIXME: concern leak - hack around MatrixAgg which is not
// generalized (afaik) // generalized (afaik)
aggInput = new AggPathInput(f, aggInput = new AggPathInput(f,
new MetricAggRef(cAggPath, ia.innerId(), ia.innerKey() != null ? QueryTranslator.nameOf(ia.innerKey()) : null)); new MetricAggRef(cAggPath, ia.innerName(), ia.innerKey() != null ? QueryTranslator.nameOf(ia.innerKey()) : null));
} }
else { else {
LeafAgg leafAgg = toAgg(functionId, f); LeafAgg leafAgg = toAgg(functionId, f);
@ -474,19 +482,19 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
if (sfa.orderBy() instanceof NamedExpression) { if (sfa.orderBy() instanceof NamedExpression) {
Attribute at = ((NamedExpression) sfa.orderBy()).toAttribute(); Attribute at = ((NamedExpression) sfa.orderBy()).toAttribute();
at = qContainer.aliases().getOrDefault(at, at); at = qContainer.aliases().getOrDefault(at, at);
qContainer = qContainer.sort(new AttributeSort(at, direction, missing)); qContainer = qContainer.addSort(new AttributeSort(at, direction, missing));
} else if (!sfa.orderBy().foldable()) { } else if (!sfa.orderBy().foldable()) {
// ignore constant // ignore constant
throw new PlanningException("does not know how to order by expression {}", sfa.orderBy()); throw new PlanningException("does not know how to order by expression {}", sfa.orderBy());
} }
} else { } else {
// nope, use scripted sorting // nope, use scripted sorting
qContainer = qContainer.sort(new ScriptSort(sfa.script(), direction, missing)); qContainer = qContainer.addSort(new ScriptSort(sfa.script(), direction, missing));
} }
} else if (attr instanceof ScoreAttribute) { } else if (attr instanceof ScoreAttribute) {
qContainer = qContainer.sort(new ScoreSort(direction, missing)); qContainer = qContainer.addSort(new ScoreSort(direction, missing));
} else { } else {
qContainer = qContainer.sort(new AttributeSort(attr, direction, missing)); qContainer = qContainer.addSort(new AttributeSort(attr, direction, missing));
} }
} }
} }

View File

@ -12,6 +12,7 @@ import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregationBui
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.querydsl.container.Sort.Direction; import org.elasticsearch.xpack.sql.querydsl.container.Sort.Direction;
import org.elasticsearch.xpack.sql.util.StringUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
@ -21,7 +22,6 @@ import java.util.Objects;
import static java.util.Collections.emptyList; import static java.util.Collections.emptyList;
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery; import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
import static org.elasticsearch.xpack.sql.util.CollectionUtils.combine; import static org.elasticsearch.xpack.sql.util.CollectionUtils.combine;
import static org.elasticsearch.xpack.sql.util.StringUtils.EMPTY;
/** /**
* SQL Aggregations associated with a query. * SQL Aggregations associated with a query.
@ -40,7 +40,7 @@ public class Aggs {
public static final String ROOT_GROUP_NAME = "groupby"; public static final String ROOT_GROUP_NAME = "groupby";
public static final GroupByKey IMPLICIT_GROUP_KEY = new GroupByKey(ROOT_GROUP_NAME, EMPTY, null, null) { public static final GroupByKey IMPLICIT_GROUP_KEY = new GroupByKey(ROOT_GROUP_NAME, StringUtils.EMPTY, null, null) {
@Override @Override
public CompositeValuesSourceBuilder<?> createSourceBuilder() { public CompositeValuesSourceBuilder<?> createSourceBuilder() {
@ -53,14 +53,12 @@ public class Aggs {
} }
}; };
public static final Aggs EMPTY = new Aggs(emptyList(), emptyList(), emptyList());
private final List<GroupByKey> groups; private final List<GroupByKey> groups;
private final List<LeafAgg> simpleAggs; private final List<LeafAgg> simpleAggs;
private final List<PipelineAgg> pipelineAggs; private final List<PipelineAgg> pipelineAggs;
public Aggs() {
this(emptyList(), emptyList(), emptyList());
}
public Aggs(List<GroupByKey> groups, List<LeafAgg> simpleAggs, List<PipelineAgg> pipelineAggs) { public Aggs(List<GroupByKey> groups, List<LeafAgg> simpleAggs, List<PipelineAgg> pipelineAggs) {
this.groups = groups; this.groups = groups;

View File

@ -15,9 +15,12 @@ import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.execution.search.FieldExtraction; import org.elasticsearch.xpack.sql.execution.search.FieldExtraction;
import org.elasticsearch.xpack.sql.execution.search.SourceGenerator; import org.elasticsearch.xpack.sql.execution.search.SourceGenerator;
import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.AttributeMap;
import org.elasticsearch.xpack.sql.expression.ExpressionId;
import org.elasticsearch.xpack.sql.expression.FieldAttribute; import org.elasticsearch.xpack.sql.expression.FieldAttribute;
import org.elasticsearch.xpack.sql.expression.LiteralAttribute; import org.elasticsearch.xpack.sql.expression.LiteralAttribute;
import org.elasticsearch.xpack.sql.expression.function.ScoreAttribute; import org.elasticsearch.xpack.sql.expression.function.ScoreAttribute;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute; import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.querydsl.agg.Aggs; import org.elasticsearch.xpack.sql.querydsl.agg.Aggs;
@ -33,7 +36,9 @@ import org.elasticsearch.xpack.sql.tree.Source;
import java.io.IOException; import java.io.IOException;
import java.util.AbstractMap; import java.util.AbstractMap;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection; import java.util.Collection;
import java.util.Comparator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
@ -47,48 +52,142 @@ import static java.util.Collections.emptySet;
import static java.util.Collections.singletonMap; import static java.util.Collections.singletonMap;
import static org.elasticsearch.xpack.sql.util.CollectionUtils.combine; import static org.elasticsearch.xpack.sql.util.CollectionUtils.combine;
/**
* Container for various references of the built ES query.
* Useful to understanding how to interpret and navigate the
* returned result.
*/
public class QueryContainer { public class QueryContainer {
private final Aggs aggs; private final Aggs aggs;
private final Query query; private final Query query;
// final output seen by the client (hence the list or ordering) // fields extracted from the response - not necessarily what the client sees
// gets converted by the Scroller into Extractors for hits or actual results in case of aggregations // for example in case of grouping or custom sorting, the response has extra columns
private final List<FieldExtraction> columns; // that is filtered before getting to the client
// the list contains both the field extraction and the id of its associated attribute (for custom sorting)
private final List<Tuple<FieldExtraction, ExpressionId>> fields;
// aliases (maps an alias to its actual resolved attribute) // aliases (maps an alias to its actual resolved attribute)
private final Map<Attribute, Attribute> aliases; private final AttributeMap<Attribute> aliases;
// pseudo functions (like count) - that are 'extracted' from other aggs // pseudo functions (like count) - that are 'extracted' from other aggs
private final Map<String, GroupByKey> pseudoFunctions; private final Map<String, GroupByKey> pseudoFunctions;
// scalar function processors - recorded as functions get folded; // scalar function processors - recorded as functions get folded;
// at scrolling, their inputs (leaves) get updated // at scrolling, their inputs (leaves) get updated
private final Map<Attribute, Pipe> scalarFunctions; private final AttributeMap<Pipe> scalarFunctions;
private final Set<Sort> sort; private final Set<Sort> sort;
private final int limit; private final int limit;
// computed // computed
private final boolean aggsOnly; private Boolean aggsOnly;
private Boolean customSort;
public QueryContainer() { public QueryContainer() {
this(null, null, null, null, null, null, null, -1); this(null, null, null, null, null, null, null, -1);
} }
public QueryContainer(Query query, Aggs aggs, List<FieldExtraction> refs, Map<Attribute, Attribute> aliases, public QueryContainer(Query query,
Aggs aggs,
List<Tuple<FieldExtraction, ExpressionId>> fields,
AttributeMap<Attribute> aliases,
Map<String, GroupByKey> pseudoFunctions, Map<String, GroupByKey> pseudoFunctions,
Map<Attribute, Pipe> scalarFunctions, AttributeMap<Pipe> scalarFunctions,
Set<Sort> sort, int limit) { Set<Sort> sort,
int limit) {
this.query = query; this.query = query;
this.aggs = aggs == null ? new Aggs() : aggs; this.aggs = aggs == null ? Aggs.EMPTY : aggs;
this.aliases = aliases == null || aliases.isEmpty() ? emptyMap() : aliases; this.fields = fields == null || fields.isEmpty() ? emptyList() : fields;
this.aliases = aliases == null || aliases.isEmpty() ? AttributeMap.emptyAttributeMap() : aliases;
this.pseudoFunctions = pseudoFunctions == null || pseudoFunctions.isEmpty() ? emptyMap() : pseudoFunctions; this.pseudoFunctions = pseudoFunctions == null || pseudoFunctions.isEmpty() ? emptyMap() : pseudoFunctions;
this.scalarFunctions = scalarFunctions == null || scalarFunctions.isEmpty() ? emptyMap() : scalarFunctions; this.scalarFunctions = scalarFunctions == null || scalarFunctions.isEmpty() ? AttributeMap.emptyAttributeMap() : scalarFunctions;
this.columns = refs == null || refs.isEmpty() ? emptyList() : refs;
this.sort = sort == null || sort.isEmpty() ? emptySet() : sort; this.sort = sort == null || sort.isEmpty() ? emptySet() : sort;
this.limit = limit; this.limit = limit;
aggsOnly = columns.stream().allMatch(FieldExtraction::supportedByAggsOnlyQuery); }
/**
* If needed, create a comparator for each indicated column (which is indicated by an index pointing to the column number from the
* result set).
*/
@SuppressWarnings({ "rawtypes", "unchecked" })
public List<Tuple<Integer, Comparator>> sortingColumns() {
if (customSort == Boolean.FALSE) {
return emptyList();
}
List<Tuple<Integer, Comparator>> sortingColumns = new ArrayList<>(sort.size());
boolean aggSort = false;
for (Sort s : sort) {
Tuple<Integer, Comparator> tuple = new Tuple<>(Integer.valueOf(-1), null);
if (s instanceof AttributeSort) {
AttributeSort as = (AttributeSort) s;
// find the relevant column of each aggregate function
if (as.attribute() instanceof AggregateFunctionAttribute) {
aggSort = true;
AggregateFunctionAttribute afa = (AggregateFunctionAttribute) as.attribute();
afa = (AggregateFunctionAttribute) aliases.getOrDefault(afa, afa);
int atIndex = -1;
for (int i = 0; i < fields.size(); i++) {
Tuple<FieldExtraction, ExpressionId> field = fields.get(i);
if (field.v2().equals(afa.innerId())) {
atIndex = i;
break;
}
}
if (atIndex == -1) {
throw new SqlIllegalArgumentException("Cannot find backing column for ordering aggregation [{}]", afa.name());
}
// assemble a comparator for it
Comparator comp = s.direction() == Sort.Direction.ASC ? Comparator.naturalOrder() : Comparator.reverseOrder();
comp = s.missing() == Sort.Missing.FIRST ? Comparator.nullsFirst(comp) : Comparator.nullsLast(comp);
tuple = new Tuple<>(Integer.valueOf(atIndex), comp);
}
}
sortingColumns.add(tuple);
}
if (customSort == null) {
customSort = Boolean.valueOf(aggSort);
}
return aggSort ? sortingColumns : emptyList();
}
/**
* Since the container contains both the field extractors and the visible columns,
* compact the information in the listener through a bitset that acts as a mask
* on what extractors are used for the visible columns.
*/
public BitSet columnMask(List<Attribute> columns) {
BitSet mask = new BitSet(fields.size());
for (Attribute column : columns) {
Attribute alias = aliases.get(column);
// find the column index
int index = -1;
ExpressionId id = column instanceof AggregateFunctionAttribute ? ((AggregateFunctionAttribute) column).innerId() : column.id();
ExpressionId aliasId = alias != null ? (alias instanceof AggregateFunctionAttribute ? ((AggregateFunctionAttribute) alias)
.innerId() : alias.id()) : null;
for (int i = 0; i < fields.size(); i++) {
Tuple<FieldExtraction, ExpressionId> tuple = fields.get(i);
if (tuple.v2().equals(id) || (aliasId != null && tuple.v2().equals(aliasId))) {
index = i;
break;
}
}
if (index > -1) {
mask.set(index);
} else {
throw new SqlIllegalArgumentException("Cannot resolve field extractor index for column [{}]", column);
}
}
return mask;
} }
public Query query() { public Query query() {
@ -99,11 +198,11 @@ public class QueryContainer {
return aggs; return aggs;
} }
public List<FieldExtraction> columns() { public List<Tuple<FieldExtraction, ExpressionId>> fields() {
return columns; return fields;
} }
public Map<Attribute, Attribute> aliases() { public AttributeMap<Attribute> aliases() {
return aliases; return aliases;
} }
@ -120,11 +219,15 @@ public class QueryContainer {
} }
public boolean isAggsOnly() { public boolean isAggsOnly() {
return aggsOnly; if (aggsOnly == null) {
aggsOnly = Boolean.valueOf(this.fields.stream().allMatch(t -> t.v1().supportedByAggsOnlyQuery()));
}
return aggsOnly.booleanValue();
} }
public boolean hasColumns() { public boolean hasColumns() {
return !columns.isEmpty(); return fields.size() > 0;
} }
// //
@ -132,37 +235,33 @@ public class QueryContainer {
// //
public QueryContainer with(Query q) { public QueryContainer with(Query q) {
return new QueryContainer(q, aggs, columns, aliases, pseudoFunctions, scalarFunctions, sort, limit); return new QueryContainer(q, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit);
} }
public QueryContainer with(List<FieldExtraction> r) { public QueryContainer withAliases(AttributeMap<Attribute> a) {
return new QueryContainer(query, aggs, r, aliases, pseudoFunctions, scalarFunctions, sort, limit); return new QueryContainer(query, aggs, fields, a, pseudoFunctions, scalarFunctions, sort, limit);
}
public QueryContainer withAliases(Map<Attribute, Attribute> a) {
return new QueryContainer(query, aggs, columns, a, pseudoFunctions, scalarFunctions, sort, limit);
} }
public QueryContainer withPseudoFunctions(Map<String, GroupByKey> p) { public QueryContainer withPseudoFunctions(Map<String, GroupByKey> p) {
return new QueryContainer(query, aggs, columns, aliases, p, scalarFunctions, sort, limit); return new QueryContainer(query, aggs, fields, aliases, p, scalarFunctions, sort, limit);
} }
public QueryContainer with(Aggs a) { public QueryContainer with(Aggs a) {
return new QueryContainer(query, a, columns, aliases, pseudoFunctions, scalarFunctions, sort, limit); return new QueryContainer(query, a, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit);
} }
public QueryContainer withLimit(int l) { public QueryContainer withLimit(int l) {
return l == limit ? this : new QueryContainer(query, aggs, columns, aliases, pseudoFunctions, scalarFunctions, sort, l); return l == limit ? this : new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, l);
} }
public QueryContainer withScalarProcessors(Map<Attribute, Pipe> procs) { public QueryContainer withScalarProcessors(AttributeMap<Pipe> procs) {
return new QueryContainer(query, aggs, columns, aliases, pseudoFunctions, procs, sort, limit); return new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, procs, sort, limit);
} }
public QueryContainer sort(Sort sortable) { public QueryContainer addSort(Sort sortable) {
Set<Sort> sort = new LinkedHashSet<>(this.sort); Set<Sort> sort = new LinkedHashSet<>(this.sort);
sort.add(sortable); sort.add(sortable);
return new QueryContainer(query, aggs, columns, aliases, pseudoFunctions, scalarFunctions, sort, limit); return new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit);
} }
private String aliasName(Attribute attr) { private String aliasName(Attribute attr) {
@ -188,7 +287,8 @@ public class QueryContainer {
attr.field().isAggregatable(), attr.parent().name()); attr.field().isAggregatable(), attr.parent().name());
nestedRefs.add(nestedFieldRef); nestedRefs.add(nestedFieldRef);
return new Tuple<>(new QueryContainer(q, aggs, columns, aliases, pseudoFunctions, scalarFunctions, sort, limit), nestedFieldRef); return new Tuple<>(new QueryContainer(q, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit),
nestedFieldRef);
} }
static Query rewriteToContainNestedField(@Nullable Query query, Source source, String path, String name, String format, static Query rewriteToContainNestedField(@Nullable Query query, Source source, String path, String name, String format,
@ -255,13 +355,13 @@ public class QueryContainer {
// update proc // update proc
Map<Attribute, Pipe> procs = new LinkedHashMap<>(qContainer.scalarFunctions()); Map<Attribute, Pipe> procs = new LinkedHashMap<>(qContainer.scalarFunctions());
procs.put(attribute, proc); procs.put(attribute, proc);
qContainer = qContainer.withScalarProcessors(procs); qContainer = qContainer.withScalarProcessors(new AttributeMap<>(procs));
return new Tuple<>(qContainer, new ComputedRef(proc)); return new Tuple<>(qContainer, new ComputedRef(proc));
} }
public QueryContainer addColumn(Attribute attr) { public QueryContainer addColumn(Attribute attr) {
Tuple<QueryContainer, FieldExtraction> tuple = toReference(attr); Tuple<QueryContainer, FieldExtraction> tuple = toReference(attr);
return tuple.v1().addColumn(tuple.v2()); return tuple.v1().addColumn(tuple.v2(), attr);
} }
private Tuple<QueryContainer, FieldExtraction> toReference(Attribute attr) { private Tuple<QueryContainer, FieldExtraction> toReference(Attribute attr) {
@ -286,11 +386,14 @@ public class QueryContainer {
throw new SqlIllegalArgumentException("Unknown output attribute {}", attr); throw new SqlIllegalArgumentException("Unknown output attribute {}", attr);
} }
public QueryContainer addColumn(FieldExtraction ref) { public QueryContainer addColumn(FieldExtraction ref, Attribute attr) {
return with(combine(columns, ref)); ExpressionId id = attr instanceof AggregateFunctionAttribute ? ((AggregateFunctionAttribute) attr).innerId() : attr.id();
return new QueryContainer(query, aggs, combine(fields, new Tuple<>(ref, id)), aliases, pseudoFunctions,
scalarFunctions,
sort, limit);
} }
public Map<Attribute, Pipe> scalarFunctions() { public AttributeMap<Pipe> scalarFunctions() {
return scalarFunctions; return scalarFunctions;
} }
@ -298,11 +401,14 @@ public class QueryContainer {
// agg methods // agg methods
// //
public QueryContainer addAggCount(GroupByKey group, String functionId) { public QueryContainer addAggCount(GroupByKey group, ExpressionId functionId) {
FieldExtraction ref = group == null ? GlobalCountRef.INSTANCE : new GroupByRef(group.id(), Property.COUNT, null); FieldExtraction ref = group == null ? GlobalCountRef.INSTANCE : new GroupByRef(group.id(), Property.COUNT, null);
Map<String, GroupByKey> pseudoFunctions = new LinkedHashMap<>(this.pseudoFunctions); Map<String, GroupByKey> pseudoFunctions = new LinkedHashMap<>(this.pseudoFunctions);
pseudoFunctions.put(functionId, group); pseudoFunctions.put(functionId.toString(), group);
return new QueryContainer(query, aggs, combine(columns, ref), aliases, pseudoFunctions, scalarFunctions, sort, limit); return new QueryContainer(query, aggs, combine(fields, new Tuple<>(ref, functionId)),
aliases,
pseudoFunctions,
scalarFunctions, sort, limit);
} }
public QueryContainer addAgg(String groupId, LeafAgg agg) { public QueryContainer addAgg(String groupId, LeafAgg agg) {
@ -327,7 +433,7 @@ public class QueryContainer {
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(query, aggs, columns, aliases); return Objects.hash(query, aggs, fields, aliases, sort, limit);
} }
@Override @Override
@ -343,7 +449,7 @@ public class QueryContainer {
QueryContainer other = (QueryContainer) obj; QueryContainer other = (QueryContainer) obj;
return Objects.equals(query, other.query) return Objects.equals(query, other.query)
&& Objects.equals(aggs, other.aggs) && Objects.equals(aggs, other.aggs)
&& Objects.equals(columns, other.columns) && Objects.equals(fields, other.fields)
&& Objects.equals(aliases, other.aliases) && Objects.equals(aliases, other.aliases)
&& Objects.equals(sort, other.sort) && Objects.equals(sort, other.sort)
&& Objects.equals(limit, other.limit); && Objects.equals(limit, other.limit);

View File

@ -9,7 +9,7 @@ import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.sql.expression.Order.NullsPosition; import org.elasticsearch.xpack.sql.expression.Order.NullsPosition;
import org.elasticsearch.xpack.sql.expression.Order.OrderDirection; import org.elasticsearch.xpack.sql.expression.Order.OrderDirection;
public class Sort { public abstract class Sort {
public enum Direction { public enum Direction {
ASC, DESC; ASC, DESC;

View File

@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.execution.search.CompositeAggregationCursor; import org.elasticsearch.xpack.sql.execution.search.CompositeAggregationCursor;
import org.elasticsearch.xpack.sql.execution.search.PagingListCursor;
import org.elasticsearch.xpack.sql.execution.search.ScrollCursor; import org.elasticsearch.xpack.sql.execution.search.ScrollCursor;
import org.elasticsearch.xpack.sql.execution.search.extractor.BucketExtractors; import org.elasticsearch.xpack.sql.execution.search.extractor.BucketExtractors;
import org.elasticsearch.xpack.sql.execution.search.extractor.HitExtractors; import org.elasticsearch.xpack.sql.execution.search.extractor.HitExtractors;
@ -48,6 +49,7 @@ public final class Cursors {
entries.add(new NamedWriteableRegistry.Entry(Cursor.class, ScrollCursor.NAME, ScrollCursor::new)); entries.add(new NamedWriteableRegistry.Entry(Cursor.class, ScrollCursor.NAME, ScrollCursor::new));
entries.add(new NamedWriteableRegistry.Entry(Cursor.class, CompositeAggregationCursor.NAME, CompositeAggregationCursor::new)); entries.add(new NamedWriteableRegistry.Entry(Cursor.class, CompositeAggregationCursor.NAME, CompositeAggregationCursor::new));
entries.add(new NamedWriteableRegistry.Entry(Cursor.class, TextFormatterCursor.NAME, TextFormatterCursor::new)); entries.add(new NamedWriteableRegistry.Entry(Cursor.class, TextFormatterCursor.NAME, TextFormatterCursor::new));
entries.add(new NamedWriteableRegistry.Entry(Cursor.class, PagingListCursor.NAME, PagingListCursor::new));
// plus all their dependencies // plus all their dependencies
entries.addAll(Processors.getNamedWriteables()); entries.addAll(Processors.getNamedWriteables());

View File

@ -7,10 +7,10 @@ package org.elasticsearch.xpack.sql.session;
import org.elasticsearch.xpack.sql.type.Schema; import org.elasticsearch.xpack.sql.type.Schema;
class EmptyRowSetCursor extends AbstractRowSet implements SchemaRowSet { class EmptyRowSet extends AbstractRowSet implements SchemaRowSet {
private final Schema schema; private final Schema schema;
EmptyRowSetCursor(Schema schema) { EmptyRowSet(Schema schema) {
this.schema = schema; this.schema = schema;
} }

View File

@ -9,25 +9,25 @@ import org.elasticsearch.xpack.sql.type.Schema;
import java.util.List; import java.util.List;
class ListRowSetCursor extends AbstractRowSet implements SchemaRowSet { public class ListRowSet extends AbstractRowSet implements SchemaRowSet {
private final Schema schema; private final Schema schema;
private final List<List<?>> list; private final List<List<?>> list;
private int pos = 0; private int pos = 0;
ListRowSetCursor(Schema schema, List<List<?>> list) { protected ListRowSet(Schema schema, List<List<?>> list) {
this.schema = schema; this.schema = schema;
this.list = list; this.list = list;
} }
@Override @Override
protected boolean doHasCurrent() { protected boolean doHasCurrent() {
return pos < list.size(); return pos < size();
} }
@Override @Override
protected boolean doNext() { protected boolean doNext() {
if (pos + 1 < list.size()) { if (pos + 1 < size()) {
pos++; pos++;
return true; return true;
} }

View File

@ -36,7 +36,7 @@ public abstract class Rows {
} }
Schema schema = schema(attrs); Schema schema = schema(attrs);
return new ListRowSetCursor(schema, values); return new ListRowSet(schema, values);
} }
public static SchemaRowSet singleton(List<Attribute> attrs, Object... values) { public static SchemaRowSet singleton(List<Attribute> attrs, Object... values) {
@ -49,10 +49,10 @@ public abstract class Rows {
} }
public static SchemaRowSet empty(Schema schema) { public static SchemaRowSet empty(Schema schema) {
return new EmptyRowSetCursor(schema); return new EmptyRowSet(schema);
} }
public static SchemaRowSet empty(List<Attribute> attrs) { public static SchemaRowSet empty(List<Attribute> attrs) {
return new EmptyRowSetCursor(schema(attrs)); return new EmptyRowSet(schema(attrs));
} }
} }

View File

@ -19,6 +19,6 @@ public interface SchemaRowSet extends RowSet {
@Override @Override
default int columnCount() { default int columnCount() {
return schema().names().size(); return schema().size();
} }
} }

View File

@ -15,6 +15,7 @@ import org.elasticsearch.xpack.sql.analysis.analyzer.Verifier;
import org.elasticsearch.xpack.sql.analysis.index.IndexResolution; import org.elasticsearch.xpack.sql.analysis.index.IndexResolution;
import org.elasticsearch.xpack.sql.analysis.index.IndexResolver; import org.elasticsearch.xpack.sql.analysis.index.IndexResolver;
import org.elasticsearch.xpack.sql.analysis.index.MappingException; import org.elasticsearch.xpack.sql.analysis.index.MappingException;
import org.elasticsearch.xpack.sql.execution.PlanExecutor;
import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry;
import org.elasticsearch.xpack.sql.optimizer.Optimizer; import org.elasticsearch.xpack.sql.optimizer.Optimizer;
import org.elasticsearch.xpack.sql.parser.SqlParser; import org.elasticsearch.xpack.sql.parser.SqlParser;
@ -40,20 +41,17 @@ public class SqlSession {
private final Verifier verifier; private final Verifier verifier;
private final Optimizer optimizer; private final Optimizer optimizer;
private final Planner planner; private final Planner planner;
private final PlanExecutor planExecutor;
private final Configuration configuration; private final Configuration configuration;
public SqlSession(SqlSession other) {
this(other.configuration, other.client, other.functionRegistry, other.indexResolver,
other.preAnalyzer, other.verifier, other.optimizer, other.planner);
}
public SqlSession(Configuration configuration, Client client, FunctionRegistry functionRegistry, public SqlSession(Configuration configuration, Client client, FunctionRegistry functionRegistry,
IndexResolver indexResolver, IndexResolver indexResolver,
PreAnalyzer preAnalyzer, PreAnalyzer preAnalyzer,
Verifier verifier, Verifier verifier,
Optimizer optimizer, Optimizer optimizer,
Planner planner) { Planner planner,
PlanExecutor planExecutor) {
this.client = client; this.client = client;
this.functionRegistry = functionRegistry; this.functionRegistry = functionRegistry;
@ -64,6 +62,7 @@ public class SqlSession {
this.verifier = verifier; this.verifier = verifier;
this.configuration = configuration; this.configuration = configuration;
this.planExecutor = planExecutor;
} }
public FunctionRegistry functionRegistry() { public FunctionRegistry functionRegistry() {
@ -90,6 +89,10 @@ public class SqlSession {
return verifier; return verifier;
} }
public PlanExecutor planExecutor() {
return planExecutor;
}
private LogicalPlan doParse(String sql, List<SqlTypedParamValue> params) { private LogicalPlan doParse(String sql, List<SqlTypedParamValue> params) {
return new SqlParser().createStatement(sql, params); return new SqlParser().createStatement(sql, params);
} }

View File

@ -281,6 +281,14 @@ public abstract class Node<T extends Node<T>> {
return getClass().getSimpleName(); return getClass().getSimpleName();
} }
/**
* The values of all the properties that are important
* to this {@link Node}.
*/
public List<Object> nodeProperties() {
return info().properties();
}
public String nodeString() { public String nodeString() {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
sb.append(nodeName()); sb.append(nodeName());
@ -349,7 +357,6 @@ public abstract class Node<T extends Node<T>> {
* {@code [} and {@code ]} of the output of {@link #treeString}. * {@code [} and {@code ]} of the output of {@link #treeString}.
*/ */
public String propertiesToString(boolean skipIfChild) { public String propertiesToString(boolean skipIfChild) {
NodeInfo<? extends Node<T>> info = info();
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
List<?> children = children(); List<?> children = children();
@ -358,7 +365,7 @@ public abstract class Node<T extends Node<T>> {
int maxWidth = 0; int maxWidth = 0;
boolean needsComma = false; boolean needsComma = false;
List<Object> props = info.properties(); List<Object> props = nodeProperties();
for (Object prop : props) { for (Object prop : props) {
// consider a property if it is not ignored AND // consider a property if it is not ignored AND
// it's not a child (optional) // it's not a child (optional)
@ -388,12 +395,4 @@ public abstract class Node<T extends Node<T>> {
return sb.toString(); return sb.toString();
} }
/**
* The values of all the properties that are important
* to this {@link Node}.
*/
public List<Object> properties() {
return info().properties();
}
} }

View File

@ -349,6 +349,54 @@ public abstract class NodeInfo<T extends Node<?>> {
T apply(Source l, P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6, P7 p7, P8 p8); T apply(Source l, P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6, P7 p7, P8 p8);
} }
public static <T extends Node<?>, P1, P2, P3, P4, P5, P6, P7, P8, P9> NodeInfo<T> create(
T n, NodeCtor9<P1, P2, P3, P4, P5, P6, P7, P8, P9, T> ctor,
P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6, P7 p7, P8 p8, P9 p9) {
return new NodeInfo<T>(n) {
@Override
protected List<Object> innerProperties() {
return Arrays.asList(p1, p2, p3, p4, p5, p6, p7, p8, p9);
}
protected T innerTransform(Function<Object, Object> rule) {
boolean same = true;
@SuppressWarnings("unchecked")
P1 newP1 = (P1) rule.apply(p1);
same &= Objects.equals(p1, newP1);
@SuppressWarnings("unchecked")
P2 newP2 = (P2) rule.apply(p2);
same &= Objects.equals(p2, newP2);
@SuppressWarnings("unchecked")
P3 newP3 = (P3) rule.apply(p3);
same &= Objects.equals(p3, newP3);
@SuppressWarnings("unchecked")
P4 newP4 = (P4) rule.apply(p4);
same &= Objects.equals(p4, newP4);
@SuppressWarnings("unchecked")
P5 newP5 = (P5) rule.apply(p5);
same &= Objects.equals(p5, newP5);
@SuppressWarnings("unchecked")
P6 newP6 = (P6) rule.apply(p6);
same &= Objects.equals(p6, newP6);
@SuppressWarnings("unchecked")
P7 newP7 = (P7) rule.apply(p7);
same &= Objects.equals(p7, newP7);
@SuppressWarnings("unchecked")
P8 newP8 = (P8) rule.apply(p8);
same &= Objects.equals(p8, newP8);
@SuppressWarnings("unchecked")
P9 newP9 = (P9) rule.apply(p9);
same &= Objects.equals(p9, newP9);
return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5, newP6, newP7, newP8, newP9);
}
};
}
public interface NodeCtor9<P1, P2, P3, P4, P5, P6, P7, P8, P9, T> {
T apply(Source l, P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6, P7 p7, P8 p8, P9 p9);
}
public static <T extends Node<?>, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10> NodeInfo<T> create( public static <T extends Node<?>, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10> NodeInfo<T> create(
T n, NodeCtor10<P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, T> ctor, T n, NodeCtor10<P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, T> ctor,
P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6, P7 p7, P8 p8, P9 p9, P10 p10) { P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6, P7 p7, P8 p8, P9 p9, P10 p10) {

View File

@ -132,7 +132,7 @@ public abstract class Graphviz {
+ "</b></td></th>\n"); + "</b></td></th>\n");
indent(nodeInfo, currentIndent + NODE_LABEL_INDENT); indent(nodeInfo, currentIndent + NODE_LABEL_INDENT);
List<Object> props = n.properties(); List<Object> props = n.nodeProperties();
List<String> parsed = new ArrayList<>(props.size()); List<String> parsed = new ArrayList<>(props.size());
List<Node<?>> subTrees = new ArrayList<>(); List<Node<?>> subTrees = new ArrayList<>();

View File

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.util;
/**
* Simply utility class used for setting a state, typically
* for closures (which require outside variables to be final).
*/
public class Holder<T> {
private T value = null;
public Holder() {
}
public Holder(T value) {
this.value = value;
}
public void set(T value) {
this.value = value;
}
public T get() {
return value;
}
}

View File

@ -175,12 +175,10 @@ public class VerifierErrorMessagesTests extends ESTestCase {
} }
public void testMissingColumnInOrderBy() { public void testMissingColumnInOrderBy() {
// xxx offset is that of the order by field
assertEquals("1:29: Unknown column [xxx]", error("SELECT * FROM test ORDER BY xxx")); assertEquals("1:29: Unknown column [xxx]", error("SELECT * FROM test ORDER BY xxx"));
} }
public void testMissingColumnFunctionInOrderBy() { public void testMissingColumnFunctionInOrderBy() {
// xxx offset is that of the order by field
assertEquals("1:41: Unknown column [xxx]", error("SELECT * FROM test ORDER BY DAY_oF_YEAR(xxx)")); assertEquals("1:41: Unknown column [xxx]", error("SELECT * FROM test ORDER BY DAY_oF_YEAR(xxx)"));
} }
@ -208,7 +206,6 @@ public class VerifierErrorMessagesTests extends ESTestCase {
} }
public void testMultipleColumns() { public void testMultipleColumns() {
// xxx offset is that of the order by field
assertEquals("1:43: Unknown column [xxx]\nline 1:8: Unknown column [xxx]", assertEquals("1:43: Unknown column [xxx]\nline 1:8: Unknown column [xxx]",
error("SELECT xxx FROM test GROUP BY DAY_oF_YEAR(xxx)")); error("SELECT xxx FROM test GROUP BY DAY_oF_YEAR(xxx)"));
} }
@ -248,7 +245,7 @@ public class VerifierErrorMessagesTests extends ESTestCase {
} }
public void testGroupByOrderByScalarOverNonGrouped() { public void testGroupByOrderByScalarOverNonGrouped() {
assertEquals("1:50: Cannot order by non-grouped column [YEAR(date)], expected [text]", assertEquals("1:50: Cannot order by non-grouped column [YEAR(date)], expected [text] or an aggregate function",
error("SELECT MAX(int) FROM test GROUP BY text ORDER BY YEAR(date)")); error("SELECT MAX(int) FROM test GROUP BY text ORDER BY YEAR(date)"));
} }
@ -258,7 +255,7 @@ public class VerifierErrorMessagesTests extends ESTestCase {
} }
public void testGroupByOrderByScalarOverNonGrouped_WithHaving() { public void testGroupByOrderByScalarOverNonGrouped_WithHaving() {
assertEquals("1:71: Cannot order by non-grouped column [YEAR(date)], expected [text]", assertEquals("1:71: Cannot order by non-grouped column [YEAR(date)], expected [text] or an aggregate function",
error("SELECT MAX(int) FROM test GROUP BY text HAVING MAX(int) > 10 ORDER BY YEAR(date)")); error("SELECT MAX(int) FROM test GROUP BY text HAVING MAX(int) > 10 ORDER BY YEAR(date)"));
} }
@ -316,18 +313,25 @@ public class VerifierErrorMessagesTests extends ESTestCase {
error("SELECT * FROM test ORDER BY unsupported")); error("SELECT * FROM test ORDER BY unsupported"));
} }
public void testGroupByOrderByNonKey() { public void testGroupByOrderByAggregate() {
assertEquals("1:52: Cannot order by non-grouped column [a], expected [bool]", accept("SELECT AVG(int) a FROM test GROUP BY bool ORDER BY a");
error("SELECT AVG(int) a FROM test GROUP BY bool ORDER BY a"));
} }
public void testGroupByOrderByFunctionOverKey() { public void testGroupByOrderByAggs() {
assertEquals("1:44: Cannot order by non-grouped column [MAX(int)], expected [int]", accept("SELECT int FROM test GROUP BY int ORDER BY COUNT(*)");
error("SELECT int FROM test GROUP BY int ORDER BY MAX(int)")); }
public void testGroupByOrderByAggAndGroupedColumn() {
accept("SELECT int FROM test GROUP BY int ORDER BY int, MAX(int)");
}
public void testGroupByOrderByNonAggAndNonGroupedColumn() {
assertEquals("1:44: Cannot order by non-grouped column [bool], expected [int]",
error("SELECT int FROM test GROUP BY int ORDER BY bool"));
} }
public void testGroupByOrderByScore() { public void testGroupByOrderByScore() {
assertEquals("1:44: Cannot order by non-grouped column [SCORE()], expected [int]", assertEquals("1:44: Cannot order by non-grouped column [SCORE()], expected [int] or an aggregate function",
error("SELECT int FROM test GROUP BY int ORDER BY SCORE()")); error("SELECT int FROM test GROUP BY int ORDER BY SCORE()"));
} }

View File

@ -17,6 +17,7 @@ import org.elasticsearch.xpack.sql.session.Cursors;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.BitSet;
import java.util.List; import java.util.List;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -27,7 +28,9 @@ public class CompositeAggregationCursorTests extends AbstractWireSerializingTest
for (int i = 0; i < extractorsSize; i++) { for (int i = 0; i < extractorsSize; i++) {
extractors.add(randomBucketExtractor()); extractors.add(randomBucketExtractor());
} }
return new CompositeAggregationCursor(new byte[randomInt(256)], extractors, randomIntBetween(10, 1024), randomAlphaOfLength(5));
return new CompositeAggregationCursor(new byte[randomInt(256)], extractors, randomBitSet(extractorsSize),
randomIntBetween(10, 1024), randomAlphaOfLength(5));
} }
static BucketExtractor randomBucketExtractor() { static BucketExtractor randomBucketExtractor() {
@ -41,7 +44,9 @@ public class CompositeAggregationCursorTests extends AbstractWireSerializingTest
@Override @Override
protected CompositeAggregationCursor mutateInstance(CompositeAggregationCursor instance) throws IOException { protected CompositeAggregationCursor mutateInstance(CompositeAggregationCursor instance) throws IOException {
return new CompositeAggregationCursor(instance.next(), instance.extractors(), return new CompositeAggregationCursor(instance.next(), instance.extractors(),
randomValueOtherThan(instance.limit(), () -> randomIntBetween(1, 512)), instance.indices()); randomValueOtherThan(instance.mask(), () -> randomBitSet(instance.extractors().size())),
randomValueOtherThan(instance.limit(), () -> randomIntBetween(1, 512)),
instance.indices());
} }
@Override @Override
@ -68,4 +73,12 @@ public class CompositeAggregationCursorTests extends AbstractWireSerializingTest
} }
return (CompositeAggregationCursor) Cursors.decodeFromString(Cursors.encodeToString(version, instance)); return (CompositeAggregationCursor) Cursors.decodeFromString(Cursors.encodeToString(version, instance));
} }
static BitSet randomBitSet(int size) {
BitSet mask = new BitSet(size);
for (int i = 0; i < size; i++) {
mask.set(i, randomBoolean());
}
return mask;
}
} }

View File

@ -23,17 +23,18 @@ import org.elasticsearch.xpack.sql.session.Cursors;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.Supplier; import java.util.function.Supplier;
import static org.elasticsearch.action.support.PlainActionFuture.newFuture; import static org.elasticsearch.action.support.PlainActionFuture.newFuture;
import static org.elasticsearch.xpack.sql.action.BasicFormatter.FormatOption.CLI;
import static org.elasticsearch.xpack.sql.action.BasicFormatter.FormatOption.TEXT;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
import static org.elasticsearch.xpack.sql.action.BasicFormatter.FormatOption.CLI;
import static org.elasticsearch.xpack.sql.action.BasicFormatter.FormatOption.TEXT;
public class CursorTests extends ESTestCase { public class CursorTests extends ESTestCase {
@ -51,7 +52,7 @@ public class CursorTests extends ESTestCase {
Client clientMock = mock(Client.class); Client clientMock = mock(Client.class);
ActionListener<Boolean> listenerMock = mock(ActionListener.class); ActionListener<Boolean> listenerMock = mock(ActionListener.class);
String cursorString = randomAlphaOfLength(10); String cursorString = randomAlphaOfLength(10);
Cursor cursor = new ScrollCursor(cursorString, Collections.emptyList(), randomInt()); Cursor cursor = new ScrollCursor(cursorString, Collections.emptyList(), new BitSet(0), randomInt());
cursor.clear(TestUtils.TEST_CFG, clientMock, listenerMock); cursor.clear(TestUtils.TEST_CFG, clientMock, listenerMock);

View File

@ -0,0 +1,63 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.execution.search;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.sql.session.Cursors;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class PagingListCursorTests extends AbstractWireSerializingTestCase<PagingListCursor> {
public static PagingListCursor randomPagingListCursor() {
int size = between(1, 20);
int depth = between(1, 20);
List<List<?>> values = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
values.add(Arrays.asList(randomArray(depth, s -> new Object[depth], () -> randomByte())));
}
return new PagingListCursor(values, depth, between(1, 20));
}
@Override
protected PagingListCursor mutateInstance(PagingListCursor instance) throws IOException {
return new PagingListCursor(instance.data(),
instance.columnCount(),
randomValueOtherThan(instance.pageSize(), () -> between(1, 20)));
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(Cursors.getNamedWriteables());
}
@Override
protected PagingListCursor createTestInstance() {
return randomPagingListCursor();
}
@Override
protected Reader<PagingListCursor> instanceReader() {
return PagingListCursor::new;
}
@Override
protected PagingListCursor copyInstance(PagingListCursor instance, Version version) throws IOException {
/* Randomly choose between internal protocol round trip and String based
* round trips used to toXContent. */
if (randomBoolean()) {
return super.copyInstance(instance, version);
}
return (PagingListCursor) Cursors.decodeFromString(Cursors.encodeToString(version, instance));
}
}

View File

@ -27,7 +27,8 @@ public class ScrollCursorTests extends AbstractWireSerializingTestCase<ScrollCur
for (int i = 0; i < extractorsSize; i++) { for (int i = 0; i < extractorsSize; i++) {
extractors.add(randomHitExtractor(0)); extractors.add(randomHitExtractor(0));
} }
return new ScrollCursor(randomAlphaOfLength(5), extractors, randomIntBetween(10, 1024)); return new ScrollCursor(randomAlphaOfLength(5), extractors, CompositeAggregationCursorTests.randomBitSet(extractorsSize),
randomIntBetween(10, 1024));
} }
static HitExtractor randomHitExtractor(int depth) { static HitExtractor randomHitExtractor(int depth) {
@ -43,6 +44,7 @@ public class ScrollCursorTests extends AbstractWireSerializingTestCase<ScrollCur
@Override @Override
protected ScrollCursor mutateInstance(ScrollCursor instance) throws IOException { protected ScrollCursor mutateInstance(ScrollCursor instance) throws IOException {
return new ScrollCursor(instance.scrollId(), instance.extractors(), return new ScrollCursor(instance.scrollId(), instance.extractors(),
randomValueOtherThan(instance.mask(), () -> CompositeAggregationCursorTests.randomBitSet(instance.extractors().size())),
randomValueOtherThan(instance.limit(), () -> randomIntBetween(1, 1024))); randomValueOtherThan(instance.limit(), () -> randomIntBetween(1, 1024)));
} }

View File

@ -85,7 +85,7 @@ public class SourceGeneratorTests extends ESTestCase {
public void testSortScoreSpecified() { public void testSortScoreSpecified() {
QueryContainer container = new QueryContainer() QueryContainer container = new QueryContainer()
.sort(new ScoreSort(Direction.DESC, null)); .addSort(new ScoreSort(Direction.DESC, null));
SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10)); SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10));
assertEquals(singletonList(scoreSort()), sourceBuilder.sorts()); assertEquals(singletonList(scoreSort()), sourceBuilder.sorts());
} }
@ -94,13 +94,13 @@ public class SourceGeneratorTests extends ESTestCase {
FieldSortBuilder sortField = fieldSort("test").unmappedType("keyword"); FieldSortBuilder sortField = fieldSort("test").unmappedType("keyword");
QueryContainer container = new QueryContainer() QueryContainer container = new QueryContainer()
.sort(new AttributeSort(new FieldAttribute(Source.EMPTY, "test", new KeywordEsField("test")), Direction.ASC, .addSort(new AttributeSort(new FieldAttribute(Source.EMPTY, "test", new KeywordEsField("test")), Direction.ASC,
Missing.LAST)); Missing.LAST));
SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10)); SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10));
assertEquals(singletonList(sortField.order(SortOrder.ASC).missing("_last")), sourceBuilder.sorts()); assertEquals(singletonList(sortField.order(SortOrder.ASC).missing("_last")), sourceBuilder.sorts());
container = new QueryContainer() container = new QueryContainer()
.sort(new AttributeSort(new FieldAttribute(Source.EMPTY, "test", new KeywordEsField("test")), Direction.DESC, .addSort(new AttributeSort(new FieldAttribute(Source.EMPTY, "test", new KeywordEsField("test")), Direction.DESC,
Missing.FIRST)); Missing.FIRST));
sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10)); sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10));
assertEquals(singletonList(sortField.order(SortOrder.DESC).missing("_first")), sourceBuilder.sorts()); assertEquals(singletonList(sortField.order(SortOrder.DESC).missing("_first")), sourceBuilder.sorts());

View File

@ -306,7 +306,7 @@ public class SqlParserTests extends ESTestCase {
In in = (In) filter.condition(); In in = (In) filter.condition();
assertEquals("?a", in.value().toString()); assertEquals("?a", in.value().toString());
assertEquals(noChildren, in.list().size()); assertEquals(noChildren, in.list().size());
assertThat(in.list().get(0).toString(), startsWith("a + b#")); assertThat(in.list().get(0).toString(), startsWith("Add[?a,?b]"));
} }
public void testDecrementOfDepthCounter() { public void testDecrementOfDepthCounter() {

View File

@ -53,7 +53,7 @@ public class SysParserTests extends ESTestCase {
return Void.TYPE; return Void.TYPE;
}).when(resolver).resolveAsSeparateMappings(any(), any(), any()); }).when(resolver).resolveAsSeparateMappings(any(), any(), any());
SqlSession session = new SqlSession(TestUtils.TEST_CFG, null, null, resolver, null, null, null, null); SqlSession session = new SqlSession(TestUtils.TEST_CFG, null, null, resolver, null, null, null, null, null);
return new Tuple<>(cmd, session); return new Tuple<>(cmd, session);
} }

View File

@ -243,7 +243,7 @@ public class SysTablesTests extends ESTestCase {
IndexResolver resolver = mock(IndexResolver.class); IndexResolver resolver = mock(IndexResolver.class);
when(resolver.clusterName()).thenReturn(CLUSTER_NAME); when(resolver.clusterName()).thenReturn(CLUSTER_NAME);
SqlSession session = new SqlSession(null, null, null, resolver, null, null, null, null); SqlSession session = new SqlSession(null, null, null, resolver, null, null, null, null, null);
return new Tuple<>(cmd, session); return new Tuple<>(cmd, session);
} }

View File

@ -36,7 +36,7 @@ public class SysTypesTests extends ESTestCase {
Command cmd = (Command) analyzer.analyze(parser.createStatement(sql), false); Command cmd = (Command) analyzer.analyze(parser.createStatement(sql), false);
IndexResolver resolver = mock(IndexResolver.class); IndexResolver resolver = mock(IndexResolver.class);
SqlSession session = new SqlSession(null, null, null, resolver, null, null, null, null); SqlSession session = new SqlSession(null, null, null, resolver, null, null, null, null, null);
return new Tuple<>(cmd, session); return new Tuple<>(cmd, session);
} }

View File

@ -253,8 +253,8 @@ public class NodeSubclassTests<T extends B, B extends Node<B>> extends ESTestCas
* the one property of the node that we intended to transform. * the one property of the node that we intended to transform.
*/ */
assertEquals(node.source(), transformed.source()); assertEquals(node.source(), transformed.source());
List<Object> op = node.properties(); List<Object> op = node.nodeProperties();
List<Object> tp = transformed.properties(); List<Object> tp = transformed.nodeProperties();
for (int p = 0; p < op.size(); p++) { for (int p = 0; p < op.size(); p++) {
if (p == changedArgOffset - 1) { // -1 because location isn't in the list if (p == changedArgOffset - 1) { // -1 because location isn't in the list
assertEquals(changedArgValue, tp.get(p)); assertEquals(changedArgValue, tp.get(p));