SQL: Prevent grouping over grouping functions (#38649)
Improve verifier to disallow grouping over grouping functions (e.g. HISTOGRAM over HISTOGRAM). Close #38308 (cherry picked from commit 4e9b1cfd4df38c652bba36b4b4b538ce7c714b6e)
This commit is contained in:
parent
871036bd21
commit
794ee4fb10
|
@ -593,20 +593,36 @@ public final class Verifier {
|
||||||
// check if the query has a grouping function (Histogram) but no GROUP BY
|
// check if the query has a grouping function (Histogram) but no GROUP BY
|
||||||
if (p instanceof Project) {
|
if (p instanceof Project) {
|
||||||
Project proj = (Project) p;
|
Project proj = (Project) p;
|
||||||
proj.projections().forEach(e -> e.forEachDown(f ->
|
proj.projections().forEach(e -> e.forEachDown(f ->
|
||||||
localFailures.add(fail(f, "[{}] needs to be part of the grouping", Expressions.name(f))), GroupingFunction.class));
|
localFailures.add(fail(f, "[{}] needs to be part of the grouping", Expressions.name(f))), GroupingFunction.class));
|
||||||
} else if (p instanceof Aggregate) {
|
} else if (p instanceof Aggregate) {
|
||||||
// if it does have a GROUP BY, check if the groupings contain the grouping functions (Histograms)
|
// if it does have a GROUP BY, check if the groupings contain the grouping functions (Histograms)
|
||||||
Aggregate a = (Aggregate) p;
|
Aggregate a = (Aggregate) p;
|
||||||
a.aggregates().forEach(agg -> agg.forEachDown(e -> {
|
a.aggregates().forEach(agg -> agg.forEachDown(e -> {
|
||||||
if (a.groupings().size() == 0
|
if (a.groupings().size() == 0
|
||||||
|| Expressions.anyMatch(a.groupings(), g -> g instanceof Function && e.functionEquals((Function) g)) == false) {
|
|| Expressions.anyMatch(a.groupings(), g -> g instanceof Function && e.functionEquals((Function) g)) == false) {
|
||||||
localFailures.add(fail(e, "[{}] needs to be part of the grouping", Expressions.name(e)));
|
localFailures.add(fail(e, "[{}] needs to be part of the grouping", Expressions.name(e)));
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
checkGroupingFunctionTarget(e, localFailures);
|
||||||
|
}
|
||||||
|
}, GroupingFunction.class));
|
||||||
|
|
||||||
|
a.groupings().forEach(g -> g.forEachDown(e -> {
|
||||||
|
checkGroupingFunctionTarget(e, localFailures);
|
||||||
}, GroupingFunction.class));
|
}, GroupingFunction.class));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static void checkGroupingFunctionTarget(GroupingFunction f, Set<Failure> localFailures) {
|
||||||
|
f.field().forEachDown(e -> {
|
||||||
|
if (e instanceof GroupingFunction) {
|
||||||
|
localFailures.add(fail(f.field(), "Cannot embed grouping functions within each other, found [{}] in [{}]",
|
||||||
|
Expressions.name(f.field()), Expressions.name(f)));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
private static void checkFilterOnAggs(LogicalPlan p, Set<Failure> localFailures) {
|
private static void checkFilterOnAggs(LogicalPlan p, Set<Failure> localFailures) {
|
||||||
if (p instanceof Filter) {
|
if (p instanceof Filter) {
|
||||||
Filter filter = (Filter) p;
|
Filter filter = (Filter) p;
|
||||||
|
|
|
@ -14,9 +14,6 @@ import org.elasticsearch.xpack.sql.type.DataType;
|
||||||
import org.elasticsearch.xpack.sql.util.StringUtils;
|
import org.elasticsearch.xpack.sql.util.StringUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Locale;
|
|
||||||
|
|
||||||
import static java.lang.String.format;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In a SQL statement, an Expression is whatever a user specifies inside an
|
* In a SQL statement, an Expression is whatever a user specifies inside an
|
||||||
|
@ -39,10 +36,6 @@ public abstract class Expression extends Node<Expression> implements Resolvable
|
||||||
this(true, message);
|
this(true, message);
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeResolution(String message, Object... args) {
|
|
||||||
this(true, format(Locale.ROOT, message, args));
|
|
||||||
}
|
|
||||||
|
|
||||||
private TypeResolution(boolean unresolved, String message) {
|
private TypeResolution(boolean unresolved, String message) {
|
||||||
this.failed = unresolved;
|
this.failed = unresolved;
|
||||||
this.message = message;
|
this.message = message;
|
||||||
|
|
|
@ -18,9 +18,9 @@ import java.util.Locale;
|
||||||
import java.util.StringJoiner;
|
import java.util.StringJoiner;
|
||||||
import java.util.function.Predicate;
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
import static java.lang.String.format;
|
|
||||||
import static java.util.Collections.emptyList;
|
import static java.util.Collections.emptyList;
|
||||||
import static java.util.Collections.emptyMap;
|
import static java.util.Collections.emptyMap;
|
||||||
|
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
|
||||||
import static org.elasticsearch.xpack.sql.type.DataType.BOOLEAN;
|
import static org.elasticsearch.xpack.sql.type.DataType.BOOLEAN;
|
||||||
|
|
||||||
public final class Expressions {
|
public final class Expressions {
|
||||||
|
@ -186,7 +186,7 @@ public final class Expressions {
|
||||||
String... acceptedTypes) {
|
String... acceptedTypes) {
|
||||||
return predicate.test(e.dataType()) || DataTypes.isNull(e.dataType())?
|
return predicate.test(e.dataType()) || DataTypes.isNull(e.dataType())?
|
||||||
TypeResolution.TYPE_RESOLVED :
|
TypeResolution.TYPE_RESOLVED :
|
||||||
new TypeResolution(format(Locale.ROOT, "[%s]%s argument must be [%s], found value [%s] type [%s]",
|
new TypeResolution(format(null, "[{}]{} argument must be [{}], found value [{}] type [{}]",
|
||||||
operationName,
|
operationName,
|
||||||
paramOrd == null || paramOrd == ParamOrdinal.DEFAULT ? "" : " " + paramOrd.name().toLowerCase(Locale.ROOT),
|
paramOrd == null || paramOrd == ParamOrdinal.DEFAULT ? "" : " " + paramOrd.name().toLowerCase(Locale.ROOT),
|
||||||
acceptedTypesForErrorMsg(acceptedTypes),
|
acceptedTypesForErrorMsg(acceptedTypes),
|
||||||
|
|
|
@ -566,10 +566,20 @@ public class VerifierErrorMessagesTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testAggsInHistogram() {
|
public void testAggsInHistogram() {
|
||||||
assertEquals("1:47: Cannot use an aggregate [MAX] for grouping",
|
assertEquals("1:37: Cannot use an aggregate [MAX] for grouping",
|
||||||
error("SELECT MAX(date) FROM test GROUP BY HISTOGRAM(MAX(int), 1)"));
|
error("SELECT MAX(date) FROM test GROUP BY MAX(int)"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testGroupingsInHistogram() {
|
||||||
|
assertEquals(
|
||||||
|
"1:47: Cannot embed grouping functions within each other, found [HISTOGRAM(int, 1)] in [HISTOGRAM(HISTOGRAM(int, 1), 1)]",
|
||||||
|
error("SELECT MAX(date) FROM test GROUP BY HISTOGRAM(HISTOGRAM(int, 1), 1)"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCastInHistogram() {
|
||||||
|
accept("SELECT MAX(date) FROM test GROUP BY HISTOGRAM(CAST(int AS LONG), 1)");
|
||||||
|
}
|
||||||
|
|
||||||
public void testHistogramNotInGrouping() {
|
public void testHistogramNotInGrouping() {
|
||||||
assertEquals("1:8: [HISTOGRAM(date, INTERVAL 1 MONTH)] needs to be part of the grouping",
|
assertEquals("1:8: [HISTOGRAM(date, INTERVAL 1 MONTH)] needs to be part of the grouping",
|
||||||
error("SELECT HISTOGRAM(date, INTERVAL 1 MONTH) AS h FROM test"));
|
error("SELECT HISTOGRAM(date, INTERVAL 1 MONTH) AS h FROM test"));
|
||||||
|
|
Loading…
Reference in New Issue