SQL: Forbid multi field groups (elastic/x-pack-elasticsearch#3199)

* SQL: GROUP BY with multiple fields are forbidden

The check is performed in the folder Verifier as the optimizer can eliminate some fields (like those with constants)

Original commit: elastic/x-pack-elasticsearch@8d49f4ab02
This commit is contained in:
Costin Leau 2017-12-05 18:41:19 +02:00 committed by GitHub
parent 4e49769efb
commit 88b8794801
9 changed files with 132 additions and 38 deletions

View File

@ -15,7 +15,7 @@ import java.sql.SQLException;
import java.util.List; import java.util.List;
@TestLogging(JdbcTestUtils.SQL_TRACE) @TestLogging(JdbcTestUtils.SQL_TRACE)
public abstract class DebugCsvSpec extends CsvSpecTestCase { public class DebugCsvSpec extends CsvSpecTestCase {
@ParametersFactory(shuffle = false, argumentFormatting = SqlSpecTestCase.PARAM_FORMATTING) @ParametersFactory(shuffle = false, argumentFormatting = SqlSpecTestCase.PARAM_FORMATTING)
public static List<Object[]> readScriptSpec() throws Exception { public static List<Object[]> readScriptSpec() throws Exception {
@ -27,6 +27,7 @@ public abstract class DebugCsvSpec extends CsvSpecTestCase {
super(fileName, groupName, testName, lineNumber, testCase); super(fileName, groupName, testName, lineNumber, testCase);
} }
@Override
protected void assertResults(ResultSet expected, ResultSet elastic) throws SQLException { protected void assertResults(ResultSet expected, ResultSet elastic) throws SQLException {
Logger log = logEsResultSet() ? logger : null; Logger log = logEsResultSet() ? logger : null;

View File

@ -12,7 +12,7 @@ import org.elasticsearch.test.junit.annotations.TestLogging;
import java.util.List; import java.util.List;
@TestLogging(JdbcTestUtils.SQL_TRACE) @TestLogging(JdbcTestUtils.SQL_TRACE)
public abstract class DebugSqlSpec extends SqlSpecTestCase { public class DebugSqlSpec extends SqlSpecTestCase {
@ParametersFactory(shuffle = false, argumentFormatting = PARAM_FORMATTING) @ParametersFactory(shuffle = false, argumentFormatting = PARAM_FORMATTING)
public static List<Object[]> readScriptSpec() throws Exception { public static List<Object[]> readScriptSpec() throws Exception {
Parser parser = specParser(); Parser parser = specParser();

View File

@ -3,7 +3,7 @@
// //
debug debug
SELECT DAY_OF_YEAR(birth_date) d, last_name l FROM "test_emp" WHERE emp_no < 10010 ORDER BY emp_no; SELECT int FROM test GROUP BY AVG(int) + 2;
table:s table:s
test_emp test_emp

View File

@ -3,4 +3,4 @@
// //
debug debug
SELECT MONTH(birth_date) AS d, COUNT(*) AS c, CAST(SUM(emp_no) AS INT) s FROM "test_emp" GROUP BY MONTH(birth_date) ORDER BY MONTH(birth_date) DESC; SELECT int FROM test GROUP BY AVG(int) + 2;

View File

@ -150,9 +150,8 @@ abstract class Verifier {
// Concrete verifications // Concrete verifications
// //
// if there are no (major) unresolved failures, do more in-depth analysis // if there are no (major) unresolved failures, do more in-depth analysis
//
if (failures.isEmpty()) { if (failures.isEmpty()) {
Map<String, Function> resolvedFunctions = new LinkedHashMap<>(); Map<String, Function> resolvedFunctions = new LinkedHashMap<>();
@ -183,14 +182,14 @@ abstract class Verifier {
if (!groupingFailures.contains(p)) { if (!groupingFailures.contains(p)) {
checkGroupBy(p, localFailures, resolvedFunctions, groupingFailures); checkGroupBy(p, localFailures, resolvedFunctions, groupingFailures);
} }
// everything checks out // everything checks out
// mark the plan as analyzed // mark the plan as analyzed
if (localFailures.isEmpty()) { if (localFailures.isEmpty()) {
p.setAnalyzed(); p.setAnalyzed();
} }
failures.addAll(localFailures); failures.addAll(localFailures);
}); });
} }
return failures; return failures;
@ -252,7 +251,7 @@ abstract class Verifier {
Expressions.names(a.groupings()))); Expressions.names(a.groupings())));
groupingFailures.add(a); groupingFailures.add(a);
return false; return false;
} }
} }
} }
return true; return true;
@ -286,7 +285,7 @@ abstract class Verifier {
a.aggregates().forEach(ne -> a.aggregates().forEach(ne ->
ne.collectFirstChildren(c -> checkGroupMatch(c, ne, a.groupings(), missing, functions))); ne.collectFirstChildren(c -> checkGroupMatch(c, ne, a.groupings(), missing, functions)));
if (!missing.isEmpty()) { if (!missing.isEmpty()) {
String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY; String plural = missing.size() > 1 ? "s" : StringUtils.EMPTY;
localFailures.add(fail(missing.values().iterator().next(), "Cannot use non-grouped column" + plural + " %s, expected %s", localFailures.add(fail(missing.values().iterator().next(), "Cannot use non-grouped column" + plural + " %s, expected %s",
Expressions.names(missing.keySet()), Expressions.names(missing.keySet()),
@ -307,9 +306,9 @@ abstract class Verifier {
// TODO: this should be handled by a different rule // TODO: this should be handled by a different rule
if (function == null) { if (function == null) {
return false; return false;
} }
e = function; e = function;
} }
// scalar functions can be a binary tree // scalar functions can be a binary tree
// first test the function against the grouping // first test the function against the grouping
@ -319,11 +318,11 @@ abstract class Verifier {
// found group for the expression // found group for the expression
if (Expressions.anyMatch(groupings, e::semanticEquals)) { if (Expressions.anyMatch(groupings, e::semanticEquals)) {
return true; return true;
} }
// unwrap function to find the base // unwrap function to find the base
for (Expression arg : sf.arguments()) { for (Expression arg : sf.arguments()) {
arg.collectFirstChildren(c -> checkGroupMatch(c, source, groupings, missing, functions)); arg.collectFirstChildren(c -> checkGroupMatch(c, source, groupings, missing, functions));
} }
return true; return true;
} }

View File

@ -17,7 +17,7 @@ import static java.util.stream.Collectors.toList;
public abstract class Expressions { public abstract class Expressions {
public static List<NamedExpression> asNamed(List<Expression> exp) { public static List<NamedExpression> asNamed(List<? extends Expression> exp) {
return exp.stream() return exp.stream()
.map(NamedExpression.class::cast) .map(NamedExpression.class::cast)
.collect(toList()); .collect(toList());
@ -72,25 +72,25 @@ public abstract class Expressions {
return e instanceof NamedExpression ? ((NamedExpression) e).name() : e.nodeName(); return e instanceof NamedExpression ? ((NamedExpression) e).name() : e.nodeName();
} }
public static List<String> names(Collection<Expression> e) { public static List<String> names(Collection<? extends Expression> e) {
List<String> names = new ArrayList<>(e.size()); List<String> names = new ArrayList<>(e.size());
for (Expression ex : e) { for (Expression ex : e) {
names.add(name(ex)); names.add(name(ex));
} }
return names; return names;
} }
public static Attribute attribute(Expression e) { public static Attribute attribute(Expression e) {
return e instanceof NamedExpression ? ((NamedExpression) e).toAttribute() : null; return e instanceof NamedExpression ? ((NamedExpression) e).toAttribute() : null;
} }
public static TypeResolution typeMustBe(Expression e, Predicate<Expression> predicate, String message) { public static TypeResolution typeMustBe(Expression e, Predicate<Expression> predicate, String message) {
return predicate.test(e) ? TypeResolution.TYPE_RESOLVED : new TypeResolution(message); return predicate.test(e) ? TypeResolution.TYPE_RESOLVED : new TypeResolution(message);
} }
public static TypeResolution typeMustBeNumeric(Expression e) { public static TypeResolution typeMustBeNumeric(Expression e) {
return e.dataType().isNumeric()? TypeResolution.TYPE_RESOLVED : new TypeResolution( return e.dataType().isNumeric()? TypeResolution.TYPE_RESOLVED : new TypeResolution(
"Argument required to be numeric ('%s' of type '%s')", Expressions.name(e), e.dataType().esName()); "Argument required to be numeric ('%s' of type '%s')", Expressions.name(e), e.dataType().esName());
} }
} }

View File

@ -7,9 +7,13 @@ package org.elasticsearch.xpack.sql.planner;
import org.elasticsearch.xpack.sql.ClientSqlException; import org.elasticsearch.xpack.sql.ClientSqlException;
import org.elasticsearch.xpack.sql.planner.Verifier.Failure; import org.elasticsearch.xpack.sql.planner.Verifier.Failure;
import org.elasticsearch.xpack.sql.util.StringUtils;
import java.util.Collection; import java.util.Collection;
import java.util.StringJoiner; import java.util.Locale;
import java.util.stream.Collectors;
import static java.lang.String.format;
public class PlanningException extends ClientSqlException { public class PlanningException extends ClientSqlException {
@ -21,11 +25,9 @@ public class PlanningException extends ClientSqlException {
super(extractMessage(sources)); super(extractMessage(sources));
} }
private static String extractMessage(Collection<Failure> sources) { private static String extractMessage(Collection<Failure> failures) {
StringJoiner sj = new StringJoiner(",", "{", "}"); return failures.stream()
sources.forEach(s -> { .map(f -> format(Locale.ROOT, "line %s:%s: %s", f.source().location().getLineNumber(), f.source().location().getColumnNumber(), f.message()))
sj.add(s.source().nodeString() + s.source().location()); .collect(Collectors.joining(StringUtils.NEW_LINE, "Found " + failures.size() + " problem(s)\n", StringUtils.EMPTY));
});
return "Fail to plan items " + sj.toString();
} }
} }

View File

@ -5,24 +5,26 @@
*/ */
package org.elasticsearch.xpack.sql.planner; package org.elasticsearch.xpack.sql.planner;
import java.util.ArrayList; import org.elasticsearch.xpack.sql.expression.Expressions;
import java.util.List; import org.elasticsearch.xpack.sql.plan.physical.AggregateExec;
import java.util.Objects;
import org.elasticsearch.xpack.sql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.sql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.sql.plan.physical.Unexecutable; import org.elasticsearch.xpack.sql.plan.physical.Unexecutable;
import org.elasticsearch.xpack.sql.plan.physical.UnplannedExec; import org.elasticsearch.xpack.sql.plan.physical.UnplannedExec;
import org.elasticsearch.xpack.sql.tree.Node; import org.elasticsearch.xpack.sql.tree.Node;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
abstract class Verifier { abstract class Verifier {
static class Failure { static class Failure {
private final Node<?> source; private final Node<?> source;
private final String message; private final String message;
Failure(Node<?> source, String message) { Failure(Node<?> source, String message) {
this.source = source; this.source = source;
this.message = message + " " + source.nodeString(); this.message = message;
} }
Node<?> source() { Node<?> source() {
@ -69,11 +71,23 @@ abstract class Verifier {
failures.add(fail(e, "Unresolved expression")); failures.add(fail(e, "Unresolved expression"));
} }
}); });
if (p instanceof AggregateExec) {
forbidMultiFieldGroupBy((AggregateExec) p, failures);
}
}); });
return failures; return failures;
} }
private static void forbidMultiFieldGroupBy(AggregateExec a, List<Failure> failures) {
if (a.groupings().size() > 1) {
failures.add(fail(a.groupings().get(0), "Currently, only a single expression can be used with GROUP BY; please select one of "
+ Expressions.names(a.groupings())));
}
}
static List<Failure> verifyExecutingPlan(PhysicalPlan plan) { static List<Failure> verifyExecutingPlan(PhysicalPlan plan) {
List<Failure> failures = new ArrayList<>(); List<Failure> failures = new ArrayList<>();

View File

@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.planner;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer;
import org.elasticsearch.xpack.sql.analysis.catalog.Catalog;
import org.elasticsearch.xpack.sql.analysis.catalog.EsIndex;
import org.elasticsearch.xpack.sql.analysis.catalog.InMemoryCatalog;
import org.elasticsearch.xpack.sql.expression.function.DefaultFunctionRegistry;
import org.elasticsearch.xpack.sql.expression.function.FunctionRegistry;
import org.elasticsearch.xpack.sql.optimizer.Optimizer;
import org.elasticsearch.xpack.sql.parser.SqlParser;
import org.elasticsearch.xpack.sql.session.TestingSqlSession;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;
import org.junit.After;
import org.junit.Before;
import java.util.LinkedHashMap;
import java.util.Map;
import static java.util.Collections.singletonList;
public class VerifierErrorMessagesTests extends ESTestCase {
private SqlParser parser;
private FunctionRegistry functionRegistry;
private Catalog catalog;
private Analyzer analyzer;
private Optimizer optimizer;
private Planner planner;
public VerifierErrorMessagesTests() {
parser = new SqlParser();
functionRegistry = new DefaultFunctionRegistry();
Map<String, DataType> mapping = new LinkedHashMap<>();
mapping.put("bool", DataTypes.BOOLEAN);
mapping.put("int", DataTypes.INTEGER);
mapping.put("text", DataTypes.TEXT);
mapping.put("keyword", DataTypes.KEYWORD);
EsIndex test = new EsIndex("test", mapping);
catalog = new InMemoryCatalog(singletonList(test));
analyzer = new Analyzer(functionRegistry);
optimizer = new Optimizer();
planner = new Planner();
}
@Before
public void setupContext() {
TestingSqlSession.setCurrentContext(TestingSqlSession.ctx(catalog));
}
@After
public void disposeContext() {
TestingSqlSession.removeCurrentContext();
}
private String verify(String sql) {
PlanningException e = expectThrows(PlanningException.class,
() -> planner.mapPlan(optimizer.optimize(analyzer.analyze(parser.createStatement(sql), true)), true));
assertTrue(e.getMessage().startsWith("Found "));
String header = "Found 1 problem(s)\nline ";
return e.getMessage().substring(header.length());
}
public void testMultiGroupBy() {
// TODO: location needs to be updated after merging extend-having
assertEquals("1:32: Currently, only a single expression can be used with GROUP BY; please select one of [bool, keyword]",
verify("SELECT bool FROM test GROUP BY bool, keyword"));
}
}