SQL: Prevent StackOverflowError when parsing large statements (#33902)
Implement circuit breaker logic in the parser which catches expressions that can blow up the tree and result in StackOverflowError being thrown. Co-authored-by: Costin Leau <costin.leau@gmail.com>
This commit is contained in:
parent
cc70352b3f
commit
5840be6a6b
|
@ -5,6 +5,7 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.sql.parser;
|
package org.elasticsearch.xpack.sql.parser;
|
||||||
|
|
||||||
|
import com.carrotsearch.hppc.ObjectShortHashMap;
|
||||||
import org.antlr.v4.runtime.BaseErrorListener;
|
import org.antlr.v4.runtime.BaseErrorListener;
|
||||||
import org.antlr.v4.runtime.CharStream;
|
import org.antlr.v4.runtime.CharStream;
|
||||||
import org.antlr.v4.runtime.CommonToken;
|
import org.antlr.v4.runtime.CommonToken;
|
||||||
|
@ -22,8 +23,8 @@ import org.antlr.v4.runtime.atn.PredictionMode;
|
||||||
import org.antlr.v4.runtime.dfa.DFA;
|
import org.antlr.v4.runtime.dfa.DFA;
|
||||||
import org.antlr.v4.runtime.misc.Pair;
|
import org.antlr.v4.runtime.misc.Pair;
|
||||||
import org.antlr.v4.runtime.tree.TerminalNode;
|
import org.antlr.v4.runtime.tree.TerminalNode;
|
||||||
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.common.logging.Loggers;
|
|
||||||
import org.elasticsearch.xpack.sql.expression.Expression;
|
import org.elasticsearch.xpack.sql.expression.Expression;
|
||||||
import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
|
import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
|
||||||
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;
|
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;
|
||||||
|
@ -41,7 +42,8 @@ import java.util.function.Function;
|
||||||
import static java.lang.String.format;
|
import static java.lang.String.format;
|
||||||
|
|
||||||
public class SqlParser {
|
public class SqlParser {
|
||||||
private static final Logger log = Loggers.getLogger(SqlParser.class);
|
|
||||||
|
private static final Logger log = LogManager.getLogger();
|
||||||
|
|
||||||
private final boolean DEBUG = false;
|
private final boolean DEBUG = false;
|
||||||
|
|
||||||
|
@ -83,7 +85,9 @@ public class SqlParser {
|
||||||
return invokeParser(expression, params, SqlBaseParser::singleExpression, AstBuilder::expression);
|
return invokeParser(expression, params, SqlBaseParser::singleExpression, AstBuilder::expression);
|
||||||
}
|
}
|
||||||
|
|
||||||
private <T> T invokeParser(String sql, List<SqlTypedParamValue> params, Function<SqlBaseParser, ParserRuleContext> parseFunction,
|
private <T> T invokeParser(String sql,
|
||||||
|
List<SqlTypedParamValue> params, Function<SqlBaseParser,
|
||||||
|
ParserRuleContext> parseFunction,
|
||||||
BiFunction<AstBuilder, ParserRuleContext, T> visitor) {
|
BiFunction<AstBuilder, ParserRuleContext, T> visitor) {
|
||||||
SqlBaseLexer lexer = new SqlBaseLexer(new CaseInsensitiveStream(sql));
|
SqlBaseLexer lexer = new SqlBaseLexer(new CaseInsensitiveStream(sql));
|
||||||
|
|
||||||
|
@ -96,6 +100,7 @@ public class SqlParser {
|
||||||
CommonTokenStream tokenStream = new CommonTokenStream(tokenSource);
|
CommonTokenStream tokenStream = new CommonTokenStream(tokenSource);
|
||||||
SqlBaseParser parser = new SqlBaseParser(tokenStream);
|
SqlBaseParser parser = new SqlBaseParser(tokenStream);
|
||||||
|
|
||||||
|
parser.addParseListener(new CircuitBreakerListener());
|
||||||
parser.addParseListener(new PostProcessor(Arrays.asList(parser.getRuleNames())));
|
parser.addParseListener(new PostProcessor(Arrays.asList(parser.getRuleNames())));
|
||||||
|
|
||||||
parser.removeErrorListeners();
|
parser.removeErrorListeners();
|
||||||
|
@ -125,7 +130,7 @@ public class SqlParser {
|
||||||
return visitor.apply(new AstBuilder(paramTokens), tree);
|
return visitor.apply(new AstBuilder(paramTokens), tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void debug(SqlBaseParser parser) {
|
private static void debug(SqlBaseParser parser) {
|
||||||
|
|
||||||
// when debugging, use the exact prediction mode (needed for diagnostics as well)
|
// when debugging, use the exact prediction mode (needed for diagnostics as well)
|
||||||
parser.getInterpreter().setPredictionMode(PredictionMode.LL_EXACT_AMBIG_DETECTION);
|
parser.getInterpreter().setPredictionMode(PredictionMode.LL_EXACT_AMBIG_DETECTION);
|
||||||
|
@ -154,7 +159,7 @@ public class SqlParser {
|
||||||
public void exitBackQuotedIdentifier(SqlBaseParser.BackQuotedIdentifierContext context) {
|
public void exitBackQuotedIdentifier(SqlBaseParser.BackQuotedIdentifierContext context) {
|
||||||
Token token = context.BACKQUOTED_IDENTIFIER().getSymbol();
|
Token token = context.BACKQUOTED_IDENTIFIER().getSymbol();
|
||||||
throw new ParsingException(
|
throw new ParsingException(
|
||||||
"backquoted indetifiers not supported; please use double quotes instead",
|
"backquoted identifiers not supported; please use double quotes instead",
|
||||||
null,
|
null,
|
||||||
token.getLine(),
|
token.getLine(),
|
||||||
token.getCharPositionInLine());
|
token.getCharPositionInLine());
|
||||||
|
@ -205,6 +210,35 @@ public class SqlParser {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used to catch large expressions that can lead to stack overflows
|
||||||
|
*/
|
||||||
|
private class CircuitBreakerListener extends SqlBaseBaseListener {
|
||||||
|
|
||||||
|
private static final short MAX_RULE_DEPTH = 100;
|
||||||
|
|
||||||
|
// Keep current depth for every rule visited.
|
||||||
|
// The totalDepth alone cannot be used as expressions like: e1 OR e2 OR e3 OR ...
|
||||||
|
// are processed as e1 OR (e2 OR (e3 OR (... and this results in the totalDepth not growing
|
||||||
|
// while the stack call depth is, leading to a StackOverflowError.
|
||||||
|
private ObjectShortHashMap<String> depthCounts = new ObjectShortHashMap<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void enterEveryRule(ParserRuleContext ctx) {
|
||||||
|
short currentDepth = depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 1, (short) 1);
|
||||||
|
if (currentDepth > MAX_RULE_DEPTH) {
|
||||||
|
throw new ParsingException("expression is too large to parse, (tree's depth exceeds {})", MAX_RULE_DEPTH);
|
||||||
|
}
|
||||||
|
super.enterEveryRule(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void exitEveryRule(ParserRuleContext ctx) {
|
||||||
|
depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 0, (short) -1);
|
||||||
|
super.exitEveryRule(ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener() {
|
private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener() {
|
||||||
@Override
|
@Override
|
||||||
public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int line,
|
public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int line,
|
||||||
|
|
|
@ -70,7 +70,7 @@ public class QuotingTests extends ESTestCase {
|
||||||
String name = "@timestamp";
|
String name = "@timestamp";
|
||||||
ParsingException ex = expectThrows(ParsingException.class, () ->
|
ParsingException ex = expectThrows(ParsingException.class, () ->
|
||||||
new SqlParser().createExpression(quote + name + quote));
|
new SqlParser().createExpression(quote + name + quote));
|
||||||
assertThat(ex.getMessage(), equalTo("line 1:1: backquoted indetifiers not supported; please use double quotes instead"));
|
assertThat(ex.getMessage(), equalTo("line 1:1: backquoted identifiers not supported; please use double quotes instead"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testQuotedAttributeAndQualifier() {
|
public void testQuotedAttributeAndQualifier() {
|
||||||
|
@ -92,7 +92,7 @@ public class QuotingTests extends ESTestCase {
|
||||||
String name = "@timestamp";
|
String name = "@timestamp";
|
||||||
ParsingException ex = expectThrows(ParsingException.class, () ->
|
ParsingException ex = expectThrows(ParsingException.class, () ->
|
||||||
new SqlParser().createExpression(quote + qualifier + quote + "." + quote + name + quote));
|
new SqlParser().createExpression(quote + qualifier + quote + "." + quote + name + quote));
|
||||||
assertThat(ex.getMessage(), equalTo("line 1:1: backquoted indetifiers not supported; please use double quotes instead"));
|
assertThat(ex.getMessage(), equalTo("line 1:1: backquoted identifiers not supported; please use double quotes instead"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGreedyQuoting() {
|
public void testGreedyQuoting() {
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.sql.parser;
|
package org.elasticsearch.xpack.sql.parser;
|
||||||
|
|
||||||
|
import com.google.common.base.Joiner;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.xpack.sql.expression.NamedExpression;
|
import org.elasticsearch.xpack.sql.expression.NamedExpression;
|
||||||
import org.elasticsearch.xpack.sql.expression.Order;
|
import org.elasticsearch.xpack.sql.expression.Order;
|
||||||
|
@ -22,6 +23,7 @@ import org.elasticsearch.xpack.sql.plan.logical.Project;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static java.util.Collections.nCopies;
|
||||||
import static java.util.stream.Collectors.toList;
|
import static java.util.stream.Collectors.toList;
|
||||||
import static org.hamcrest.Matchers.hasEntry;
|
import static org.hamcrest.Matchers.hasEntry;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
|
@ -136,6 +138,88 @@ public class SqlParserTests extends ESTestCase {
|
||||||
assertThat(mmqp.optionMap(), hasEntry("fuzzy_rewrite", "scoring_boolean"));
|
assertThat(mmqp.optionMap(), hasEntry("fuzzy_rewrite", "scoring_boolean"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testLimitToPreventStackOverflowFromLargeUnaryBooleanExpression() {
|
||||||
|
// Create expression in the form of NOT(NOT(NOT ... (b) ...)
|
||||||
|
|
||||||
|
// 40 elements is ok
|
||||||
|
new SqlParser().createExpression(
|
||||||
|
Joiner.on("").join(nCopies(40, "NOT(")).concat("b").concat(Joiner.on("").join(nCopies(40, ")"))));
|
||||||
|
|
||||||
|
// 100 elements parser's "circuit breaker" is triggered
|
||||||
|
ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createExpression(
|
||||||
|
Joiner.on("").join(nCopies(100, "NOT(")).concat("b").concat(Joiner.on("").join(nCopies(100, ")")))));
|
||||||
|
assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLimitToPreventStackOverflowFromLargeBinaryBooleanExpression() {
|
||||||
|
// Create expression in the form of a = b OR a = b OR ... a = b
|
||||||
|
|
||||||
|
// 50 elements is ok
|
||||||
|
new SqlParser().createExpression(Joiner.on(" OR ").join(nCopies(50, "a = b")));
|
||||||
|
|
||||||
|
// 100 elements parser's "circuit breaker" is triggered
|
||||||
|
ParsingException e = expectThrows(ParsingException.class, () ->
|
||||||
|
new SqlParser().createExpression(Joiner.on(" OR ").join(nCopies(100, "a = b"))));
|
||||||
|
assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLimitToPreventStackOverflowFromLargeUnaryArithmeticExpression() {
|
||||||
|
// Create expression in the form of abs(abs(abs ... (i) ...)
|
||||||
|
|
||||||
|
// 50 elements is ok
|
||||||
|
new SqlParser().createExpression(
|
||||||
|
Joiner.on("").join(nCopies(50, "abs(")).concat("i").concat(Joiner.on("").join(nCopies(50, ")"))));
|
||||||
|
|
||||||
|
// 101 elements parser's "circuit breaker" is triggered
|
||||||
|
ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createExpression(
|
||||||
|
Joiner.on("").join(nCopies(101, "abs(")).concat("i").concat(Joiner.on("").join(nCopies(101, ")")))));
|
||||||
|
assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLimitToPreventStackOverflowFromLargeBinaryArithmeticExpression() {
|
||||||
|
// Create expression in the form of a + a + a + ... + a
|
||||||
|
|
||||||
|
// 100 elements is ok
|
||||||
|
new SqlParser().createExpression(Joiner.on(" + ").join(nCopies(100, "a")));
|
||||||
|
|
||||||
|
// 101 elements parser's "circuit breaker" is triggered
|
||||||
|
ParsingException e = expectThrows(ParsingException.class, () ->
|
||||||
|
new SqlParser().createExpression(Joiner.on(" + ").join(nCopies(101, "a"))));
|
||||||
|
assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLimitToPreventStackOverflowFromLargeSubselectTree() {
|
||||||
|
// Test with queries in the form of `SELECT * FROM (SELECT * FROM (... t) ...)
|
||||||
|
|
||||||
|
// 100 elements is ok
|
||||||
|
new SqlParser().createStatement(
|
||||||
|
Joiner.on(" (").join(nCopies(100, "SELECT * FROM"))
|
||||||
|
.concat("t")
|
||||||
|
.concat(Joiner.on("").join(nCopies(99, ")"))));
|
||||||
|
|
||||||
|
// 101 elements parser's "circuit breaker" is triggered
|
||||||
|
ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createStatement(
|
||||||
|
Joiner.on(" (").join(nCopies(101, "SELECT * FROM"))
|
||||||
|
.concat("t")
|
||||||
|
.concat(Joiner.on("").join(nCopies(100, ")")))));
|
||||||
|
assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLimitToPreventStackOverflowFromLargeComplexSubselectTree() {
|
||||||
|
// Test with queries in the form of `SELECT true OR true OR .. FROM (SELECT true OR true OR... FROM (... t) ...)
|
||||||
|
|
||||||
|
new SqlParser().createStatement(
|
||||||
|
Joiner.on(" (").join(nCopies(20, "SELECT ")).
|
||||||
|
concat(Joiner.on(" OR ").join(nCopies(50, "true"))).concat(" FROM")
|
||||||
|
.concat("t").concat(Joiner.on("").join(nCopies(19, ")"))));
|
||||||
|
|
||||||
|
ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createStatement(
|
||||||
|
Joiner.on(" (").join(nCopies(20, "SELECT ")).
|
||||||
|
concat(Joiner.on(" OR ").join(nCopies(100, "true"))).concat(" FROM")
|
||||||
|
.concat("t").concat(Joiner.on("").join(nCopies(19, ")")))));
|
||||||
|
assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
|
||||||
|
}
|
||||||
|
|
||||||
private LogicalPlan parseStatement(String sql) {
|
private LogicalPlan parseStatement(String sql) {
|
||||||
return new SqlParser().createStatement(sql);
|
return new SqlParser().createStatement(sql);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue