SQL: Prevent StackOverflowError when parsing large statements (#33902)

Implement circuit breaker logic in the parser which catches expressions
that can blow up the tree and result in StackOverflowError being thrown.

Co-authored-by: Costin Leau <costin.leau@gmail.com>
This commit is contained in:
Marios Trivyzas 2018-09-25 19:20:25 +02:00 committed by GitHub
parent cc70352b3f
commit 5840be6a6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 125 additions and 7 deletions

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.sql.parser;
import com.carrotsearch.hppc.ObjectShortHashMap;
import org.antlr.v4.runtime.BaseErrorListener;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CommonToken;
@ -22,8 +23,8 @@ import org.antlr.v4.runtime.atn.PredictionMode;
import org.antlr.v4.runtime.dfa.DFA;
import org.antlr.v4.runtime.misc.Pair;
import org.antlr.v4.runtime.tree.TerminalNode;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;
@ -41,7 +42,8 @@ import java.util.function.Function;
import static java.lang.String.format;
public class SqlParser {
private static final Logger log = Loggers.getLogger(SqlParser.class);
private static final Logger log = LogManager.getLogger();
private final boolean DEBUG = false;
@ -83,7 +85,9 @@ public class SqlParser {
return invokeParser(expression, params, SqlBaseParser::singleExpression, AstBuilder::expression);
}
private <T> T invokeParser(String sql, List<SqlTypedParamValue> params, Function<SqlBaseParser, ParserRuleContext> parseFunction,
private <T> T invokeParser(String sql,
List<SqlTypedParamValue> params, Function<SqlBaseParser,
ParserRuleContext> parseFunction,
BiFunction<AstBuilder, ParserRuleContext, T> visitor) {
SqlBaseLexer lexer = new SqlBaseLexer(new CaseInsensitiveStream(sql));
@ -96,6 +100,7 @@ public class SqlParser {
CommonTokenStream tokenStream = new CommonTokenStream(tokenSource);
SqlBaseParser parser = new SqlBaseParser(tokenStream);
parser.addParseListener(new CircuitBreakerListener());
parser.addParseListener(new PostProcessor(Arrays.asList(parser.getRuleNames())));
parser.removeErrorListeners();
@ -125,7 +130,7 @@ public class SqlParser {
return visitor.apply(new AstBuilder(paramTokens), tree);
}
private void debug(SqlBaseParser parser) {
private static void debug(SqlBaseParser parser) {
// when debugging, use the exact prediction mode (needed for diagnostics as well)
parser.getInterpreter().setPredictionMode(PredictionMode.LL_EXACT_AMBIG_DETECTION);
@ -154,7 +159,7 @@ public class SqlParser {
public void exitBackQuotedIdentifier(SqlBaseParser.BackQuotedIdentifierContext context) {
Token token = context.BACKQUOTED_IDENTIFIER().getSymbol();
throw new ParsingException(
"backquoted indetifiers not supported; please use double quotes instead",
"backquoted identifiers not supported; please use double quotes instead",
null,
token.getLine(),
token.getCharPositionInLine());
@ -205,6 +210,35 @@ public class SqlParser {
}
}
/**
* Used to catch large expressions that can lead to stack overflows
*/
private class CircuitBreakerListener extends SqlBaseBaseListener {
private static final short MAX_RULE_DEPTH = 100;
// Keep current depth for every rule visited.
// The totalDepth alone cannot be used as expressions like: e1 OR e2 OR e3 OR ...
// are processed as e1 OR (e2 OR (e3 OR (... and this results in the totalDepth not growing
// while the stack call depth is, leading to a StackOverflowError.
private ObjectShortHashMap<String> depthCounts = new ObjectShortHashMap<>();
@Override
public void enterEveryRule(ParserRuleContext ctx) {
short currentDepth = depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 1, (short) 1);
if (currentDepth > MAX_RULE_DEPTH) {
throw new ParsingException("expression is too large to parse, (tree's depth exceeds {})", MAX_RULE_DEPTH);
}
super.enterEveryRule(ctx);
}
@Override
public void exitEveryRule(ParserRuleContext ctx) {
depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 0, (short) -1);
super.exitEveryRule(ctx);
}
}
private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener() {
@Override
public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int line,

View File

@ -70,7 +70,7 @@ public class QuotingTests extends ESTestCase {
String name = "@timestamp";
ParsingException ex = expectThrows(ParsingException.class, () ->
new SqlParser().createExpression(quote + name + quote));
assertThat(ex.getMessage(), equalTo("line 1:1: backquoted indetifiers not supported; please use double quotes instead"));
assertThat(ex.getMessage(), equalTo("line 1:1: backquoted identifiers not supported; please use double quotes instead"));
}
public void testQuotedAttributeAndQualifier() {
@ -92,7 +92,7 @@ public class QuotingTests extends ESTestCase {
String name = "@timestamp";
ParsingException ex = expectThrows(ParsingException.class, () ->
new SqlParser().createExpression(quote + qualifier + quote + "." + quote + name + quote));
assertThat(ex.getMessage(), equalTo("line 1:1: backquoted indetifiers not supported; please use double quotes instead"));
assertThat(ex.getMessage(), equalTo("line 1:1: backquoted identifiers not supported; please use double quotes instead"));
}
public void testGreedyQuoting() {

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.sql.parser;
import com.google.common.base.Joiner;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.sql.expression.NamedExpression;
import org.elasticsearch.xpack.sql.expression.Order;
@ -22,6 +23,7 @@ import org.elasticsearch.xpack.sql.plan.logical.Project;
import java.util.ArrayList;
import java.util.List;
import static java.util.Collections.nCopies;
import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize;
@ -136,6 +138,88 @@ public class SqlParserTests extends ESTestCase {
assertThat(mmqp.optionMap(), hasEntry("fuzzy_rewrite", "scoring_boolean"));
}
public void testLimitToPreventStackOverflowFromLargeUnaryBooleanExpression() {
// Create expression in the form of NOT(NOT(NOT ... (b) ...)
// 40 elements is ok
new SqlParser().createExpression(
Joiner.on("").join(nCopies(40, "NOT(")).concat("b").concat(Joiner.on("").join(nCopies(40, ")"))));
// 100 elements parser's "circuit breaker" is triggered
ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createExpression(
Joiner.on("").join(nCopies(100, "NOT(")).concat("b").concat(Joiner.on("").join(nCopies(100, ")")))));
assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
}
public void testLimitToPreventStackOverflowFromLargeBinaryBooleanExpression() {
// Create expression in the form of a = b OR a = b OR ... a = b
// 50 elements is ok
new SqlParser().createExpression(Joiner.on(" OR ").join(nCopies(50, "a = b")));
// 100 elements parser's "circuit breaker" is triggered
ParsingException e = expectThrows(ParsingException.class, () ->
new SqlParser().createExpression(Joiner.on(" OR ").join(nCopies(100, "a = b"))));
assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
}
public void testLimitToPreventStackOverflowFromLargeUnaryArithmeticExpression() {
// Create expression in the form of abs(abs(abs ... (i) ...)
// 50 elements is ok
new SqlParser().createExpression(
Joiner.on("").join(nCopies(50, "abs(")).concat("i").concat(Joiner.on("").join(nCopies(50, ")"))));
// 101 elements parser's "circuit breaker" is triggered
ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createExpression(
Joiner.on("").join(nCopies(101, "abs(")).concat("i").concat(Joiner.on("").join(nCopies(101, ")")))));
assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
}
public void testLimitToPreventStackOverflowFromLargeBinaryArithmeticExpression() {
// Create expression in the form of a + a + a + ... + a
// 100 elements is ok
new SqlParser().createExpression(Joiner.on(" + ").join(nCopies(100, "a")));
// 101 elements parser's "circuit breaker" is triggered
ParsingException e = expectThrows(ParsingException.class, () ->
new SqlParser().createExpression(Joiner.on(" + ").join(nCopies(101, "a"))));
assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
}
public void testLimitToPreventStackOverflowFromLargeSubselectTree() {
// Test with queries in the form of `SELECT * FROM (SELECT * FROM (... t) ...)
// 100 elements is ok
new SqlParser().createStatement(
Joiner.on(" (").join(nCopies(100, "SELECT * FROM"))
.concat("t")
.concat(Joiner.on("").join(nCopies(99, ")"))));
// 101 elements parser's "circuit breaker" is triggered
ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createStatement(
Joiner.on(" (").join(nCopies(101, "SELECT * FROM"))
.concat("t")
.concat(Joiner.on("").join(nCopies(100, ")")))));
assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
}
public void testLimitToPreventStackOverflowFromLargeComplexSubselectTree() {
// Test with queries in the form of `SELECT true OR true OR .. FROM (SELECT true OR true OR... FROM (... t) ...)
new SqlParser().createStatement(
Joiner.on(" (").join(nCopies(20, "SELECT ")).
concat(Joiner.on(" OR ").join(nCopies(50, "true"))).concat(" FROM")
.concat("t").concat(Joiner.on("").join(nCopies(19, ")"))));
ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createStatement(
Joiner.on(" (").join(nCopies(20, "SELECT ")).
concat(Joiner.on(" OR ").join(nCopies(100, "true"))).concat(" FROM")
.concat("t").concat(Joiner.on("").join(nCopies(19, ")")))));
assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
}
private LogicalPlan parseStatement(String sql) {
return new SqlParser().createStatement(sql);
}