SQL: Allow sorting of groups by aggregates (#38042)

Introduce client-side sorting of groups based on aggregate
functions. To allow this, the Analyzer has been extended to push down
to underlying Aggregate, aggregate function and the Querier has been
extended to identify the case and consume the results in order and sort
them based on the given columns.
The underlying QueryContainer has been slightly modified to allow a view
of the underlying values being extracted as the columns used for sorting
might not be requested by the user.

The PR also adds minor tweaks, mainly related to tree output.

Close #35118
This commit is contained in:
Costin Leau 2019-02-02 01:38:25 +02:00 committed by GitHub
parent 630889baec
commit 783c9ed372
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 1352 additions and 410 deletions

View File

@ -67,8 +67,18 @@ a field is an array (has multiple values) or not, so without reading all the dat
=== Sorting by aggregation
When doing aggregations (`GROUP BY`) {es-sql} relies on {es}'s `composite` aggregation for its support for paginating results.
But this type of aggregation does come with a limitation: sorting can only be applied on the key used for the aggregation's buckets. This
means that queries like `SELECT * FROM test GROUP BY age ORDER BY COUNT(*)` are not possible.
However this type of aggregation does come with a limitation: sorting can only be applied on the key used for the aggregation's buckets.
{es-sql} overcomes this limitation by doing client-side sorting however as a safety measure, allows only up to *512* rows.
It is recommended to use `LIMIT` for queries that use sorting by aggregation, essentially indicating the top N results that are desired:
[source, sql]
--------------------------------------------------
SELECT * FROM test GROUP BY age ORDER BY COUNT(*) LIMIT 100;
--------------------------------------------------
It is possible to run the same queries without a `LIMIT` however in that case if the maximum size (*512*) is passed, an exception will be
returned as {es-sql} is unable to track (and sort) all the results returned.
[float]
=== Using aggregation functions on top of scalar functions

View File

@ -20,7 +20,7 @@ public class CliExplainIT extends CliIntegrationTestCase {
assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("With[{}]"));
assertThat(readLine(), startsWith("\\_Project[[?*]]"));
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[[][index=test],null,Unknown index [test]]"));
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[test]"));
assertEquals("", readLine());
assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT * FROM test"), containsString("plan"));
@ -64,22 +64,22 @@ public class CliExplainIT extends CliIntegrationTestCase {
assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("With[{}]"));
assertThat(readLine(), startsWith("\\_Project[[?*]]"));
assertThat(readLine(), startsWith(" \\_Filter[i = 2#"));
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[[][index=test],null,Unknown index [test]]"));
assertThat(readLine(), startsWith(" \\_Filter[Equals[?i,2"));
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[test]"));
assertEquals("", readLine());
assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT * FROM test WHERE i = 2"),
containsString("plan"));
assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("Project[[i{f}#"));
assertThat(readLine(), startsWith("\\_Filter[i = 2#"));
assertThat(readLine(), startsWith("\\_Filter[Equals[i"));
assertThat(readLine(), startsWith(" \\_EsRelation[test][i{f}#"));
assertEquals("", readLine());
assertThat(command("EXPLAIN (PLAN OPTIMIZED) SELECT * FROM test WHERE i = 2"), containsString("plan"));
assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("Project[[i{f}#"));
assertThat(readLine(), startsWith("\\_Filter[i = 2#"));
assertThat(readLine(), startsWith("\\_Filter[Equals[i"));
assertThat(readLine(), startsWith(" \\_EsRelation[test][i{f}#"));
assertEquals("", readLine());
@ -123,20 +123,20 @@ public class CliExplainIT extends CliIntegrationTestCase {
assertThat(command("EXPLAIN (PLAN PARSED) SELECT COUNT(*) FROM test"), containsString("plan"));
assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("With[{}]"));
assertThat(readLine(), startsWith("\\_Project[[?COUNT(*)]]"));
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[[][index=test],null,Unknown index [test]]"));
assertThat(readLine(), startsWith("\\_Project[[?COUNT[?*]]]"));
assertThat(readLine(), startsWith(" \\_UnresolvedRelation[test]"));
assertEquals("", readLine());
assertThat(command("EXPLAIN " + (randomBoolean() ? "" : "(PLAN ANALYZED) ") + "SELECT COUNT(*) FROM test"),
containsString("plan"));
assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("Aggregate[[],[COUNT(*)#"));
assertThat(readLine(), startsWith("Aggregate[[],[Count[*=1"));
assertThat(readLine(), startsWith("\\_EsRelation[test][i{f}#"));
assertEquals("", readLine());
assertThat(command("EXPLAIN (PLAN OPTIMIZED) SELECT COUNT(*) FROM test"), containsString("plan"));
assertThat(readLine(), startsWith("----------"));
assertThat(readLine(), startsWith("Aggregate[[],[COUNT(*)#"));
assertThat(readLine(), startsWith("Aggregate[[],[Count[*=1"));
assertThat(readLine(), startsWith("\\_EsRelation[test][i{f}#"));
assertEquals("", readLine());

View File

@ -73,7 +73,7 @@ public abstract class ErrorsTestCase extends CliIntegrationTestCase implements o
public void testSelectOrderByScoreInAggContext() throws Exception {
index("test", body -> body.field("foo", 1));
assertFoundOneProblem(command("SELECT foo, COUNT(*) FROM test GROUP BY foo ORDER BY SCORE()"));
assertEquals("line 1:54: Cannot order by non-grouped column [SCORE()], expected [foo]" + END, readLine());
assertEquals("line 1:54: Cannot order by non-grouped column [SCORE()], expected [foo] or an aggregate function" + END, readLine());
}
@Override

View File

@ -81,7 +81,9 @@ public class ErrorsTestCase extends JdbcIntegrationTestCase implements org.elast
try (Connection c = esJdbc()) {
SQLException e = expectThrows(SQLException.class, () ->
c.prepareStatement("SELECT foo, COUNT(*) FROM test GROUP BY foo ORDER BY SCORE()").executeQuery());
assertEquals("Found 1 problem(s)\nline 1:54: Cannot order by non-grouped column [SCORE()], expected [foo]", e.getMessage());
assertEquals(
"Found 1 problem(s)\nline 1:54: Cannot order by non-grouped column [SCORE()], expected [foo] or an aggregate function",
e.getMessage());
}
}

View File

@ -38,6 +38,7 @@ public abstract class SqlSpecTestCase extends SpecBaseIntegrationTestCase {
tests.addAll(readScriptSpec("/datetime.sql-spec", parser));
tests.addAll(readScriptSpec("/math.sql-spec", parser));
tests.addAll(readScriptSpec("/agg.sql-spec", parser));
tests.addAll(readScriptSpec("/agg-ordering.sql-spec", parser));
tests.addAll(readScriptSpec("/arithmetic.sql-spec", parser));
tests.addAll(readScriptSpec("/string-functions.sql-spec", parser));
tests.addAll(readScriptSpec("/case-functions.sql-spec", parser));

View File

@ -0,0 +1,87 @@
//
// Custom sorting/ordering on aggregates
//
countWithImplicitGroupBy
SELECT MAX(salary) AS m FROM test_emp ORDER BY COUNT(*);
countWithImplicitGroupByWithHaving
SELECT MAX(salary) AS m FROM test_emp HAVING MIN(salary) > 1 ORDER BY COUNT(*);
countAndMaxWithImplicitGroupBy
SELECT MAX(salary) AS m FROM test_emp ORDER BY MAX(salary), COUNT(*);
maxWithAliasWithImplicitGroupBy
SELECT MAX(salary) AS m FROM test_emp ORDER BY m;
maxWithAliasWithImplicitGroupByAndHaving
SELECT MAX(salary) AS m FROM test_emp HAVING COUNT(*) > 1 ORDER BY m;
multipleOrderWithImplicitGroupByWithHaving
SELECT MAX(salary) AS m FROM test_emp HAVING MIN(salary) > 1 ORDER BY COUNT(*), m DESC;
multipleOrderWithImplicitGroupByWithoutAlias
SELECT MAX(salary) AS m FROM test_emp HAVING MIN(salary) > 1 ORDER BY COUNT(*), MIN(salary) DESC;
multipleOrderWithImplicitGroupByOfOrdinals
SELECT MAX(salary) AS max, MIN(salary) AS min FROM test_emp HAVING MIN(salary) > 1 ORDER BY 1, COUNT(*), 2 DESC;
aggWithoutAlias
SELECT MAX(salary) AS max FROM test_emp GROUP BY gender ORDER BY MAX(salary);
aggWithAlias
SELECT MAX(salary) AS m FROM test_emp GROUP BY gender ORDER BY m;
multipleAggsThatGetRewrittenWithoutAlias
SELECT MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY gender ORDER BY MAX(salary);
multipleAggsThatGetRewrittenWithAliasDesc
SELECT MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY gender ORDER BY 1 DESC;
multipleAggsThatGetRewrittenWithAlias
SELECT MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY gender ORDER BY max;
aggNotSpecifiedInTheAggregate
SELECT MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender ORDER BY MAX(salary);
aggNotSpecifiedInTheAggregatePlusOrdinal
SELECT MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender ORDER BY MAX(salary), 2 DESC;
aggNotSpecifiedInTheAggregateWithHaving
SELECT MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY MAX(salary);
aggNotSpecifiedInTheAggregateWithHavingDesc
SELECT MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY MAX(salary) DESC;
aggNotSpecifiedInTheAggregateAndGroupWithHaving
SELECT gender, MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY MAX(salary), gender;
groupAndAggNotSpecifiedInTheAggregateWithHaving
SELECT gender, MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY gender, MAX(salary);
multipleAggsThatGetRewrittenWithAliasOnAMediumGroupBy
SELECT languages, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY languages ORDER BY max;
multipleAggsThatGetRewrittenWithAliasOnALargeGroupBy
SELECT emp_no, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY emp_no ORDER BY max;
multipleAggsThatGetRewrittenWithAliasOnAMediumGroupByWithHaving
SELECT languages, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY languages HAVING min BETWEEN 1000 AND 99999 ORDER BY max;
aggNotSpecifiedInTheAggregatemultipleAggsThatGetRewrittenWithAliasOnALargeGroupBy
SELECT emp_no, MIN(salary) AS min FROM test_emp GROUP BY emp_no ORDER BY MAX(salary);
aggNotSpecifiedWithHavingOnLargeGroupBy
SELECT MAX(salary) AS max FROM test_emp GROUP BY emp_no HAVING AVG(salary) > 1000 ORDER BY MIN(salary);
aggWithTieBreakerDescAsc
SELECT emp_no, MIN(languages) AS min FROM test_emp GROUP BY emp_no ORDER BY MIN(languages) DESC NULLS FIRST, emp_no ASC;
aggWithTieBreakerDescDesc
SELECT emp_no, MIN(languages) AS min FROM test_emp GROUP BY emp_no ORDER BY MIN(languages) DESC NULLS FIRST, emp_no DESC;
aggWithTieBreakerAscDesc
SELECT emp_no, MIN(languages) AS min FROM test_emp GROUP BY emp_no ORDER BY MAX(languages) ASC NULLS FIRST, emp_no DESC;
aggWithMixOfOrdinals
SELECT gender AS g, MAX(salary) AS m FROM test_emp GROUP BY gender ORDER BY 2 DESC LIMIT 3;

View File

@ -52,6 +52,8 @@ import org.elasticsearch.xpack.sql.type.DataTypeConversion;
import org.elasticsearch.xpack.sql.type.DataTypes;
import org.elasticsearch.xpack.sql.type.InvalidMappedField;
import org.elasticsearch.xpack.sql.type.UnsupportedEsField;
import org.elasticsearch.xpack.sql.util.CollectionUtils;
import org.elasticsearch.xpack.sql.util.Holder;
import java.util.ArrayList;
import java.util.Arrays;
@ -106,7 +108,8 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
new ResolveFunctions(),
new ResolveAliases(),
new ProjectedAggregations(),
new ResolveAggsInHaving()
new ResolveAggsInHaving(),
new ResolveAggsInOrderBy()
//new ImplicitCasting()
);
Batch finish = new Batch("Finish Analysis",
@ -926,7 +929,7 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
// Handle aggs in HAVING. To help folding any aggs not found in Aggregation
// will be pushed down to the Aggregate and then projected. This also simplifies the Verifier's job.
//
private class ResolveAggsInHaving extends AnalyzeRule<LogicalPlan> {
private class ResolveAggsInHaving extends AnalyzeRule<Filter> {
@Override
protected boolean skipResolved() {
@ -934,54 +937,49 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
}
@Override
protected LogicalPlan rule(LogicalPlan plan) {
protected LogicalPlan rule(Filter f) {
// HAVING = Filter followed by an Agg
if (plan instanceof Filter) {
Filter f = (Filter) plan;
if (f.child() instanceof Aggregate && f.child().resolved()) {
Aggregate agg = (Aggregate) f.child();
if (f.child() instanceof Aggregate && f.child().resolved()) {
Aggregate agg = (Aggregate) f.child();
Set<NamedExpression> missing = null;
Expression condition = f.condition();
Set<NamedExpression> missing = null;
Expression condition = f.condition();
// the condition might contain an agg (AVG(salary)) that could have been resolved
// (salary cannot be pushed down to Aggregate since there's no grouping and thus the function wasn't resolved either)
// the condition might contain an agg (AVG(salary)) that could have been resolved
// (salary cannot be pushed down to Aggregate since there's no grouping and thus the function wasn't resolved either)
// so try resolving the condition in one go through a 'dummy' aggregate
if (!condition.resolved()) {
// that's why try to resolve the condition
Aggregate tryResolvingCondition = new Aggregate(agg.source(), agg.child(), agg.groupings(),
combine(agg.aggregates(), new Alias(f.source(), ".having", condition)));
// so try resolving the condition in one go through a 'dummy' aggregate
if (!condition.resolved()) {
// that's why try to resolve the condition
Aggregate tryResolvingCondition = new Aggregate(agg.source(), agg.child(), agg.groupings(),
combine(agg.aggregates(), new Alias(f.source(), ".having", condition)));
tryResolvingCondition = (Aggregate) analyze(tryResolvingCondition, false);
tryResolvingCondition = (Aggregate) analyze(tryResolvingCondition, false);
// if it got resolved
if (tryResolvingCondition.resolved()) {
// replace the condition with the resolved one
condition = ((Alias) tryResolvingCondition.aggregates()
.get(tryResolvingCondition.aggregates().size() - 1)).child();
} else {
// else bail out
return plan;
}
// if it got resolved
if (tryResolvingCondition.resolved()) {
// replace the condition with the resolved one
condition = ((Alias) tryResolvingCondition.aggregates()
.get(tryResolvingCondition.aggregates().size() - 1)).child();
} else {
// else bail out
return f;
}
missing = findMissingAggregate(agg, condition);
if (!missing.isEmpty()) {
Aggregate newAgg = new Aggregate(agg.source(), agg.child(), agg.groupings(),
combine(agg.aggregates(), missing));
Filter newFilter = new Filter(f.source(), newAgg, condition);
// preserve old output
return new Project(f.source(), newFilter, f.output());
}
return new Filter(f.source(), f.child(), condition);
}
return plan;
}
return plan;
missing = findMissingAggregate(agg, condition);
if (!missing.isEmpty()) {
Aggregate newAgg = new Aggregate(agg.source(), agg.child(), agg.groupings(),
combine(agg.aggregates(), missing));
Filter newFilter = new Filter(f.source(), newAgg, condition);
// preserve old output
return new Project(f.source(), newFilter, f.output());
}
return new Filter(f.source(), f.child(), condition);
}
return f;
}
private Set<NamedExpression> findMissingAggregate(Aggregate target, Expression from) {
@ -1001,6 +999,66 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
}
}
//
// Handle aggs in ORDER BY. To help folding any aggs not found in Aggregation
// will be pushed down to the Aggregate and then projected. This also simplifies the Verifier's job.
// Similar to Having however using a different matching pattern since HAVING is always Filter with Agg,
// while an OrderBy can have multiple intermediate nodes (Filter,Project, etc...)
//
private static class ResolveAggsInOrderBy extends AnalyzeRule<OrderBy> {
@Override
protected boolean skipResolved() {
return false;
}
@Override
protected LogicalPlan rule(OrderBy ob) {
List<Order> orders = ob.order();
// 1. collect aggs inside an order by
List<NamedExpression> aggs = new ArrayList<>();
for (Order order : orders) {
if (Functions.isAggregate(order.child())) {
aggs.add(Expressions.wrapAsNamed(order.child()));
}
}
if (aggs.isEmpty()) {
return ob;
}
// 2. find first Aggregate child and update it
final Holder<Boolean> found = new Holder<>(Boolean.FALSE);
LogicalPlan plan = ob.transformDown(a -> {
if (found.get() == Boolean.FALSE) {
found.set(Boolean.TRUE);
List<NamedExpression> missing = new ArrayList<>();
for (NamedExpression orderedAgg : aggs) {
if (Expressions.anyMatch(a.aggregates(), e -> Expressions.equalsAsAttribute(e, orderedAgg)) == false) {
missing.add(orderedAgg);
}
}
// agg already contains all aggs
if (missing.isEmpty() == false) {
// save aggregates
return new Aggregate(a.source(), a.child(), a.groupings(), CollectionUtils.combine(a.aggregates(), missing));
}
}
return a;
}, Aggregate.class);
// if the plan was updated, project the initial aggregates
if (plan != ob) {
return new Project(ob.source(), plan, ob.output());
}
return ob;
}
}
private class PruneDuplicateFunctions extends AnalyzeRule<LogicalPlan> {
@Override

View File

@ -54,8 +54,8 @@ import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import static java.lang.String.format;
import static java.util.stream.Collectors.toMap;
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.sql.stats.FeatureMetric.COMMAND;
import static org.elasticsearch.xpack.sql.stats.FeatureMetric.GROUPBY;
import static org.elasticsearch.xpack.sql.stats.FeatureMetric.HAVING;
@ -70,7 +70,7 @@ import static org.elasticsearch.xpack.sql.stats.FeatureMetric.WHERE;
*/
public final class Verifier {
private final Metrics metrics;
public Verifier(Metrics metrics) {
this.metrics = metrics;
}
@ -118,7 +118,7 @@ public final class Verifier {
}
private static Failure fail(Node<?> source, String message, Object... args) {
return new Failure(source, format(Locale.ROOT, message, args));
return new Failure(source, format(message, args));
}
public Map<Node<?>, String> verifyFailures(LogicalPlan plan) {
@ -314,11 +314,12 @@ public final class Verifier {
Aggregate a = (Aggregate) child;
Map<Expression, Node<?>> missing = new LinkedHashMap<>();
o.order().forEach(oe -> {
Expression e = oe.child();
// cannot order by aggregates (not supported by composite)
if (Functions.isAggregate(e)) {
missing.put(e, oe);
// aggregates are allowed
if (Functions.isAggregate(e) || e instanceof AggregateFunctionAttribute) {
return;
}
@ -352,7 +353,8 @@ public final class Verifier {
String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY;
// get the location of the first missing expression as the order by might be on a different line
localFailures.add(
fail(missing.values().iterator().next(), "Cannot order by non-grouped column" + plural + " %s, expected %s",
fail(missing.values().iterator().next(),
"Cannot order by non-grouped column" + plural + " {}, expected {} or an aggregate function",
Expressions.names(missing.keySet()),
Expressions.names(a.groupings())));
groupingFailures.add(a);
@ -379,7 +381,7 @@ public final class Verifier {
if (!missing.isEmpty()) {
String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY;
localFailures.add(
fail(condition, "Cannot use HAVING filter on non-aggregate" + plural + " %s; use WHERE instead",
fail(condition, "Cannot use HAVING filter on non-aggregate" + plural + " {}; use WHERE instead",
Expressions.names(missing)));
groupingFailures.add(a);
return false;
@ -388,13 +390,13 @@ public final class Verifier {
if (!unsupported.isEmpty()) {
String plural = unsupported.size() > 1 ? "s" : StringUtils.EMPTY;
localFailures.add(
fail(condition, "HAVING filter is unsupported for function" + plural + " %s",
fail(condition, "HAVING filter is unsupported for function" + plural + " {}",
Expressions.names(unsupported)));
groupingFailures.add(a);
return false;
}
}
}
}
return true;
}
@ -438,7 +440,7 @@ public final class Verifier {
// Min & Max on a Keyword field will be translated to First & Last respectively
unsupported.add(e);
return true;
}
}
}
// skip literals / foldable
@ -480,7 +482,7 @@ public final class Verifier {
e.collectFirstChildren(c -> {
if (Functions.isGrouping(c)) {
localFailures.add(fail(c,
"Cannot combine [%s] grouping function inside GROUP BY, found [%s];"
"Cannot combine [{}] grouping function inside GROUP BY, found [{}];"
+ " consider moving the expression inside the histogram",
Expressions.name(c), Expressions.name(e)));
return true;
@ -509,7 +511,7 @@ public final class Verifier {
if (!missing.isEmpty()) {
String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY;
localFailures.add(fail(missing.values().iterator().next(), "Cannot use non-grouped column" + plural + " %s, expected %s",
localFailures.add(fail(missing.values().iterator().next(), "Cannot use non-grouped column" + plural + " {}, expected {}",
Expressions.names(missing.keySet()),
Expressions.names(a.groupings())));
return false;
@ -592,7 +594,7 @@ public final class Verifier {
filter.condition().forEachDown(e -> {
if (Functions.isAggregate(e) || e instanceof AggregateFunctionAttribute) {
localFailures.add(
fail(e, "Cannot use WHERE filtering on aggregate function [%s], use HAVING instead", Expressions.name(e)));
fail(e, "Cannot use WHERE filtering on aggregate function [{}], use HAVING instead", Expressions.name(e)));
}
}, Expression.class);
}
@ -606,7 +608,7 @@ public final class Verifier {
filter.condition().forEachDown(e -> {
if (Functions.isGrouping(e) || e instanceof GroupingFunctionAttribute) {
localFailures
.add(fail(e, "Cannot filter on grouping function [%s], use its argument instead", Expressions.name(e)));
.add(fail(e, "Cannot filter on grouping function [{}], use its argument instead", Expressions.name(e)));
}
}, Expression.class);
}
@ -659,7 +661,7 @@ public final class Verifier {
DataType dt = in.value().dataType();
for (Expression value : in.list()) {
if (areTypesCompatible(dt, value.dataType()) == false) {
localFailures.add(fail(value, "expected data type [%s], value provided is of type [%s]",
localFailures.add(fail(value, "expected data type [{}], value provided is of type [{}]",
dt.esType, value.dataType().esType));
return;
}
@ -680,7 +682,7 @@ public final class Verifier {
}
} else {
if (areTypesCompatible(dt, child.dataType()) == false) {
localFailures.add(fail(child, "expected data type [%s], value provided is of type [%s]",
localFailures.add(fail(child, "expected data type [{}], value provided is of type [{}]",
dt.esType, child.dataType().esType));
return;
}

View File

@ -60,7 +60,7 @@ public class PlanExecutor {
}
private SqlSession newSession(Configuration cfg) {
return new SqlSession(cfg, client, functionRegistry, indexResolver, preAnalyzer, verifier, optimizer, planner);
return new SqlSession(cfg, client, functionRegistry, indexResolver, preAnalyzer, verifier, optimizer, planner, this);
}
public void searchSource(Configuration cfg, String sql, List<SqlTypedParamValue> params, ActionListener<SearchSourceBuilder> listener) {
@ -68,15 +68,20 @@ public class PlanExecutor {
if (exec instanceof EsQueryExec) {
EsQueryExec e = (EsQueryExec) exec;
listener.onResponse(SourceGenerator.sourceBuilder(e.queryContainer(), cfg.filter(), cfg.pageSize()));
} else if (exec instanceof LocalExec) {
listener.onFailure(new PlanningException("Cannot generate a query DSL for an SQL query that either " +
"its WHERE clause evaluates to FALSE or doesn't operate on a table (missing a FROM clause), sql statement: [{}]",
sql));
} else if (exec instanceof CommandExec) {
listener.onFailure(new PlanningException("Cannot generate a query DSL for a special SQL command " +
"(e.g.: DESCRIBE, SHOW), sql statement: [{}]", sql));
} else {
listener.onFailure(new PlanningException("Cannot generate a query DSL, sql statement: [{}]", sql));
}
// try to provide a better resolution of what failed
else {
String message = null;
if (exec instanceof LocalExec) {
message = "Cannot generate a query DSL for an SQL query that either " +
"its WHERE clause evaluates to FALSE or doesn't operate on a table (missing a FROM clause)";
} else if (exec instanceof CommandExec) {
message = "Cannot generate a query DSL for a special SQL command " +
"(e.g.: DESCRIBE, SHOW)";
} else {
message = "Cannot generate a query DSL";
}
listener.onFailure(new PlanningException(message + ", sql statement: [{}]", sql));
}
}, listener::onFailure));
}

View File

@ -32,6 +32,7 @@ import org.elasticsearch.xpack.sql.util.StringUtils;
import java.io.IOException;
import java.util.Arrays;
import java.util.BitSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -49,12 +50,14 @@ public class CompositeAggregationCursor implements Cursor {
private final String[] indices;
private final byte[] nextQuery;
private final List<BucketExtractor> extractors;
private final BitSet mask;
private final int limit;
CompositeAggregationCursor(byte[] next, List<BucketExtractor> exts, int remainingLimit, String... indices) {
CompositeAggregationCursor(byte[] next, List<BucketExtractor> exts, BitSet mask, int remainingLimit, String... indices) {
this.indices = indices;
this.nextQuery = next;
this.extractors = exts;
this.mask = mask;
this.limit = remainingLimit;
}
@ -64,6 +67,7 @@ public class CompositeAggregationCursor implements Cursor {
limit = in.readVInt();
extractors = in.readNamedWriteableList(BucketExtractor.class);
mask = BitSet.valueOf(in.readByteArray());
}
@Override
@ -73,6 +77,7 @@ public class CompositeAggregationCursor implements Cursor {
out.writeVInt(limit);
out.writeNamedWriteableList(extractors);
out.writeByteArray(mask.toByteArray());
}
@Override
@ -88,6 +93,10 @@ public class CompositeAggregationCursor implements Cursor {
return nextQuery;
}
BitSet mask() {
return mask;
}
List<BucketExtractor> extractors() {
return extractors;
}
@ -125,7 +134,7 @@ public class CompositeAggregationCursor implements Cursor {
}
updateCompositeAfterKey(r, query);
CompositeAggsRowSet rowSet = new CompositeAggsRowSet(extractors, r, limit, serializeQuery(query), indices);
CompositeAggsRowSet rowSet = new CompositeAggsRowSet(extractors, mask, r, limit, serializeQuery(query), indices);
listener.onResponse(rowSet);
} catch (Exception ex) {
listener.onFailure(ex);

View File

@ -8,10 +8,10 @@ package org.elasticsearch.xpack.sql.execution.search;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.xpack.sql.execution.search.extractor.BucketExtractor;
import org.elasticsearch.xpack.sql.session.AbstractRowSet;
import org.elasticsearch.xpack.sql.session.Cursor;
import org.elasticsearch.xpack.sql.session.RowSet;
import java.util.BitSet;
import java.util.List;
import static java.util.Collections.emptyList;
@ -19,8 +19,7 @@ import static java.util.Collections.emptyList;
/**
* {@link RowSet} specific to (GROUP BY) aggregation.
*/
class CompositeAggsRowSet extends AbstractRowSet {
private final List<BucketExtractor> exts;
class CompositeAggsRowSet extends ResultRowSet<BucketExtractor> {
private final List<? extends CompositeAggregation.Bucket> buckets;
@ -29,8 +28,8 @@ class CompositeAggsRowSet extends AbstractRowSet {
private final int size;
private int row = 0;
CompositeAggsRowSet(List<BucketExtractor> exts, SearchResponse response, int limit, byte[] next, String... indices) {
this.exts = exts;
CompositeAggsRowSet(List<BucketExtractor> exts, BitSet mask, SearchResponse response, int limit, byte[] next, String... indices) {
super(exts, mask);
CompositeAggregation composite = CompositeAggregationCursor.getComposite(response);
if (composite != null) {
@ -54,19 +53,14 @@ class CompositeAggsRowSet extends AbstractRowSet {
if (next == null || size == 0 || remainingLimit == 0) {
cursor = Cursor.EMPTY;
} else {
cursor = new CompositeAggregationCursor(next, exts, remainingLimit, indices);
cursor = new CompositeAggregationCursor(next, exts, mask, remainingLimit, indices);
}
}
}
@Override
protected Object getColumn(int column) {
return exts.get(column).extract(buckets.get(row));
}
@Override
public int columnCount() {
return exts.size();
protected Object extractValue(BucketExtractor e) {
return e.extract(buckets.get(row));
}
@Override

View File

@ -0,0 +1,106 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.execution.search;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.sql.session.Configuration;
import org.elasticsearch.xpack.sql.session.Cursor;
import org.elasticsearch.xpack.sql.session.RowSet;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import static java.util.Collections.emptyList;
public class PagingListCursor implements Cursor {
public static final String NAME = "p";
private final List<List<?>> data;
private final int columnCount;
private final int pageSize;
PagingListCursor(List<List<?>> data, int columnCount, int pageSize) {
this.data = data;
this.columnCount = columnCount;
this.pageSize = pageSize;
}
@SuppressWarnings("unchecked")
public PagingListCursor(StreamInput in) throws IOException {
data = (List<List<?>>) in.readGenericValue();
columnCount = in.readVInt();
pageSize = in.readVInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeGenericValue(data);
out.writeVInt(columnCount);
out.writeVInt(pageSize);
}
@Override
public String getWriteableName() {
return NAME;
}
List<List<?>> data() {
return data;
}
int columnCount() {
return columnCount;
}
int pageSize() {
return pageSize;
}
@Override
public void nextPage(Configuration cfg, Client client, NamedWriteableRegistry registry, ActionListener<RowSet> listener) {
// the check is really a safety measure since the page initialization handles it already (by returning an empty cursor)
List<List<?>> nextData = data.size() > pageSize ? data.subList(pageSize, data.size()) : emptyList();
listener.onResponse(new PagingListRowSet(nextData, columnCount, pageSize));
}
@Override
public void clear(Configuration cfg, Client client, ActionListener<Boolean> listener) {
listener.onResponse(true);
}
@Override
public int hashCode() {
return Objects.hash(data, columnCount, pageSize);
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
PagingListCursor other = (PagingListCursor) obj;
return Objects.equals(pageSize, other.pageSize)
&& Objects.equals(columnCount, other.columnCount)
&& Objects.equals(data, other.data);
}
@Override
public String toString() {
return "cursor for paging list";
}
}

View File

@ -0,0 +1,46 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.execution.search;
import org.elasticsearch.xpack.sql.session.Cursor;
import org.elasticsearch.xpack.sql.session.ListRowSet;
import org.elasticsearch.xpack.sql.type.Schema;
import java.util.List;
class PagingListRowSet extends ListRowSet {
private final int pageSize;
private final int columnCount;
private final Cursor cursor;
PagingListRowSet(List<List<?>> list, int columnCount, int pageSize) {
this(Schema.EMPTY, list, columnCount, pageSize);
}
PagingListRowSet(Schema schema, List<List<?>> list, int columnCount, int pageSize) {
super(schema, list);
this.columnCount = columnCount;
this.pageSize = Math.min(pageSize, list.size());
this.cursor = list.size() > pageSize ? new PagingListCursor(list, columnCount, pageSize) : Cursor.EMPTY;
}
@Override
public int size() {
return pageSize;
}
@Override
public int columnCount() {
return columnCount;
}
@Override
public Cursor nextPageCursor() {
return cursor;
}
}

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.sql.execution.search;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.PriorityQueue;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
@ -14,6 +15,7 @@ import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -25,6 +27,7 @@ import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation.Buck
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.execution.PlanExecutor;
import org.elasticsearch.xpack.sql.execution.search.extractor.BucketExtractor;
import org.elasticsearch.xpack.sql.execution.search.extractor.CompositeKeyExtractor;
import org.elasticsearch.xpack.sql.execution.search.extractor.ComputingExtractor;
@ -33,11 +36,14 @@ import org.elasticsearch.xpack.sql.execution.search.extractor.FieldHitExtractor;
import org.elasticsearch.xpack.sql.execution.search.extractor.HitExtractor;
import org.elasticsearch.xpack.sql.execution.search.extractor.MetricAggExtractor;
import org.elasticsearch.xpack.sql.execution.search.extractor.TopHitsAggExtractor;
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.ExpressionId;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.AggExtractorInput;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.AggPathInput;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.HitExtractorInput;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.ReferenceInput;
import org.elasticsearch.xpack.sql.planner.PlanningException;
import org.elasticsearch.xpack.sql.querydsl.agg.Aggs;
import org.elasticsearch.xpack.sql.querydsl.container.ComputedRef;
import org.elasticsearch.xpack.sql.querydsl.container.GlobalCountRef;
@ -48,16 +54,23 @@ import org.elasticsearch.xpack.sql.querydsl.container.ScriptFieldRef;
import org.elasticsearch.xpack.sql.querydsl.container.SearchHitFieldRef;
import org.elasticsearch.xpack.sql.querydsl.container.TopHitsAggRef;
import org.elasticsearch.xpack.sql.session.Configuration;
import org.elasticsearch.xpack.sql.session.Cursor;
import org.elasticsearch.xpack.sql.session.RowSet;
import org.elasticsearch.xpack.sql.session.Rows;
import org.elasticsearch.xpack.sql.session.SchemaRowSet;
import org.elasticsearch.xpack.sql.session.SqlSession;
import org.elasticsearch.xpack.sql.type.Schema;
import org.elasticsearch.xpack.sql.util.StringUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Comparator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import static java.util.Collections.singletonList;
// TODO: add retry/back-off
@ -65,25 +78,25 @@ public class Querier {
private final Logger log = LogManager.getLogger(getClass());
private final PlanExecutor planExecutor;
private final Configuration cfg;
private final TimeValue keepAlive, timeout;
private final int size;
private final Client client;
@Nullable
private final QueryBuilder filter;
public Querier(Client client, Configuration cfg) {
this(client, cfg.requestTimeout(), cfg.pageTimeout(), cfg.filter(), cfg.pageSize());
public Querier(SqlSession sqlSession) {
this.planExecutor = sqlSession.planExecutor();
this.client = sqlSession.client();
this.cfg = sqlSession.configuration();
this.keepAlive = cfg.requestTimeout();
this.timeout = cfg.pageTimeout();
this.filter = cfg.filter();
this.size = cfg.pageSize();
}
public Querier(Client client, TimeValue keepAlive, TimeValue timeout, QueryBuilder filter, int size) {
this.client = client;
this.keepAlive = keepAlive;
this.timeout = timeout;
this.filter = filter;
this.size = size;
}
public void query(Schema schema, QueryContainer query, String index, ActionListener<SchemaRowSet> listener) {
public void query(List<Attribute> output, QueryContainer query, String index, ActionListener<SchemaRowSet> listener) {
// prepare the request
SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(query, filter, size);
// set query timeout
@ -97,16 +110,21 @@ public class Querier {
SearchRequest search = prepareRequest(client, sourceBuilder, timeout, Strings.commaDelimitedListToStringArray(index));
ActionListener<SearchResponse> l;
@SuppressWarnings("rawtypes")
List<Tuple<Integer, Comparator>> sortingColumns = query.sortingColumns();
listener = sortingColumns.isEmpty() ? listener : new LocalAggregationSorterListener(listener, sortingColumns, query.limit());
ActionListener<SearchResponse> l = null;
if (query.isAggsOnly()) {
if (query.aggs().useImplicitGroupBy()) {
l = new ImplicitGroupActionListener(listener, client, timeout, schema, query, search);
l = new ImplicitGroupActionListener(listener, client, timeout, output, query, search);
} else {
l = new CompositeActionListener(listener, client, timeout, schema, query, search);
l = new CompositeActionListener(listener, client, timeout, output, query, search);
}
} else {
search.scroll(keepAlive);
l = new ScrollActionListener(listener, client, timeout, schema, query);
l = new ScrollActionListener(listener, client, timeout, output, query);
}
client.search(search, l);
@ -114,13 +132,148 @@ public class Querier {
public static SearchRequest prepareRequest(Client client, SearchSourceBuilder source, TimeValue timeout, String... indices) {
SearchRequest search = client.prepareSearch(indices)
// always track total hits accurately
.setTrackTotalHits(true)
.setAllowPartialSearchResults(false)
.setSource(source)
.setTimeout(timeout)
.request();
return search;
// always track total hits accurately
.setTrackTotalHits(true)
.setAllowPartialSearchResults(false)
.setSource(source)
.setTimeout(timeout)
.request();
return search;
}
/**
* Listener used for local sorting (typically due to aggregations used inside `ORDER BY`).
*
* This listener consumes the whole result set, sorts it in memory then sends the paginated
* results back to the client.
*/
@SuppressWarnings("rawtypes")
class LocalAggregationSorterListener implements ActionListener<SchemaRowSet> {
private final ActionListener<SchemaRowSet> listener;
// keep the top N entries.
private final PriorityQueue<Tuple<List<?>, Integer>> data;
private final AtomicInteger counter = new AtomicInteger();
private volatile Schema schema;
private static final int MAXIMUM_SIZE = 512;
private final boolean noLimit;
LocalAggregationSorterListener(ActionListener<SchemaRowSet> listener, List<Tuple<Integer, Comparator>> sortingColumns, int limit) {
this.listener = listener;
int size = MAXIMUM_SIZE;
if (limit < 0) {
noLimit = true;
} else {
noLimit = false;
if (limit > MAXIMUM_SIZE) {
throw new PlanningException("The maximum LIMIT for aggregate sorting is [{}], received [{}]", limit, MAXIMUM_SIZE);
} else {
size = limit;
}
}
this.data = new PriorityQueue<Tuple<List<?>, Integer>>(size) {
// compare row based on the received attribute sort
// if a sort item is not in the list, it is assumed the sorting happened in ES
// and the results are left as is (by using the row ordering), otherwise it is sorted based on the given criteria.
//
// Take for example ORDER BY a, x, b, y
// a, b - are sorted in ES
// x, y - need to be sorted client-side
// sorting on x kicks in, only if the values for a are equal.
// thanks to @jpountz for the row ordering idea as a way to preserve ordering
@SuppressWarnings("unchecked")
@Override
protected boolean lessThan(Tuple<List<?>, Integer> l, Tuple<List<?>, Integer> r) {
for (Tuple<Integer, Comparator> tuple : sortingColumns) {
int i = tuple.v1().intValue();
Comparator comparator = tuple.v2();
Object vl = l.v1().get(i);
Object vr = r.v1().get(i);
if (comparator != null) {
int result = comparator.compare(vl, vr);
// if things are equals, move to the next comparator
if (result != 0) {
return result < 0;
}
}
// no comparator means the existing order needs to be preserved
else {
// check the values - if they are equal move to the next comparator
// otherwise return the row order
if (Objects.equals(vl, vr) == false) {
return l.v2().compareTo(r.v2()) < 0;
}
}
}
// everything is equal, fall-back to the row order
return l.v2().compareTo(r.v2()) < 0;
}
};
}
@Override
public void onResponse(SchemaRowSet schemaRowSet) {
schema = schemaRowSet.schema();
doResponse(schemaRowSet);
}
private void doResponse(RowSet rowSet) {
// 1. consume all pages received
if (consumeRowSet(rowSet) == false) {
return;
}
Cursor cursor = rowSet.nextPageCursor();
// 1a. trigger a next call if there's still data
if (cursor != Cursor.EMPTY) {
// trigger a next call
planExecutor.nextPage(cfg, cursor, ActionListener.wrap(this::doResponse, this::onFailure));
// make sure to bail out afterwards as we'll get called by a different thread
return;
}
// no more data available, the last thread sends the response
// 2. send the in-memory view to the client
sendResponse();
}
private boolean consumeRowSet(RowSet rowSet) {
// use a synchronized block for visibility purposes (there's no concurrency)
ResultRowSet<?> rrs = (ResultRowSet<?>) rowSet;
synchronized (data) {
for (boolean hasRows = rrs.hasCurrentRow(); hasRows; hasRows = rrs.advanceRow()) {
List<Object> row = new ArrayList<>(rrs.columnCount());
rrs.forEachResultColumn(row::add);
// if the queue overflows and no limit was specified, bail out
if (data.insertWithOverflow(new Tuple<>(row, counter.getAndIncrement())) != null && noLimit) {
onFailure(new SqlIllegalArgumentException(
"The default limit [{}] for aggregate sorting has been reached; please specify a LIMIT"));
return false;
}
}
}
return true;
}
private void sendResponse() {
List<List<?>> list = new ArrayList<>(data.size());
Tuple<List<?>, Integer> pop = null;
while ((pop = data.pop()) != null) {
list.add(pop.v1());
}
listener.onResponse(new PagingListRowSet(schema, list, schema.size(), cfg.pageSize()));
}
@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
}
/**
@ -156,9 +309,9 @@ public class Querier {
}
});
ImplicitGroupActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, Schema schema,
ImplicitGroupActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, List<Attribute> output,
QueryContainer query, SearchRequest request) {
super(listener, client, keepAlive, schema, query, request);
super(listener, client, keepAlive, output, query, request);
}
@Override
@ -182,9 +335,12 @@ public class Querier {
if (buckets.size() == 1) {
Bucket implicitGroup = buckets.get(0);
List<BucketExtractor> extractors = initBucketExtractors(response);
Object[] values = new Object[extractors.size()];
for (int i = 0; i < values.length; i++) {
values[i] = extractors.get(i).extract(implicitGroup);
Object[] values = new Object[mask.cardinality()];
int index = 0;
for (int i = mask.nextSetBit(0); i >= 0; i = mask.nextSetBit(i + 1)) {
values[index++] = extractors.get(i).extract(implicitGroup);
}
listener.onResponse(Rows.singleton(schema, values));
@ -205,8 +361,8 @@ public class Querier {
static class CompositeActionListener extends BaseAggActionListener {
CompositeActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive,
Schema schema, QueryContainer query, SearchRequest request) {
super(listener, client, keepAlive, schema, query, request);
List<Attribute> output, QueryContainer query, SearchRequest request) {
super(listener, client, keepAlive, output, query, request);
}
@ -232,7 +388,7 @@ public class Querier {
}
listener.onResponse(
new SchemaCompositeAggsRowSet(schema, initBucketExtractors(response), response, query.limit(),
new SchemaCompositeAggsRowSet(schema, initBucketExtractors(response), mask, response, query.limit(),
nextSearch,
request.indices()));
}
@ -246,23 +402,25 @@ public class Querier {
abstract static class BaseAggActionListener extends BaseActionListener {
final QueryContainer query;
final SearchRequest request;
final BitSet mask;
BaseAggActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, Schema schema,
BaseAggActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, List<Attribute> output,
QueryContainer query, SearchRequest request) {
super(listener, client, keepAlive, schema);
super(listener, client, keepAlive, output);
this.query = query;
this.request = request;
this.mask = query.columnMask(output);
}
protected List<BucketExtractor> initBucketExtractors(SearchResponse response) {
// create response extractors for the first time
List<FieldExtraction> refs = query.columns();
List<Tuple<FieldExtraction, ExpressionId>> refs = query.fields();
List<BucketExtractor> exts = new ArrayList<>(refs.size());
ConstantExtractor totalCount = new ConstantExtractor(response.getHits().getTotalHits().value);
for (FieldExtraction ref : refs) {
exts.add(createExtractor(ref, totalCount));
for (Tuple<FieldExtraction, ExpressionId> ref : refs) {
exts.add(createExtractor(ref.v1(), totalCount));
}
return exts;
}
@ -308,11 +466,13 @@ public class Querier {
*/
static class ScrollActionListener extends BaseActionListener {
private final QueryContainer query;
private final BitSet mask;
ScrollActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive,
Schema schema, QueryContainer query) {
super(listener, client, keepAlive, schema);
List<Attribute> output, QueryContainer query) {
super(listener, client, keepAlive, output);
this.query = query;
this.mask = query.columnMask(output);
}
@Override
@ -320,27 +480,27 @@ public class Querier {
SearchHit[] hits = response.getHits().getHits();
// create response extractors for the first time
List<FieldExtraction> refs = query.columns();
List<Tuple<FieldExtraction, ExpressionId>> refs = query.fields();
List<HitExtractor> exts = new ArrayList<>(refs.size());
for (FieldExtraction ref : refs) {
exts.add(createExtractor(ref));
for (Tuple<FieldExtraction, ExpressionId> ref : refs) {
exts.add(createExtractor(ref.v1()));
}
// there are some results
if (hits.length > 0) {
String scrollId = response.getScrollId();
SchemaSearchHitRowSet hitRowSet = new SchemaSearchHitRowSet(schema, exts, hits, query.limit(), scrollId);
SchemaSearchHitRowSet hitRowSet = new SchemaSearchHitRowSet(schema, exts, mask, hits, query.limit(), scrollId);
// if there's an id, try to setup next scroll
if (scrollId != null &&
// is all the content already retrieved?
(Boolean.TRUE.equals(response.isTerminatedEarly())
(Boolean.TRUE.equals(response.isTerminatedEarly())
|| response.getHits().getTotalHits().value == hits.length
|| hitRowSet.isLimitReached())) {
// if so, clear the scroll
clear(response.getScrollId(), ActionListener.wrap(
succeeded -> listener.onResponse(new SchemaSearchHitRowSet(schema, exts, hits, query.limit(), null)),
succeeded -> listener.onResponse(new SchemaSearchHitRowSet(schema, exts, mask, hits, query.limit(), null)),
listener::onFailure));
} else {
listener.onResponse(hitRowSet);
@ -401,12 +561,12 @@ public class Querier {
final TimeValue keepAlive;
final Schema schema;
BaseActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, Schema schema) {
BaseActionListener(ActionListener<SchemaRowSet> listener, Client client, TimeValue keepAlive, List<Attribute> output) {
this.listener = listener;
this.client = client;
this.keepAlive = keepAlive;
this.schema = schema;
this.schema = Rows.schema(output);
}
// TODO: need to handle rejections plus check failures (shard size, etc...)

View File

@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.execution.search;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.session.AbstractRowSet;
import org.elasticsearch.xpack.sql.util.Check;
import java.util.BitSet;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
abstract class ResultRowSet<E extends NamedWriteable> extends AbstractRowSet {
private final List<E> extractors;
private final BitSet mask;
ResultRowSet(List<E> extractors, BitSet mask) {
this.extractors = extractors;
this.mask = mask;
Check.isTrue(mask.length() <= extractors.size(), "Invalid number of extracted columns specified");
}
@Override
public final int columnCount() {
return mask.cardinality();
}
@Override
protected Object getColumn(int column) {
return extractValue(userExtractor(column));
}
List<E> extractors() {
return extractors;
}
BitSet mask() {
return mask;
}
E userExtractor(int column) {
int i = -1;
// find the nth set bit
for (i = mask.nextSetBit(0); i >= 0; i = mask.nextSetBit(i + 1)) {
if (column-- == 0) {
return extractors.get(i);
}
}
throw new SqlIllegalArgumentException("Cannot find column [{}]", column);
}
Object resultColumn(int column) {
return extractValue(extractors().get(column));
}
int resultColumnCount() {
return extractors.size();
}
void forEachResultColumn(Consumer<? super Object> action) {
Objects.requireNonNull(action);
int rowSize = resultColumnCount();
for (int i = 0; i < rowSize; i++) {
action.accept(resultColumn(i));
}
}
protected abstract Object extractValue(E e);
}

View File

@ -11,6 +11,7 @@ import org.elasticsearch.xpack.sql.session.RowSet;
import org.elasticsearch.xpack.sql.session.SchemaRowSet;
import org.elasticsearch.xpack.sql.type.Schema;
import java.util.BitSet;
import java.util.List;
/**
@ -21,9 +22,10 @@ class SchemaCompositeAggsRowSet extends CompositeAggsRowSet implements SchemaRow
private final Schema schema;
SchemaCompositeAggsRowSet(Schema schema, List<BucketExtractor> exts, SearchResponse response, int limitAggs, byte[] next,
SchemaCompositeAggsRowSet(Schema schema, List<BucketExtractor> exts, BitSet mask, SearchResponse response, int limitAggs,
byte[] next,
String... indices) {
super(exts, response, limitAggs, next, indices);
super(exts, mask, response, limitAggs, next, indices);
this.schema = schema;
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.xpack.sql.execution.search.extractor.HitExtractor;
import org.elasticsearch.xpack.sql.session.SchemaRowSet;
import org.elasticsearch.xpack.sql.type.Schema;
import java.util.BitSet;
import java.util.List;
/**
@ -20,8 +21,8 @@ import java.util.List;
class SchemaSearchHitRowSet extends SearchHitRowSet implements SchemaRowSet {
private final Schema schema;
SchemaSearchHitRowSet(Schema schema, List<HitExtractor> exts, SearchHit[] hits, int limitHits, String scrollId) {
super(exts, hits, limitHits, scrollId);
SchemaSearchHitRowSet(Schema schema, List<HitExtractor> exts, BitSet mask, SearchHit[] hits, int limitHits, String scrollId) {
super(exts, mask, hits, limitHits, scrollId);
this.schema = schema;
}

View File

@ -23,6 +23,7 @@ import org.elasticsearch.xpack.sql.session.Cursor;
import org.elasticsearch.xpack.sql.session.RowSet;
import java.io.IOException;
import java.util.BitSet;
import java.util.List;
import java.util.Objects;
@ -34,11 +35,13 @@ public class ScrollCursor implements Cursor {
private final String scrollId;
private final List<HitExtractor> extractors;
private final BitSet mask;
private final int limit;
public ScrollCursor(String scrollId, List<HitExtractor> extractors, int limit) {
public ScrollCursor(String scrollId, List<HitExtractor> extractors, BitSet mask, int limit) {
this.scrollId = scrollId;
this.extractors = extractors;
this.mask = mask;
this.limit = limit;
}
@ -47,6 +50,7 @@ public class ScrollCursor implements Cursor {
limit = in.readVInt();
extractors = in.readNamedWriteableList(HitExtractor.class);
mask = BitSet.valueOf(in.readByteArray());
}
@Override
@ -55,6 +59,7 @@ public class ScrollCursor implements Cursor {
out.writeVInt(limit);
out.writeNamedWriteableList(extractors);
out.writeByteArray(mask.toByteArray());
}
@Override
@ -66,6 +71,10 @@ public class ScrollCursor implements Cursor {
return scrollId;
}
BitSet mask() {
return mask;
}
List<HitExtractor> extractors() {
return extractors;
}
@ -79,7 +88,7 @@ public class ScrollCursor implements Cursor {
SearchScrollRequest request = new SearchScrollRequest(scrollId).scroll(cfg.pageTimeout());
client.searchScroll(request, ActionListener.wrap((SearchResponse response) -> {
SearchHitRowSet rowSet = new SearchHitRowSet(extractors, response.getHits().getHits(),
SearchHitRowSet rowSet = new SearchHitRowSet(extractors, mask, response.getHits().getHits(),
limit, response.getScrollId());
if (rowSet.nextPageCursor() == Cursor.EMPTY ) {
// we are finished with this cursor, let's clean it before continuing

View File

@ -9,10 +9,10 @@ import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.execution.search.extractor.HitExtractor;
import org.elasticsearch.xpack.sql.session.AbstractRowSet;
import org.elasticsearch.xpack.sql.session.Cursor;
import java.util.Arrays;
import java.util.BitSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
@ -20,10 +20,9 @@ import java.util.Set;
/**
* Extracts rows from an array of {@link SearchHit}.
*/
class SearchHitRowSet extends AbstractRowSet {
class SearchHitRowSet extends ResultRowSet<HitExtractor> {
private final SearchHit[] hits;
private final Cursor cursor;
private final List<HitExtractor> extractors;
private final Set<String> innerHits = new LinkedHashSet<>();
private final String innerHit;
@ -31,10 +30,10 @@ class SearchHitRowSet extends AbstractRowSet {
private final int[] indexPerLevel;
private int row = 0;
SearchHitRowSet(List<HitExtractor> exts, SearchHit[] hits, int limit, String scrollId) {
SearchHitRowSet(List<HitExtractor> exts, BitSet mask, SearchHit[] hits, int limit, String scrollId) {
super(exts, mask);
this.hits = hits;
this.extractors = exts;
// Since the results might contain nested docs, the iteration is similar to that of Aggregation
// namely it discovers the nested docs and then, for iteration, increments the deepest level first
@ -85,7 +84,7 @@ class SearchHitRowSet extends AbstractRowSet {
if (size == 0 || remainingLimit == 0) {
cursor = Cursor.EMPTY;
} else {
cursor = new ScrollCursor(scrollId, extractors, remainingLimit);
cursor = new ScrollCursor(scrollId, extractors(), mask, remainingLimit);
}
}
}
@ -95,13 +94,7 @@ class SearchHitRowSet extends AbstractRowSet {
}
@Override
public int columnCount() {
return extractors.size();
}
@Override
protected Object getColumn(int column) {
HitExtractor e = extractors.get(column);
protected Object extractValue(HitExtractor e) {
int extractorLevel = e.hitName() == null ? 0 : 1;
SearchHit hit = null;

View File

@ -58,7 +58,7 @@ public abstract class SourceGenerator {
// need to be retrieved from the result documents
// NB: the sortBuilder takes care of eliminating duplicates
container.columns().forEach(cr -> cr.collectFields(sortBuilder));
container.fields().forEach(f -> f.v1().collectFields(sortBuilder));
sortBuilder.build(source);
optimize(sortBuilder, source);

View File

@ -11,7 +11,6 @@ import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.stream.Stream;
@ -21,7 +20,7 @@ import static java.util.Collections.singletonMap;
import static java.util.Collections.unmodifiableCollection;
import static java.util.Collections.unmodifiableSet;
public class AttributeMap<E> {
public class AttributeMap<E> implements Map<Attribute, E> {
static class AttributeWrapper {
@ -120,8 +119,9 @@ public class AttributeMap<E> {
@SuppressWarnings("unchecked")
public <A> A[] toArray(A[] a) {
// collection is immutable so use that to our advantage
if (a.length < size())
if (a.length < size()) {
a = (A[]) java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), size());
}
int i = 0;
Object[] result = a;
for (U u : this) {
@ -140,6 +140,14 @@ public class AttributeMap<E> {
}
}
@SuppressWarnings("rawtypes")
public static final AttributeMap EMPTY = new AttributeMap<>();
@SuppressWarnings("unchecked")
public static final <E> AttributeMap<E> emptyAttributeMap() {
return EMPTY;
}
private final Map<AttributeWrapper, E> delegate;
private Set<Attribute> keySet = null;
private Collection<E> values = null;
@ -175,6 +183,14 @@ public class AttributeMap<E> {
delegate.putAll(other.delegate);
}
public AttributeMap<E> combine(AttributeMap<E> other) {
AttributeMap<E> combine = new AttributeMap<>();
combine.addAll(this);
combine.addAll(other);
return combine;
}
public AttributeMap<E> subtract(AttributeMap<E> other) {
AttributeMap<E> diff = new AttributeMap<>();
for (Entry<AttributeWrapper, E> entry : this.delegate.entrySet()) {
@ -222,14 +238,17 @@ public class AttributeMap<E> {
return s;
}
@Override
public int size() {
return delegate.size();
}
@Override
public boolean isEmpty() {
return delegate.isEmpty();
}
@Override
public boolean containsKey(Object key) {
if (key instanceof NamedExpression) {
return delegate.keySet().contains(new AttributeWrapper(((NamedExpression) key).toAttribute()));
@ -237,10 +256,12 @@ public class AttributeMap<E> {
return false;
}
@Override
public boolean containsValue(Object value) {
return delegate.values().contains(value);
}
@Override
public E get(Object key) {
if (key instanceof NamedExpression) {
return delegate.get(new AttributeWrapper(((NamedExpression) key).toAttribute()));
@ -248,6 +269,7 @@ public class AttributeMap<E> {
return null;
}
@Override
public E getOrDefault(Object key, E defaultValue) {
E e;
return (((e = get(key)) != null) || containsKey(key))
@ -255,6 +277,27 @@ public class AttributeMap<E> {
: defaultValue;
}
@Override
public E put(Attribute key, E value) {
throw new UnsupportedOperationException();
}
@Override
public E remove(Object key) {
throw new UnsupportedOperationException();
}
@Override
public void putAll(Map<? extends Attribute, ? extends E> m) {
throw new UnsupportedOperationException();
}
@Override
public void clear() {
throw new UnsupportedOperationException();
}
@Override
public Set<Attribute> keySet() {
if (keySet == null) {
keySet = new UnwrappingSet<AttributeWrapper, Attribute>(delegate.keySet()) {
@ -267,6 +310,7 @@ public class AttributeMap<E> {
return keySet;
}
@Override
public Collection<E> values() {
if (values == null) {
values = unmodifiableCollection(delegate.values());
@ -274,6 +318,7 @@ public class AttributeMap<E> {
return values;
}
@Override
public Set<Entry<Attribute, E>> entrySet() {
if (entrySet == null) {
entrySet = new UnwrappingSet<Entry<AttributeWrapper, E>, Entry<Attribute, E>>(delegate.entrySet()) {
@ -301,6 +346,7 @@ public class AttributeMap<E> {
return entrySet;
}
@Override
public void forEach(BiConsumer<? super Attribute, ? super E> action) {
delegate.forEach((k, v) -> action.accept(k.attr, v));
}

View File

@ -57,6 +57,10 @@ public class AttributeSet implements Set<Attribute> {
delegate.addAll(other.delegate);
}
public AttributeSet combine(AttributeSet other) {
return new AttributeSet(delegate.combine(other.delegate));
}
public AttributeSet subtract(AttributeSet other) {
return new AttributeSet(delegate.subtract(other.delegate));
}

View File

@ -8,8 +8,8 @@ package org.elasticsearch.xpack.sql.expression;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.capabilities.Resolvable;
import org.elasticsearch.xpack.sql.capabilities.Resolvables;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.tree.Node;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.util.StringUtils;
@ -64,6 +64,7 @@ public abstract class Expression extends Node<Expression> implements Resolvable
private TypeResolution lazyTypeResolution = null;
private Boolean lazyChildrenResolved = null;
private Expression lazyCanonical = null;
private AttributeSet lazyReferences = null;
public Expression(Source source, List<Expression> children) {
super(source, children);
@ -82,7 +83,10 @@ public abstract class Expression extends Node<Expression> implements Resolvable
// the references/inputs/leaves of the expression tree
public AttributeSet references() {
return Expressions.references(children());
if (lazyReferences == null) {
lazyReferences = Expressions.references(children());
}
return lazyReferences;
}
public boolean childrenResolved() {

View File

@ -36,7 +36,7 @@ public final class Expressions {
private Expressions() {}
public static NamedExpression wrapAsNamed(Expression exp) {
return exp instanceof NamedExpression ? (NamedExpression) exp : new Alias(exp.source(), exp.nodeName(), exp);
return exp instanceof NamedExpression ? (NamedExpression) exp : new Alias(exp.source(), exp.sourceText(), exp);
}
public static List<Attribute> asAttributes(List<? extends NamedExpression> named) {

View File

@ -91,4 +91,9 @@ public abstract class NamedExpression extends Expression {
&& Objects.equals(name, other.name)
&& Objects.equals(children(), other.children());
}
@Override
public String toString() {
return super.toString() + "#" + id();
}
}

View File

@ -49,11 +49,6 @@ public abstract class Function extends NamedExpression {
return Expressions.nullable(children());
}
@Override
public String toString() {
return sourceText() + "#" + id();
}
public String functionName() {
return functionName;
}

View File

@ -166,7 +166,7 @@ public class UnresolvedFunction extends Function implements Unresolvable {
@Override
public String toString() {
return UNRESOLVED_PREFIX + sourceText();
return UNRESOLVED_PREFIX + name + children();
}
@Override

View File

@ -52,7 +52,7 @@ public abstract class AggregateFunction extends Function {
public AggregateFunctionAttribute toAttribute() {
if (lazyAttribute == null) {
// this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl)
lazyAttribute = new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId(), null);
lazyAttribute = new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId());
}
return lazyAttribute;
}

View File

@ -18,23 +18,36 @@ import java.util.Objects;
public class AggregateFunctionAttribute extends FunctionAttribute {
// used when dealing with a inner agg (avg -> stats) to keep track of
// packed id
// used since the functionId points to the compoundAgg
private final ExpressionId innerId;
private final String propertyPath;
AggregateFunctionAttribute(Source source, String name, DataType dataType, ExpressionId id,
String functionId, String propertyPath) {
this(source, name, dataType, null, Nullability.FALSE, id, false, functionId, propertyPath);
AggregateFunctionAttribute(Source source, String name, DataType dataType, ExpressionId id, String functionId) {
this(source, name, dataType, null, Nullability.FALSE, id, false, functionId, null, null);
}
public AggregateFunctionAttribute(Source source, String name, DataType dataType, String qualifier,
Nullability nullability, ExpressionId id, boolean synthetic, String functionId, String propertyPath) {
AggregateFunctionAttribute(Source source, String name, DataType dataType, ExpressionId id, String functionId, ExpressionId innerId,
String propertyPath) {
this(source, name, dataType, null, Nullability.FALSE, id, false, functionId, innerId, propertyPath);
}
public AggregateFunctionAttribute(Source source, String name, DataType dataType, String qualifier, Nullability nullability,
ExpressionId id, boolean synthetic, String functionId, ExpressionId innerId, String propertyPath) {
super(source, name, dataType, qualifier, nullability, id, synthetic, functionId);
this.innerId = innerId;
this.propertyPath = propertyPath;
}
@Override
protected NodeInfo<AggregateFunctionAttribute> info() {
return NodeInfo.create(this, AggregateFunctionAttribute::new,
name(), dataType(), qualifier(), nullable(), id(), synthetic(), functionId(), propertyPath);
return NodeInfo.create(this, AggregateFunctionAttribute::new, name(), dataType(), qualifier(), nullable(), id(), synthetic(),
functionId(), innerId, propertyPath);
}
public ExpressionId innerId() {
return innerId != null ? innerId : id();
}
public String propertyPath() {
@ -43,33 +56,38 @@ public class AggregateFunctionAttribute extends FunctionAttribute {
@Override
protected Expression canonicalize() {
return new AggregateFunctionAttribute(source(), "<none>", dataType(), null, Nullability.TRUE, id(), false, "<none>", null);
return new AggregateFunctionAttribute(source(), "<none>", dataType(), null, Nullability.TRUE, id(), false, "<none>", null, null);
}
@Override
protected Attribute clone(Source source, String name, String qualifier, Nullability nullability, ExpressionId id, boolean synthetic) {
// this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl)
// that is the functionId is actually derived from the expression id to easily track it across contexts
return new AggregateFunctionAttribute(source, name, dataType(), qualifier, nullability, id, synthetic, functionId(), propertyPath);
return new AggregateFunctionAttribute(source, name, dataType(), qualifier, nullability, id, synthetic, functionId(), innerId,
propertyPath);
}
public AggregateFunctionAttribute withFunctionId(String functionId, String propertyPath) {
return new AggregateFunctionAttribute(source(), name(), dataType(), qualifier(), nullable(),
id(), synthetic(), functionId, propertyPath);
return new AggregateFunctionAttribute(source(), name(), dataType(), qualifier(), nullable(), id(), synthetic(), functionId, innerId,
propertyPath);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), propertyPath);
return Objects.hash(super.hashCode(), innerId, propertyPath);
}
@Override
public boolean equals(Object obj) {
return super.equals(obj) && Objects.equals(propertyPath(), ((AggregateFunctionAttribute) obj).propertyPath());
if (super.equals(obj)) {
AggregateFunctionAttribute other = (AggregateFunctionAttribute) obj;
return Objects.equals(innerId, other.innerId) && Objects.equals(propertyPath, other.propertyPath);
}
return false;
}
@Override
protected String label() {
return "a->" + functionId();
return "a->" + innerId();
}
}
}

View File

@ -77,11 +77,11 @@ public class Count extends AggregateFunction {
public AggregateFunctionAttribute toAttribute() {
// COUNT(*) gets its value from the parent aggregation on which _count is called
if (field() instanceof Literal) {
return new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId(), "_count");
return new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId(), id(), "_count");
}
// COUNT(column) gets its value from a sibling aggregation (an exists filter agg) by calling its id and then _count on it
if (!distinct()) {
return new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId(), functionId() + "._count");
return new AggregateFunctionAttribute(source(), name(), dataType(), id(), functionId(), id(), functionId() + "._count");
}
return super.toAttribute();
}

View File

@ -7,8 +7,8 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.function.Function;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.type.DataType;
import java.util.List;
@ -17,7 +17,7 @@ public class InnerAggregate extends AggregateFunction {
private final AggregateFunction inner;
private final CompoundNumericAggregate outer;
private final String innerId;
private final String innerName;
// used when the result needs to be extracted from a map (like in MatrixAggs or Percentiles)
private final Expression innerKey;
@ -29,7 +29,7 @@ public class InnerAggregate extends AggregateFunction {
super(source, outer.field(), outer.arguments());
this.inner = inner;
this.outer = outer;
this.innerId = ((EnclosedAgg) inner).innerName();
this.innerName = ((EnclosedAgg) inner).innerName();
this.innerKey = innerKey;
}
@ -55,8 +55,8 @@ public class InnerAggregate extends AggregateFunction {
return outer;
}
public String innerId() {
return innerId;
public String innerName() {
return innerName;
}
public Expression innerKey() {
@ -77,10 +77,10 @@ public class InnerAggregate extends AggregateFunction {
public AggregateFunctionAttribute toAttribute() {
// this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl)
return new AggregateFunctionAttribute(source(), name(), dataType(), outer.id(), functionId(),
aggMetricValue(functionId(), innerId));
inner.id(), aggMetricValue(functionId(), innerName));
}
public static String aggMetricValue(String aggPath, String valueName) {
private static String aggMetricValue(String aggPath, String valueName) {
// handle aggPath inconsistency (for percentiles and percentileRanks) percentile[99.9] (valid) vs percentile.99.9 (invalid)
return aggPath + "[" + valueName + "]";
}
@ -98,4 +98,9 @@ public class InnerAggregate extends AggregateFunction {
public String name() {
return inner.name();
}
@Override
public String toString() {
return nodeName() + "[" + outer + ">" + inner.nodeName() + "#" + inner.id() + "]";
}
}

View File

@ -10,7 +10,6 @@ import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer.CleanAliases;
import org.elasticsearch.xpack.sql.expression.Alias;
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.AttributeMap;
import org.elasticsearch.xpack.sql.expression.AttributeSet;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.ExpressionId;
import org.elasticsearch.xpack.sql.expression.ExpressionSet;
@ -78,22 +77,23 @@ import org.elasticsearch.xpack.sql.rule.Rule;
import org.elasticsearch.xpack.sql.rule.RuleExecutor;
import org.elasticsearch.xpack.sql.session.EmptyExecutable;
import org.elasticsearch.xpack.sql.session.SingletonExecutable;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.util.CollectionUtils;
import org.elasticsearch.xpack.sql.util.Holder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import static java.util.stream.Collectors.toList;
import static org.elasticsearch.xpack.sql.expression.Literal.FALSE;
import static org.elasticsearch.xpack.sql.expression.Literal.TRUE;
import static org.elasticsearch.xpack.sql.expression.predicate.Predicates.combineAnd;
@ -117,19 +117,8 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
@Override
protected Iterable<RuleExecutor<LogicalPlan>.Batch> batches() {
Batch aggregate = new Batch("Aggregation",
new PruneDuplicatesInGroupBy(),
new ReplaceDuplicateAggsWithReferences(),
new ReplaceAggsWithMatrixStats(),
new ReplaceAggsWithExtendedStats(),
new ReplaceAggsWithStats(),
new PromoteStatsToExtendedStats(),
new ReplaceAggsWithPercentiles(),
new ReplaceAggsWithPercentileRanks(),
new ReplaceMinMaxWithTopHits()
);
Batch operators = new Batch("Operator Optimization",
new PruneDuplicatesInGroupBy(),
// combining
new CombineProjections(),
// folding
@ -157,6 +146,16 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
//new PruneDuplicateFunctions()
);
Batch aggregate = new Batch("Aggregation Rewrite",
//new ReplaceDuplicateAggsWithReferences(),
new ReplaceAggsWithMatrixStats(),
new ReplaceAggsWithExtendedStats(),
new ReplaceAggsWithStats(),
new PromoteStatsToExtendedStats(),
new ReplaceAggsWithPercentiles(),
new ReplaceAggsWithPercentileRanks()
);
Batch local = new Batch("Skip Elasticsearch",
new SkipQueryOnLimitZero(),
new SkipQueryIfFoldingProjection()
@ -253,7 +252,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
seen.put(argument, matrixStats);
}
InnerAggregate ia = new InnerAggregate(f.source(), f, matrixStats, f.field());
InnerAggregate ia = new InnerAggregate(f.source(), f, matrixStats, argument);
promotedIds.putIfAbsent(f.functionId(), ia.toAttribute());
return ia;
}
@ -310,8 +309,8 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
private static class Match {
final Stats stats;
int count = 1;
final Set<Class<? extends AggregateFunction>> functionTypes = new LinkedHashSet<>();
private final Set<Class<? extends AggregateFunction>> functionTypes = new LinkedHashSet<>();
private Map<Class<? extends AggregateFunction>, InnerAggregate> innerAggs = null;
Match(Stats stats) {
this.stats = stats;
@ -321,6 +320,22 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
public String toString() {
return stats.toString();
}
public void add(Class<? extends AggregateFunction> aggType) {
functionTypes.add(aggType);
}
// if the stat has at least two different functions for it, promote it as stat
// also keep the promoted function around for reuse
public AggregateFunction maybePromote(AggregateFunction agg) {
if (functionTypes.size() > 1) {
if (innerAggs == null) {
innerAggs = new LinkedHashMap<>();
}
return innerAggs.computeIfAbsent(agg.getClass(), k -> new InnerAggregate(agg, stats));
}
return agg;
}
}
@Override
@ -359,15 +374,10 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
Match match = seen.get(argument);
if (match == null) {
match = new Match(new Stats(f.source(), argument));
match.functionTypes.add(f.getClass());
match = new Match(new Stats(new Source(f.sourceLocation(), "STATS(" + Expressions.name(argument) + ")"), argument));
seen.put(argument, match);
}
else {
if (match.functionTypes.add(f.getClass())) {
match.count++;
}
}
match.add(f.getClass());
}
return e;
@ -378,13 +388,14 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
AggregateFunction f = (AggregateFunction) e;
Expression argument = f.field();
Match counter = seen.get(argument);
Match match = seen.get(argument);
// if the stat has at least two different functions for it, promote it as stat
if (counter != null && counter.count > 1) {
InnerAggregate innerAgg = new InnerAggregate(f, counter.stats);
attrs.putIfAbsent(f.functionId(), innerAgg.toAttribute());
return innerAgg;
if (match != null) {
AggregateFunction inner = match.maybePromote(f);
if (inner != f) {
attrs.putIfAbsent(f.functionId(), inner.toAttribute());
}
return inner;
}
}
return e;
@ -819,31 +830,23 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
@Override
protected LogicalPlan rule(OrderBy ob) {
List<Order> order = ob.order();
Holder<Boolean> foundAggregate = new Holder<>(Boolean.FALSE);
Holder<Boolean> foundImplicitGroupBy = new Holder<>(Boolean.FALSE);
// remove constants
List<Order> nonConstant = order.stream().filter(o -> !o.child().foldable()).collect(toList());
if (nonConstant.isEmpty()) {
return ob.child();
}
// if the sort points to an agg, consider it only if there's grouping
if (ob.child() instanceof Aggregate) {
Aggregate a = (Aggregate) ob.child();
if (a.groupings().isEmpty()) {
AttributeSet aggsAttr = new AttributeSet(Expressions.asAttributes(a.aggregates()));
List<Order> nonAgg = nonConstant.stream().filter(o -> {
if (o.child() instanceof NamedExpression) {
return !aggsAttr.contains(((NamedExpression) o.child()).toAttribute());
}
return true;
}).collect(toList());
return nonAgg.isEmpty() ? ob.child() : new OrderBy(ob.source(), ob.child(), nonAgg);
// if the first found aggregate has no grouping, there's no need to do ordering
ob.forEachDown(a -> {
// take into account
if (foundAggregate.get() == Boolean.TRUE) {
return;
}
foundAggregate.set(Boolean.TRUE);
if (a.groupings().isEmpty()) {
foundImplicitGroupBy.set(Boolean.TRUE);
}
}, Aggregate.class);
if (foundImplicitGroupBy.get() == Boolean.TRUE) {
return ob.child();
}
return ob;
}
@ -858,34 +861,43 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
protected LogicalPlan rule(OrderBy ob) {
List<Order> order = ob.order();
// remove constants
List<Order> nonConstant = order.stream().filter(o -> !o.child().foldable()).collect(toList());
// remove constants and put the items in reverse order so the iteration happens back to front
List<Order> nonConstant = new LinkedList<>();
for (Order o : order) {
if (o.child().foldable() == false) {
nonConstant.add(0, o);
}
}
// if the sort points to an agg, change the agg order based on the order
if (ob.child() instanceof Aggregate) {
Aggregate a = (Aggregate) ob.child();
List<Expression> groupings = new ArrayList<>(a.groupings());
boolean orderChanged = false;
Holder<Boolean> foundAggregate = new Holder<>(Boolean.FALSE);
for (int orderIndex = 0; orderIndex < nonConstant.size(); orderIndex++) {
Order o = nonConstant.get(orderIndex);
// if the first found aggregate has no grouping, there's no need to do ordering
return ob.transformDown(a -> {
// take into account
if (foundAggregate.get() == Boolean.TRUE) {
return a;
}
foundAggregate.set(Boolean.TRUE);
List<Expression> groupings = new LinkedList<>(a.groupings());
for (Order o : nonConstant) {
Expression fieldToOrder = o.child();
for (Expression group : a.groupings()) {
if (Expressions.equalsAsAttribute(fieldToOrder, group)) {
// move grouping in front
groupings.remove(group);
groupings.add(orderIndex, group);
orderChanged = true;
groupings.add(0, group);
}
}
}
if (orderChanged) {
Aggregate newAgg = new Aggregate(a.source(), a.child(), groupings, a.aggregates());
return new OrderBy(ob.source(), newAgg, ob.order());
if (groupings.equals(a.groupings()) == false) {
return new Aggregate(a.source(), a.child(), groupings, a.aggregates());
}
}
return ob;
return a;
}, Aggregate.class);
}
}
@ -1017,6 +1029,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
// eliminate lower project but first replace the aliases in the upper one
return new Project(p.source(), p.child(), combineProjections(project.projections(), p.projections()));
}
if (child instanceof Aggregate) {
Aggregate a = (Aggregate) child;
return new Aggregate(a.source(), a.child(), a.groupings(), combineProjections(project.projections(), a.aggregates()));
@ -1029,23 +1042,25 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
// that might be reused by the upper one, these need to be replaced.
// for example an alias defined in the lower list might be referred in the upper - without replacing it the alias becomes invalid
private List<NamedExpression> combineProjections(List<? extends NamedExpression> upper, List<? extends NamedExpression> lower) {
//TODO: this need rewriting when moving functions of NamedExpression
// collect aliases in the lower list
Map<Attribute, Alias> map = new LinkedHashMap<>();
Map<Attribute, NamedExpression> map = new LinkedHashMap<>();
for (NamedExpression ne : lower) {
if (ne instanceof Alias) {
Alias a = (Alias) ne;
map.put(a.toAttribute(), a);
if ((ne instanceof Attribute) == false) {
map.put(ne.toAttribute(), ne);
}
}
AttributeMap<Alias> aliases = new AttributeMap<>(map);
AttributeMap<NamedExpression> aliases = new AttributeMap<>(map);
List<NamedExpression> replaced = new ArrayList<>();
// replace any matching attribute with a lower alias (if there's a match)
// but clean-up non-top aliases at the end
for (NamedExpression ne : upper) {
NamedExpression replacedExp = (NamedExpression) ne.transformUp(a -> {
Alias as = aliases.get(a);
NamedExpression as = aliases.get(a);
return as != null ? as : a;
}, Attribute.class);
@ -1088,12 +1103,12 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
return plan;
}
AtomicBoolean stop = new AtomicBoolean(false);
Holder<Boolean> stop = new Holder<>(Boolean.FALSE);
// propagate folding up to unary nodes
// anything higher and the propagate stops
plan = plan.transformUp(p -> {
if (stop.get() == false && canPropagateFoldable(p)) {
if (stop.get() == Boolean.FALSE && canPropagateFoldable(p)) {
return p.transformExpressionsDown(e -> {
if (e instanceof Attribute && attrs.contains(e)) {
Alias as = aliases.get(e);
@ -1108,7 +1123,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
}
if (p.children().size() > 1) {
stop.set(true);
stop.set(Boolean.TRUE);
}
return p;

View File

@ -60,13 +60,11 @@ public class TableIdentifier {
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("[");
if (cluster != null) {
builder.append(cluster);
builder.append(":");
}
builder.append("][index=");
builder.append(index);
builder.append("]");
return builder.toString();
}
}

View File

@ -8,13 +8,15 @@ package org.elasticsearch.xpack.sql.plan.logical;
import org.elasticsearch.xpack.sql.capabilities.Unresolvable;
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.plan.TableIdentifier;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.tree.Source;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import static java.util.Collections.singletonList;
public class UnresolvedRelation extends LeafPlan implements Unresolvable {
private final TableIdentifier table;
@ -86,4 +88,14 @@ public class UnresolvedRelation extends LeafPlan implements Unresolvable {
&& Objects.equals(alias, other.alias)
&& unresolvedMsg.equals(other.unresolvedMsg);
}
@Override
public List<Object> nodeProperties() {
return singletonList(table);
}
@Override
public String toString() {
return UNRESOLVED_PREFIX + table.index();
}
}

View File

@ -9,11 +9,10 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.sql.execution.search.Querier;
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.querydsl.container.QueryContainer;
import org.elasticsearch.xpack.sql.session.Rows;
import org.elasticsearch.xpack.sql.session.SchemaRowSet;
import org.elasticsearch.xpack.sql.session.SqlSession;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.tree.Source;
import java.util.List;
import java.util.Objects;
@ -22,7 +21,6 @@ public class EsQueryExec extends LeafExec {
private final String index;
private final List<Attribute> output;
private final QueryContainer queryContainer;
public EsQueryExec(Source source, String index, List<Attribute> output, QueryContainer queryContainer) {
@ -56,8 +54,9 @@ public class EsQueryExec extends LeafExec {
@Override
public void execute(SqlSession session, ActionListener<SchemaRowSet> listener) {
Querier scroller = new Querier(session.client(), session.configuration());
scroller.query(Rows.schema(output), queryContainer, index, listener);
Querier scroller = new Querier(session);
scroller.query(output, queryContainer, index, listener);
}
@Override
@ -85,4 +84,4 @@ public class EsQueryExec extends LeafExec {
public String nodeString() {
return nodeName() + "[" + index + "," + queryContainer + "]";
}
}
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.execution.search.AggRef;
import org.elasticsearch.xpack.sql.expression.Alias;
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.AttributeMap;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Foldables;
@ -146,8 +147,12 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
}
}
QueryContainer clone = new QueryContainer(queryC.query(), queryC.aggs(), queryC.columns(), aliases,
queryC.pseudoFunctions(), processors, queryC.sort(), queryC.limit());
QueryContainer clone = new QueryContainer(queryC.query(), queryC.aggs(), queryC.fields(),
new AttributeMap<>(aliases),
queryC.pseudoFunctions(),
new AttributeMap<>(processors),
queryC.sort(),
queryC.limit());
return new EsQueryExec(exec.source(), exec.index(), project.output(), clone);
}
return project;
@ -170,7 +175,8 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
}
Aggs aggs = addPipelineAggs(qContainer, qt, plan);
qContainer = new QueryContainer(query, aggs, qContainer.columns(), qContainer.aliases(),
qContainer = new QueryContainer(query, aggs, qContainer.fields(),
qContainer.aliases(),
qContainer.pseudoFunctions(),
qContainer.scalarFunctions(),
qContainer.sort(),
@ -315,7 +321,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
}
// add the computed column
queryC = qC.get().addColumn(new ComputedRef(proc));
queryC = qC.get().addColumn(new ComputedRef(proc), f.toAttribute());
// TODO: is this needed?
// redirect the alias to the scalar group id (changing the id altogether doesn't work it is
@ -337,20 +343,22 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
// UTC is used since that's what the server uses and there's no conversion applied
// (like for date histograms)
ZoneId zi = child.dataType().isDateBased() ? DateUtils.UTC : null;
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, zi));
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, zi), ((Attribute) child));
}
// handle histogram
else if (child instanceof GroupingFunction) {
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, null));
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, null),
((GroupingFunction) child).toAttribute());
}
// fallback to regular agg functions
else {
// the only thing left is agg function
Check.isTrue(Functions.isAggregate(child),
"Expected aggregate function inside alias; got [{}]", child.nodeString());
Tuple<QueryContainer, AggPathInput> withAgg = addAggFunction(matchingGroup,
(AggregateFunction) child, compoundAggMap, queryC);
queryC = withAgg.v1().addColumn(withAgg.v2().context());
AggregateFunction af = (AggregateFunction) child;
Tuple<QueryContainer, AggPathInput> withAgg = addAggFunction(matchingGroup, af, compoundAggMap, queryC);
// make sure to add the inner id (to handle compound aggs)
queryC = withAgg.v1().addColumn(withAgg.v2().context(), af.toAttribute());
}
}
// not an Alias or Function means it's an Attribute so apply the same logic as above
@ -361,7 +369,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
Check.notNull(matchingGroup, "Cannot find group [{}]", Expressions.name(ne));
ZoneId zi = ne.dataType().isDateBased() ? DateUtils.UTC : null;
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, zi));
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, zi), ne.toAttribute());
}
}
}
@ -369,7 +377,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
if (!aliases.isEmpty()) {
Map<Attribute, Attribute> newAliases = new LinkedHashMap<>(queryC.aliases());
newAliases.putAll(aliases);
queryC = queryC.withAliases(newAliases);
queryC = queryC.withAliases(new AttributeMap<>(newAliases));
}
return new EsQueryExec(exec.source(), exec.index(), a.output(), queryC);
}
@ -420,7 +428,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
// FIXME: concern leak - hack around MatrixAgg which is not
// generalized (afaik)
aggInput = new AggPathInput(f,
new MetricAggRef(cAggPath, ia.innerId(), ia.innerKey() != null ? QueryTranslator.nameOf(ia.innerKey()) : null));
new MetricAggRef(cAggPath, ia.innerName(), ia.innerKey() != null ? QueryTranslator.nameOf(ia.innerKey()) : null));
}
else {
LeafAgg leafAgg = toAgg(functionId, f);
@ -474,19 +482,19 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
if (sfa.orderBy() instanceof NamedExpression) {
Attribute at = ((NamedExpression) sfa.orderBy()).toAttribute();
at = qContainer.aliases().getOrDefault(at, at);
qContainer = qContainer.sort(new AttributeSort(at, direction, missing));
qContainer = qContainer.addSort(new AttributeSort(at, direction, missing));
} else if (!sfa.orderBy().foldable()) {
// ignore constant
throw new PlanningException("does not know how to order by expression {}", sfa.orderBy());
}
} else {
// nope, use scripted sorting
qContainer = qContainer.sort(new ScriptSort(sfa.script(), direction, missing));
qContainer = qContainer.addSort(new ScriptSort(sfa.script(), direction, missing));
}
} else if (attr instanceof ScoreAttribute) {
qContainer = qContainer.sort(new ScoreSort(direction, missing));
qContainer = qContainer.addSort(new ScoreSort(direction, missing));
} else {
qContainer = qContainer.sort(new AttributeSort(attr, direction, missing));
qContainer = qContainer.addSort(new AttributeSort(attr, direction, missing));
}
}
}

View File

@ -12,6 +12,7 @@ import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregationBui
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.querydsl.container.Sort.Direction;
import org.elasticsearch.xpack.sql.util.StringUtils;
import java.util.ArrayList;
import java.util.Collection;
@ -21,7 +22,6 @@ import java.util.Objects;
import static java.util.Collections.emptyList;
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
import static org.elasticsearch.xpack.sql.util.CollectionUtils.combine;
import static org.elasticsearch.xpack.sql.util.StringUtils.EMPTY;
/**
* SQL Aggregations associated with a query.
@ -40,7 +40,7 @@ public class Aggs {
public static final String ROOT_GROUP_NAME = "groupby";
public static final GroupByKey IMPLICIT_GROUP_KEY = new GroupByKey(ROOT_GROUP_NAME, EMPTY, null, null) {
public static final GroupByKey IMPLICIT_GROUP_KEY = new GroupByKey(ROOT_GROUP_NAME, StringUtils.EMPTY, null, null) {
@Override
public CompositeValuesSourceBuilder<?> createSourceBuilder() {
@ -53,14 +53,12 @@ public class Aggs {
}
};
public static final Aggs EMPTY = new Aggs(emptyList(), emptyList(), emptyList());
private final List<GroupByKey> groups;
private final List<LeafAgg> simpleAggs;
private final List<PipelineAgg> pipelineAggs;
public Aggs() {
this(emptyList(), emptyList(), emptyList());
}
public Aggs(List<GroupByKey> groups, List<LeafAgg> simpleAggs, List<PipelineAgg> pipelineAggs) {
this.groups = groups;

View File

@ -15,9 +15,12 @@ import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.execution.search.FieldExtraction;
import org.elasticsearch.xpack.sql.execution.search.SourceGenerator;
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.AttributeMap;
import org.elasticsearch.xpack.sql.expression.ExpressionId;
import org.elasticsearch.xpack.sql.expression.FieldAttribute;
import org.elasticsearch.xpack.sql.expression.LiteralAttribute;
import org.elasticsearch.xpack.sql.expression.function.ScoreAttribute;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.querydsl.agg.Aggs;
@ -33,7 +36,9 @@ import org.elasticsearch.xpack.sql.tree.Source;
import java.io.IOException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
@ -47,48 +52,142 @@ import static java.util.Collections.emptySet;
import static java.util.Collections.singletonMap;
import static org.elasticsearch.xpack.sql.util.CollectionUtils.combine;
/**
* Container for various references of the built ES query.
* Useful to understanding how to interpret and navigate the
* returned result.
*/
public class QueryContainer {
private final Aggs aggs;
private final Query query;
// final output seen by the client (hence the list or ordering)
// gets converted by the Scroller into Extractors for hits or actual results in case of aggregations
private final List<FieldExtraction> columns;
// fields extracted from the response - not necessarily what the client sees
// for example in case of grouping or custom sorting, the response has extra columns
// that is filtered before getting to the client
// the list contains both the field extraction and the id of its associated attribute (for custom sorting)
private final List<Tuple<FieldExtraction, ExpressionId>> fields;
// aliases (maps an alias to its actual resolved attribute)
private final Map<Attribute, Attribute> aliases;
private final AttributeMap<Attribute> aliases;
// pseudo functions (like count) - that are 'extracted' from other aggs
private final Map<String, GroupByKey> pseudoFunctions;
// scalar function processors - recorded as functions get folded;
// at scrolling, their inputs (leaves) get updated
private final Map<Attribute, Pipe> scalarFunctions;
private final AttributeMap<Pipe> scalarFunctions;
private final Set<Sort> sort;
private final int limit;
// computed
private final boolean aggsOnly;
private Boolean aggsOnly;
private Boolean customSort;
public QueryContainer() {
this(null, null, null, null, null, null, null, -1);
}
public QueryContainer(Query query, Aggs aggs, List<FieldExtraction> refs, Map<Attribute, Attribute> aliases,
public QueryContainer(Query query,
Aggs aggs,
List<Tuple<FieldExtraction, ExpressionId>> fields,
AttributeMap<Attribute> aliases,
Map<String, GroupByKey> pseudoFunctions,
Map<Attribute, Pipe> scalarFunctions,
Set<Sort> sort, int limit) {
AttributeMap<Pipe> scalarFunctions,
Set<Sort> sort,
int limit) {
this.query = query;
this.aggs = aggs == null ? new Aggs() : aggs;
this.aliases = aliases == null || aliases.isEmpty() ? emptyMap() : aliases;
this.aggs = aggs == null ? Aggs.EMPTY : aggs;
this.fields = fields == null || fields.isEmpty() ? emptyList() : fields;
this.aliases = aliases == null || aliases.isEmpty() ? AttributeMap.emptyAttributeMap() : aliases;
this.pseudoFunctions = pseudoFunctions == null || pseudoFunctions.isEmpty() ? emptyMap() : pseudoFunctions;
this.scalarFunctions = scalarFunctions == null || scalarFunctions.isEmpty() ? emptyMap() : scalarFunctions;
this.columns = refs == null || refs.isEmpty() ? emptyList() : refs;
this.scalarFunctions = scalarFunctions == null || scalarFunctions.isEmpty() ? AttributeMap.emptyAttributeMap() : scalarFunctions;
this.sort = sort == null || sort.isEmpty() ? emptySet() : sort;
this.limit = limit;
aggsOnly = columns.stream().allMatch(FieldExtraction::supportedByAggsOnlyQuery);
}
/**
* If needed, create a comparator for each indicated column (which is indicated by an index pointing to the column number from the
* result set).
*/
@SuppressWarnings({ "rawtypes", "unchecked" })
public List<Tuple<Integer, Comparator>> sortingColumns() {
if (customSort == Boolean.FALSE) {
return emptyList();
}
List<Tuple<Integer, Comparator>> sortingColumns = new ArrayList<>(sort.size());
boolean aggSort = false;
for (Sort s : sort) {
Tuple<Integer, Comparator> tuple = new Tuple<>(Integer.valueOf(-1), null);
if (s instanceof AttributeSort) {
AttributeSort as = (AttributeSort) s;
// find the relevant column of each aggregate function
if (as.attribute() instanceof AggregateFunctionAttribute) {
aggSort = true;
AggregateFunctionAttribute afa = (AggregateFunctionAttribute) as.attribute();
afa = (AggregateFunctionAttribute) aliases.getOrDefault(afa, afa);
int atIndex = -1;
for (int i = 0; i < fields.size(); i++) {
Tuple<FieldExtraction, ExpressionId> field = fields.get(i);
if (field.v2().equals(afa.innerId())) {
atIndex = i;
break;
}
}
if (atIndex == -1) {
throw new SqlIllegalArgumentException("Cannot find backing column for ordering aggregation [{}]", afa.name());
}
// assemble a comparator for it
Comparator comp = s.direction() == Sort.Direction.ASC ? Comparator.naturalOrder() : Comparator.reverseOrder();
comp = s.missing() == Sort.Missing.FIRST ? Comparator.nullsFirst(comp) : Comparator.nullsLast(comp);
tuple = new Tuple<>(Integer.valueOf(atIndex), comp);
}
}
sortingColumns.add(tuple);
}
if (customSort == null) {
customSort = Boolean.valueOf(aggSort);
}
return aggSort ? sortingColumns : emptyList();
}
/**
* Since the container contains both the field extractors and the visible columns,
* compact the information in the listener through a bitset that acts as a mask
* on what extractors are used for the visible columns.
*/
public BitSet columnMask(List<Attribute> columns) {
BitSet mask = new BitSet(fields.size());
for (Attribute column : columns) {
Attribute alias = aliases.get(column);
// find the column index
int index = -1;
ExpressionId id = column instanceof AggregateFunctionAttribute ? ((AggregateFunctionAttribute) column).innerId() : column.id();
ExpressionId aliasId = alias != null ? (alias instanceof AggregateFunctionAttribute ? ((AggregateFunctionAttribute) alias)
.innerId() : alias.id()) : null;
for (int i = 0; i < fields.size(); i++) {
Tuple<FieldExtraction, ExpressionId> tuple = fields.get(i);
if (tuple.v2().equals(id) || (aliasId != null && tuple.v2().equals(aliasId))) {
index = i;
break;
}
}
if (index > -1) {
mask.set(index);
} else {
throw new SqlIllegalArgumentException("Cannot resolve field extractor index for column [{}]", column);
}
}
return mask;
}
public Query query() {
@ -99,11 +198,11 @@ public class QueryContainer {
return aggs;
}
public List<FieldExtraction> columns() {
return columns;
public List<Tuple<FieldExtraction, ExpressionId>> fields() {
return fields;
}
public Map<Attribute, Attribute> aliases() {
public AttributeMap<Attribute> aliases() {
return aliases;
}
@ -120,11 +219,15 @@ public class QueryContainer {
}
public boolean isAggsOnly() {
return aggsOnly;
if (aggsOnly == null) {
aggsOnly = Boolean.valueOf(this.fields.stream().allMatch(t -> t.v1().supportedByAggsOnlyQuery()));
}
return aggsOnly.booleanValue();
}
public boolean hasColumns() {
return !columns.isEmpty();
return fields.size() > 0;
}
//
@ -132,37 +235,33 @@ public class QueryContainer {
//
public QueryContainer with(Query q) {
return new QueryContainer(q, aggs, columns, aliases, pseudoFunctions, scalarFunctions, sort, limit);
return new QueryContainer(q, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit);
}
public QueryContainer with(List<FieldExtraction> r) {
return new QueryContainer(query, aggs, r, aliases, pseudoFunctions, scalarFunctions, sort, limit);
}
public QueryContainer withAliases(Map<Attribute, Attribute> a) {
return new QueryContainer(query, aggs, columns, a, pseudoFunctions, scalarFunctions, sort, limit);
public QueryContainer withAliases(AttributeMap<Attribute> a) {
return new QueryContainer(query, aggs, fields, a, pseudoFunctions, scalarFunctions, sort, limit);
}
public QueryContainer withPseudoFunctions(Map<String, GroupByKey> p) {
return new QueryContainer(query, aggs, columns, aliases, p, scalarFunctions, sort, limit);
return new QueryContainer(query, aggs, fields, aliases, p, scalarFunctions, sort, limit);
}
public QueryContainer with(Aggs a) {
return new QueryContainer(query, a, columns, aliases, pseudoFunctions, scalarFunctions, sort, limit);
return new QueryContainer(query, a, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit);
}
public QueryContainer withLimit(int l) {
return l == limit ? this : new QueryContainer(query, aggs, columns, aliases, pseudoFunctions, scalarFunctions, sort, l);
return l == limit ? this : new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, l);
}
public QueryContainer withScalarProcessors(Map<Attribute, Pipe> procs) {
return new QueryContainer(query, aggs, columns, aliases, pseudoFunctions, procs, sort, limit);
public QueryContainer withScalarProcessors(AttributeMap<Pipe> procs) {
return new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, procs, sort, limit);
}
public QueryContainer sort(Sort sortable) {
public QueryContainer addSort(Sort sortable) {
Set<Sort> sort = new LinkedHashSet<>(this.sort);
sort.add(sortable);
return new QueryContainer(query, aggs, columns, aliases, pseudoFunctions, scalarFunctions, sort, limit);
return new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit);
}
private String aliasName(Attribute attr) {
@ -188,7 +287,8 @@ public class QueryContainer {
attr.field().isAggregatable(), attr.parent().name());
nestedRefs.add(nestedFieldRef);
return new Tuple<>(new QueryContainer(q, aggs, columns, aliases, pseudoFunctions, scalarFunctions, sort, limit), nestedFieldRef);
return new Tuple<>(new QueryContainer(q, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit),
nestedFieldRef);
}
static Query rewriteToContainNestedField(@Nullable Query query, Source source, String path, String name, String format,
@ -255,13 +355,13 @@ public class QueryContainer {
// update proc
Map<Attribute, Pipe> procs = new LinkedHashMap<>(qContainer.scalarFunctions());
procs.put(attribute, proc);
qContainer = qContainer.withScalarProcessors(procs);
qContainer = qContainer.withScalarProcessors(new AttributeMap<>(procs));
return new Tuple<>(qContainer, new ComputedRef(proc));
}
public QueryContainer addColumn(Attribute attr) {
Tuple<QueryContainer, FieldExtraction> tuple = toReference(attr);
return tuple.v1().addColumn(tuple.v2());
return tuple.v1().addColumn(tuple.v2(), attr);
}
private Tuple<QueryContainer, FieldExtraction> toReference(Attribute attr) {
@ -286,11 +386,14 @@ public class QueryContainer {
throw new SqlIllegalArgumentException("Unknown output attribute {}", attr);
}
public QueryContainer addColumn(FieldExtraction ref) {
return with(combine(columns, ref));
public QueryContainer addColumn(FieldExtraction ref, Attribute attr) {
ExpressionId id = attr instanceof AggregateFunctionAttribute ? ((AggregateFunctionAttribute) attr).innerId() : attr.id();
return new QueryContainer(query, aggs, combine(fields, new Tuple<>(ref, id)), aliases, pseudoFunctions,
scalarFunctions,
sort, limit);
}
public Map<Attribute, Pipe> scalarFunctions() {
public AttributeMap<Pipe> scalarFunctions() {
return scalarFunctions;
}
@ -298,11 +401,14 @@ public class QueryContainer {
// agg methods
//
public QueryContainer addAggCount(GroupByKey group, String functionId) {
public QueryContainer addAggCount(GroupByKey group, ExpressionId functionId) {
FieldExtraction ref = group == null ? GlobalCountRef.INSTANCE : new GroupByRef(group.id(), Property.COUNT, null);
Map<String, GroupByKey> pseudoFunctions = new LinkedHashMap<>(this.pseudoFunctions);
pseudoFunctions.put(functionId, group);
return new QueryContainer(query, aggs, combine(columns, ref), aliases, pseudoFunctions, scalarFunctions, sort, limit);
pseudoFunctions.put(functionId.toString(), group);
return new QueryContainer(query, aggs, combine(fields, new Tuple<>(ref, functionId)),
aliases,
pseudoFunctions,
scalarFunctions, sort, limit);
}
public QueryContainer addAgg(String groupId, LeafAgg agg) {
@ -327,7 +433,7 @@ public class QueryContainer {
@Override
public int hashCode() {
return Objects.hash(query, aggs, columns, aliases);
return Objects.hash(query, aggs, fields, aliases, sort, limit);
}
@Override
@ -343,7 +449,7 @@ public class QueryContainer {
QueryContainer other = (QueryContainer) obj;
return Objects.equals(query, other.query)
&& Objects.equals(aggs, other.aggs)
&& Objects.equals(columns, other.columns)
&& Objects.equals(fields, other.fields)
&& Objects.equals(aliases, other.aliases)
&& Objects.equals(sort, other.sort)
&& Objects.equals(limit, other.limit);

View File

@ -9,7 +9,7 @@ import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.sql.expression.Order.NullsPosition;
import org.elasticsearch.xpack.sql.expression.Order.OrderDirection;
public class Sort {
public abstract class Sort {
public enum Direction {
ASC, DESC;

View File

@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.execution.search.CompositeAggregationCursor;
import org.elasticsearch.xpack.sql.execution.search.PagingListCursor;
import org.elasticsearch.xpack.sql.execution.search.ScrollCursor;
import org.elasticsearch.xpack.sql.execution.search.extractor.BucketExtractors;
import org.elasticsearch.xpack.sql.execution.search.extractor.HitExtractors;
@ -48,6 +49,7 @@ public final class Cursors {
entries.add(new NamedWriteableRegistry.Entry(Cursor.class, ScrollCursor.NAME, ScrollCursor::new));
entries.add(new NamedWriteableRegistry.Entry(Cursor.class, CompositeAggregationCursor.NAME, CompositeAggregationCursor::new));
entries.add(new NamedWriteableRegistry.Entry(Cursor.class, TextFormatterCursor.NAME, TextFormatterCursor::new));
entries.add(new NamedWriteableRegistry.Entry(Cursor.class, PagingListCursor.NAME, PagingListCursor::new));
// plus all their dependencies
entries.addAll(Processors.getNamedWriteables());

View File

@ -7,10 +7,10 @@ package org.elasticsearch.xpack.sql.session;
import org.elasticsearch.xpack.sql.type.Schema;
class EmptyRowSetCursor extends AbstractRowSet implements SchemaRowSet {
class EmptyRowSet extends AbstractRowSet implements SchemaRowSet {
private final Schema schema;
EmptyRowSetCursor(Schema schema) {
EmptyRowSet(Schema schema) {
this.schema = schema;
}

View File

@ -9,25 +9,25 @@ import org.elasticsearch.xpack.sql.type.Schema;
import java.util.List;
class ListRowSetCursor extends AbstractRowSet implements SchemaRowSet {
public class ListRowSet extends AbstractRowSet implements SchemaRowSet {
private final Schema schema;
private final List<List<?>> list;
private int pos = 0;
ListRowSetCursor(Schema schema, List<List<?>> list) {
protected ListRowSet(Schema schema, List<List<?>> list) {
this.schema = schema;
this.list = list;
}
@Override
protected boolean doHasCurrent() {
return pos < list.size();
return pos < size();
}
@Override
protected boolean doNext() {
if (pos + 1 < list.size()) {
if (pos + 1 < size()) {
pos++;
return true;
}

View File

@ -36,7 +36,7 @@ public abstract class Rows {
}
Schema schema = schema(attrs);
return new ListRowSetCursor(schema, values);
return new ListRowSet(schema, values);
}
public static SchemaRowSet singleton(List<Attribute> attrs, Object... values) {
@ -49,10 +49,10 @@ public abstract class Rows {
}
public static SchemaRowSet empty(Schema schema) {
return new EmptyRowSetCursor(schema);
return new EmptyRowSet(schema);
}
public static SchemaRowSet empty(List<Attribute> attrs) {
return new EmptyRowSetCursor(schema(attrs));
return new EmptyRowSet(schema(attrs));
}
}

View File

@ -19,6 +19,6 @@ public interface SchemaRowSet extends RowSet {
@Override
default int columnCount() {
return schema().names().size();
return schema().size();
}
}

View File

@ -15,6 +15,7 @@ import org.elasticsearch.xpack.sql.analysis.analyzer.Verifier;
import org.elasticsearch.xpack.sql.analysis.index.IndexResolution;
import org.elasticsearch.xpack.sql.analysis.index.IndexResolver;
import org.elasticsearch.xpack.sql.analysis.index.MappingException;
import org.elasticsearch.xpack.sql.execution.PlanExecutor;
import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry;
import org.elasticsearch.xpack.sql.optimizer.Optimizer;
import org.elasticsearch.xpack.sql.parser.SqlParser;
@ -40,20 +41,17 @@ public class SqlSession {
private final Verifier verifier;
private final Optimizer optimizer;
private final Planner planner;
private final PlanExecutor planExecutor;
private final Configuration configuration;
public SqlSession(SqlSession other) {
this(other.configuration, other.client, other.functionRegistry, other.indexResolver,
other.preAnalyzer, other.verifier, other.optimizer, other.planner);
}
public SqlSession(Configuration configuration, Client client, FunctionRegistry functionRegistry,
IndexResolver indexResolver,
PreAnalyzer preAnalyzer,
Verifier verifier,
Optimizer optimizer,
Planner planner) {
Planner planner,
PlanExecutor planExecutor) {
this.client = client;
this.functionRegistry = functionRegistry;
@ -64,6 +62,7 @@ public class SqlSession {
this.verifier = verifier;
this.configuration = configuration;
this.planExecutor = planExecutor;
}
public FunctionRegistry functionRegistry() {
@ -90,6 +89,10 @@ public class SqlSession {
return verifier;
}
public PlanExecutor planExecutor() {
return planExecutor;
}
private LogicalPlan doParse(String sql, List<SqlTypedParamValue> params) {
return new SqlParser().createStatement(sql, params);
}

View File

@ -281,6 +281,14 @@ public abstract class Node<T extends Node<T>> {
return getClass().getSimpleName();
}
/**
* The values of all the properties that are important
* to this {@link Node}.
*/
public List<Object> nodeProperties() {
return info().properties();
}
public String nodeString() {
StringBuilder sb = new StringBuilder();
sb.append(nodeName());
@ -349,7 +357,6 @@ public abstract class Node<T extends Node<T>> {
* {@code [} and {@code ]} of the output of {@link #treeString}.
*/
public String propertiesToString(boolean skipIfChild) {
NodeInfo<? extends Node<T>> info = info();
StringBuilder sb = new StringBuilder();
List<?> children = children();
@ -358,7 +365,7 @@ public abstract class Node<T extends Node<T>> {
int maxWidth = 0;
boolean needsComma = false;
List<Object> props = info.properties();
List<Object> props = nodeProperties();
for (Object prop : props) {
// consider a property if it is not ignored AND
// it's not a child (optional)
@ -388,12 +395,4 @@ public abstract class Node<T extends Node<T>> {
return sb.toString();
}
/**
* The values of all the properties that are important
* to this {@link Node}.
*/
public List<Object> properties() {
return info().properties();
}
}
}

View File

@ -349,6 +349,54 @@ public abstract class NodeInfo<T extends Node<?>> {
T apply(Source l, P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6, P7 p7, P8 p8);
}
public static <T extends Node<?>, P1, P2, P3, P4, P5, P6, P7, P8, P9> NodeInfo<T> create(
T n, NodeCtor9<P1, P2, P3, P4, P5, P6, P7, P8, P9, T> ctor,
P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6, P7 p7, P8 p8, P9 p9) {
return new NodeInfo<T>(n) {
@Override
protected List<Object> innerProperties() {
return Arrays.asList(p1, p2, p3, p4, p5, p6, p7, p8, p9);
}
protected T innerTransform(Function<Object, Object> rule) {
boolean same = true;
@SuppressWarnings("unchecked")
P1 newP1 = (P1) rule.apply(p1);
same &= Objects.equals(p1, newP1);
@SuppressWarnings("unchecked")
P2 newP2 = (P2) rule.apply(p2);
same &= Objects.equals(p2, newP2);
@SuppressWarnings("unchecked")
P3 newP3 = (P3) rule.apply(p3);
same &= Objects.equals(p3, newP3);
@SuppressWarnings("unchecked")
P4 newP4 = (P4) rule.apply(p4);
same &= Objects.equals(p4, newP4);
@SuppressWarnings("unchecked")
P5 newP5 = (P5) rule.apply(p5);
same &= Objects.equals(p5, newP5);
@SuppressWarnings("unchecked")
P6 newP6 = (P6) rule.apply(p6);
same &= Objects.equals(p6, newP6);
@SuppressWarnings("unchecked")
P7 newP7 = (P7) rule.apply(p7);
same &= Objects.equals(p7, newP7);
@SuppressWarnings("unchecked")
P8 newP8 = (P8) rule.apply(p8);
same &= Objects.equals(p8, newP8);
@SuppressWarnings("unchecked")
P9 newP9 = (P9) rule.apply(p9);
same &= Objects.equals(p9, newP9);
return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5, newP6, newP7, newP8, newP9);
}
};
}
public interface NodeCtor9<P1, P2, P3, P4, P5, P6, P7, P8, P9, T> {
T apply(Source l, P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6, P7 p7, P8 p8, P9 p9);
}
public static <T extends Node<?>, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10> NodeInfo<T> create(
T n, NodeCtor10<P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, T> ctor,
P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6, P7 p7, P8 p8, P9 p9, P10 p10) {

View File

@ -132,7 +132,7 @@ public abstract class Graphviz {
+ "</b></td></th>\n");
indent(nodeInfo, currentIndent + NODE_LABEL_INDENT);
List<Object> props = n.properties();
List<Object> props = n.nodeProperties();
List<String> parsed = new ArrayList<>(props.size());
List<Node<?>> subTrees = new ArrayList<>();

View File

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.util;
/**
* Simply utility class used for setting a state, typically
* for closures (which require outside variables to be final).
*/
public class Holder<T> {
private T value = null;
public Holder() {
}
public Holder(T value) {
this.value = value;
}
public void set(T value) {
this.value = value;
}
public T get() {
return value;
}
}

View File

@ -175,12 +175,10 @@ public class VerifierErrorMessagesTests extends ESTestCase {
}
public void testMissingColumnInOrderBy() {
// xxx offset is that of the order by field
assertEquals("1:29: Unknown column [xxx]", error("SELECT * FROM test ORDER BY xxx"));
}
public void testMissingColumnFunctionInOrderBy() {
// xxx offset is that of the order by field
assertEquals("1:41: Unknown column [xxx]", error("SELECT * FROM test ORDER BY DAY_oF_YEAR(xxx)"));
}
@ -208,7 +206,6 @@ public class VerifierErrorMessagesTests extends ESTestCase {
}
public void testMultipleColumns() {
// xxx offset is that of the order by field
assertEquals("1:43: Unknown column [xxx]\nline 1:8: Unknown column [xxx]",
error("SELECT xxx FROM test GROUP BY DAY_oF_YEAR(xxx)"));
}
@ -248,7 +245,7 @@ public class VerifierErrorMessagesTests extends ESTestCase {
}
public void testGroupByOrderByScalarOverNonGrouped() {
assertEquals("1:50: Cannot order by non-grouped column [YEAR(date)], expected [text]",
assertEquals("1:50: Cannot order by non-grouped column [YEAR(date)], expected [text] or an aggregate function",
error("SELECT MAX(int) FROM test GROUP BY text ORDER BY YEAR(date)"));
}
@ -258,7 +255,7 @@ public class VerifierErrorMessagesTests extends ESTestCase {
}
public void testGroupByOrderByScalarOverNonGrouped_WithHaving() {
assertEquals("1:71: Cannot order by non-grouped column [YEAR(date)], expected [text]",
assertEquals("1:71: Cannot order by non-grouped column [YEAR(date)], expected [text] or an aggregate function",
error("SELECT MAX(int) FROM test GROUP BY text HAVING MAX(int) > 10 ORDER BY YEAR(date)"));
}
@ -316,18 +313,25 @@ public class VerifierErrorMessagesTests extends ESTestCase {
error("SELECT * FROM test ORDER BY unsupported"));
}
public void testGroupByOrderByNonKey() {
assertEquals("1:52: Cannot order by non-grouped column [a], expected [bool]",
error("SELECT AVG(int) a FROM test GROUP BY bool ORDER BY a"));
public void testGroupByOrderByAggregate() {
accept("SELECT AVG(int) a FROM test GROUP BY bool ORDER BY a");
}
public void testGroupByOrderByFunctionOverKey() {
assertEquals("1:44: Cannot order by non-grouped column [MAX(int)], expected [int]",
error("SELECT int FROM test GROUP BY int ORDER BY MAX(int)"));
public void testGroupByOrderByAggs() {
accept("SELECT int FROM test GROUP BY int ORDER BY COUNT(*)");
}
public void testGroupByOrderByAggAndGroupedColumn() {
accept("SELECT int FROM test GROUP BY int ORDER BY int, MAX(int)");
}
public void testGroupByOrderByNonAggAndNonGroupedColumn() {
assertEquals("1:44: Cannot order by non-grouped column [bool], expected [int]",
error("SELECT int FROM test GROUP BY int ORDER BY bool"));
}
public void testGroupByOrderByScore() {
assertEquals("1:44: Cannot order by non-grouped column [SCORE()], expected [int]",
assertEquals("1:44: Cannot order by non-grouped column [SCORE()], expected [int] or an aggregate function",
error("SELECT int FROM test GROUP BY int ORDER BY SCORE()"));
}

View File

@ -17,6 +17,7 @@ import org.elasticsearch.xpack.sql.session.Cursors;
import java.io.IOException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.function.Supplier;
@ -27,7 +28,9 @@ public class CompositeAggregationCursorTests extends AbstractWireSerializingTest
for (int i = 0; i < extractorsSize; i++) {
extractors.add(randomBucketExtractor());
}
return new CompositeAggregationCursor(new byte[randomInt(256)], extractors, randomIntBetween(10, 1024), randomAlphaOfLength(5));
return new CompositeAggregationCursor(new byte[randomInt(256)], extractors, randomBitSet(extractorsSize),
randomIntBetween(10, 1024), randomAlphaOfLength(5));
}
static BucketExtractor randomBucketExtractor() {
@ -41,7 +44,9 @@ public class CompositeAggregationCursorTests extends AbstractWireSerializingTest
@Override
protected CompositeAggregationCursor mutateInstance(CompositeAggregationCursor instance) throws IOException {
return new CompositeAggregationCursor(instance.next(), instance.extractors(),
randomValueOtherThan(instance.limit(), () -> randomIntBetween(1, 512)), instance.indices());
randomValueOtherThan(instance.mask(), () -> randomBitSet(instance.extractors().size())),
randomValueOtherThan(instance.limit(), () -> randomIntBetween(1, 512)),
instance.indices());
}
@Override
@ -68,4 +73,12 @@ public class CompositeAggregationCursorTests extends AbstractWireSerializingTest
}
return (CompositeAggregationCursor) Cursors.decodeFromString(Cursors.encodeToString(version, instance));
}
static BitSet randomBitSet(int size) {
BitSet mask = new BitSet(size);
for (int i = 0; i < size; i++) {
mask.set(i, randomBoolean());
}
return mask;
}
}

View File

@ -23,17 +23,18 @@ import org.elasticsearch.xpack.sql.session.Cursors;
import org.mockito.ArgumentCaptor;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.List;
import java.util.function.Supplier;
import static org.elasticsearch.action.support.PlainActionFuture.newFuture;
import static org.elasticsearch.xpack.sql.action.BasicFormatter.FormatOption.CLI;
import static org.elasticsearch.xpack.sql.action.BasicFormatter.FormatOption.TEXT;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.elasticsearch.xpack.sql.action.BasicFormatter.FormatOption.CLI;
import static org.elasticsearch.xpack.sql.action.BasicFormatter.FormatOption.TEXT;
public class CursorTests extends ESTestCase {
@ -51,7 +52,7 @@ public class CursorTests extends ESTestCase {
Client clientMock = mock(Client.class);
ActionListener<Boolean> listenerMock = mock(ActionListener.class);
String cursorString = randomAlphaOfLength(10);
Cursor cursor = new ScrollCursor(cursorString, Collections.emptyList(), randomInt());
Cursor cursor = new ScrollCursor(cursorString, Collections.emptyList(), new BitSet(0), randomInt());
cursor.clear(TestUtils.TEST_CFG, clientMock, listenerMock);

View File

@ -0,0 +1,63 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.execution.search;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.sql.session.Cursors;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class PagingListCursorTests extends AbstractWireSerializingTestCase<PagingListCursor> {
public static PagingListCursor randomPagingListCursor() {
int size = between(1, 20);
int depth = between(1, 20);
List<List<?>> values = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
values.add(Arrays.asList(randomArray(depth, s -> new Object[depth], () -> randomByte())));
}
return new PagingListCursor(values, depth, between(1, 20));
}
@Override
protected PagingListCursor mutateInstance(PagingListCursor instance) throws IOException {
return new PagingListCursor(instance.data(),
instance.columnCount(),
randomValueOtherThan(instance.pageSize(), () -> between(1, 20)));
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(Cursors.getNamedWriteables());
}
@Override
protected PagingListCursor createTestInstance() {
return randomPagingListCursor();
}
@Override
protected Reader<PagingListCursor> instanceReader() {
return PagingListCursor::new;
}
@Override
protected PagingListCursor copyInstance(PagingListCursor instance, Version version) throws IOException {
/* Randomly choose between internal protocol round trip and String based
* round trips used to toXContent. */
if (randomBoolean()) {
return super.copyInstance(instance, version);
}
return (PagingListCursor) Cursors.decodeFromString(Cursors.encodeToString(version, instance));
}
}

View File

@ -27,7 +27,8 @@ public class ScrollCursorTests extends AbstractWireSerializingTestCase<ScrollCur
for (int i = 0; i < extractorsSize; i++) {
extractors.add(randomHitExtractor(0));
}
return new ScrollCursor(randomAlphaOfLength(5), extractors, randomIntBetween(10, 1024));
return new ScrollCursor(randomAlphaOfLength(5), extractors, CompositeAggregationCursorTests.randomBitSet(extractorsSize),
randomIntBetween(10, 1024));
}
static HitExtractor randomHitExtractor(int depth) {
@ -43,6 +44,7 @@ public class ScrollCursorTests extends AbstractWireSerializingTestCase<ScrollCur
@Override
protected ScrollCursor mutateInstance(ScrollCursor instance) throws IOException {
return new ScrollCursor(instance.scrollId(), instance.extractors(),
randomValueOtherThan(instance.mask(), () -> CompositeAggregationCursorTests.randomBitSet(instance.extractors().size())),
randomValueOtherThan(instance.limit(), () -> randomIntBetween(1, 1024)));
}

View File

@ -85,7 +85,7 @@ public class SourceGeneratorTests extends ESTestCase {
public void testSortScoreSpecified() {
QueryContainer container = new QueryContainer()
.sort(new ScoreSort(Direction.DESC, null));
.addSort(new ScoreSort(Direction.DESC, null));
SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10));
assertEquals(singletonList(scoreSort()), sourceBuilder.sorts());
}
@ -94,13 +94,13 @@ public class SourceGeneratorTests extends ESTestCase {
FieldSortBuilder sortField = fieldSort("test").unmappedType("keyword");
QueryContainer container = new QueryContainer()
.sort(new AttributeSort(new FieldAttribute(Source.EMPTY, "test", new KeywordEsField("test")), Direction.ASC,
.addSort(new AttributeSort(new FieldAttribute(Source.EMPTY, "test", new KeywordEsField("test")), Direction.ASC,
Missing.LAST));
SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10));
assertEquals(singletonList(sortField.order(SortOrder.ASC).missing("_last")), sourceBuilder.sorts());
container = new QueryContainer()
.sort(new AttributeSort(new FieldAttribute(Source.EMPTY, "test", new KeywordEsField("test")), Direction.DESC,
.addSort(new AttributeSort(new FieldAttribute(Source.EMPTY, "test", new KeywordEsField("test")), Direction.DESC,
Missing.FIRST));
sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10));
assertEquals(singletonList(sortField.order(SortOrder.DESC).missing("_first")), sourceBuilder.sorts());

View File

@ -306,7 +306,7 @@ public class SqlParserTests extends ESTestCase {
In in = (In) filter.condition();
assertEquals("?a", in.value().toString());
assertEquals(noChildren, in.list().size());
assertThat(in.list().get(0).toString(), startsWith("a + b#"));
assertThat(in.list().get(0).toString(), startsWith("Add[?a,?b]"));
}
public void testDecrementOfDepthCounter() {

View File

@ -53,7 +53,7 @@ public class SysParserTests extends ESTestCase {
return Void.TYPE;
}).when(resolver).resolveAsSeparateMappings(any(), any(), any());
SqlSession session = new SqlSession(TestUtils.TEST_CFG, null, null, resolver, null, null, null, null);
SqlSession session = new SqlSession(TestUtils.TEST_CFG, null, null, resolver, null, null, null, null, null);
return new Tuple<>(cmd, session);
}

View File

@ -243,7 +243,7 @@ public class SysTablesTests extends ESTestCase {
IndexResolver resolver = mock(IndexResolver.class);
when(resolver.clusterName()).thenReturn(CLUSTER_NAME);
SqlSession session = new SqlSession(null, null, null, resolver, null, null, null, null);
SqlSession session = new SqlSession(null, null, null, resolver, null, null, null, null, null);
return new Tuple<>(cmd, session);
}

View File

@ -36,7 +36,7 @@ public class SysTypesTests extends ESTestCase {
Command cmd = (Command) analyzer.analyze(parser.createStatement(sql), false);
IndexResolver resolver = mock(IndexResolver.class);
SqlSession session = new SqlSession(null, null, null, resolver, null, null, null, null);
SqlSession session = new SqlSession(null, null, null, resolver, null, null, null, null, null);
return new Tuple<>(cmd, session);
}

View File

@ -253,8 +253,8 @@ public class NodeSubclassTests<T extends B, B extends Node<B>> extends ESTestCas
* the one property of the node that we intended to transform.
*/
assertEquals(node.source(), transformed.source());
List<Object> op = node.properties();
List<Object> tp = transformed.properties();
List<Object> op = node.nodeProperties();
List<Object> tp = transformed.nodeProperties();
for (int p = 0; p < op.size(); p++) {
if (p == changedArgOffset - 1) { // -1 because location isn't in the list
assertEquals(changedArgValue, tp.get(p));