EQL: Disallow chained comparisons (#62567) (#62601)

Expressions like `1 = 2 = 3 = 4` or `1 < 2 = 3 >= 4` were treated with
leftmost priority: ((1 = 2) = 3) = 4 which can lead to confusing
results. Since such expressions don't make so much change for EQL
filters we disallow them in the parser to prevent unexpected results
from their bad usage.

Major DBs like PostgreSQL and Oracle also disallow them in their SQL
syntax. (counter example would be MySQL which interprets them as we did
before with leftmost priority).

Fixes: #61654
(cherry picked from commit 8f94981bb093f104228d267b532e0a3d5b7f6a38)
This commit is contained in:
Marios Trivyzas 2020-09-18 10:48:14 +02:00 committed by GitHub
parent 81f2f84177
commit b072de4ce0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 500 additions and 359 deletions

View File

@ -84,11 +84,15 @@ booleanExpression
valueExpression valueExpression
: primaryExpression predicate? #valueExpressionDefault : operatorExpression #valueExpressionDefault
| operator=(MINUS | PLUS) valueExpression #arithmeticUnary | left=operatorExpression comparisonOperator right=operatorExpression #comparison
| left=valueExpression operator=(ASTERISK | SLASH | PERCENT) right=valueExpression #arithmeticBinary ;
| left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary
| left=valueExpression comparisonOperator right=valueExpression #comparison operatorExpression
: primaryExpression predicate? #operatorExpressionDefault
| operator=(MINUS | PLUS) operatorExpression #arithmeticUnary
| left=operatorExpression operator=(ASTERISK | SLASH | PERCENT) right=operatorExpression #arithmeticBinary
| left=operatorExpression operator=(PLUS | MINUS) right=operatorExpression #arithmeticBinary
; ;
// workaround for // workaround for

View File

@ -263,6 +263,18 @@ class EqlBaseBaseListener implements EqlBaseListener {
* <p>The default implementation does nothing.</p> * <p>The default implementation does nothing.</p>
*/ */
@Override public void exitComparison(EqlBaseParser.ComparisonContext ctx) { } @Override public void exitComparison(EqlBaseParser.ComparisonContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterOperatorExpressionDefault(EqlBaseParser.OperatorExpressionDefaultContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitOperatorExpressionDefault(EqlBaseParser.OperatorExpressionDefaultContext ctx) { }
/** /**
* {@inheritDoc} * {@inheritDoc}
* *

View File

@ -158,6 +158,13 @@ class EqlBaseBaseVisitor<T> extends AbstractParseTreeVisitor<T> implements EqlBa
* {@link #visitChildren} on {@code ctx}.</p> * {@link #visitChildren} on {@code ctx}.</p>
*/ */
@Override public T visitComparison(EqlBaseParser.ComparisonContext ctx) { return visitChildren(ctx); } @Override public T visitComparison(EqlBaseParser.ComparisonContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
* <p>The default implementation returns the result of calling
* {@link #visitChildren} on {@code ctx}.</p>
*/
@Override public T visitOperatorExpressionDefault(EqlBaseParser.OperatorExpressionDefaultContext ctx) { return visitChildren(ctx); }
/** /**
* {@inheritDoc} * {@inheritDoc}
* *

View File

@ -229,27 +229,39 @@ interface EqlBaseListener extends ParseTreeListener {
* @param ctx the parse tree * @param ctx the parse tree
*/ */
void exitComparison(EqlBaseParser.ComparisonContext ctx); void exitComparison(EqlBaseParser.ComparisonContext ctx);
/**
* Enter a parse tree produced by the {@code operatorExpressionDefault}
* labeled alternative in {@link EqlBaseParser#operatorExpression}.
* @param ctx the parse tree
*/
void enterOperatorExpressionDefault(EqlBaseParser.OperatorExpressionDefaultContext ctx);
/**
* Exit a parse tree produced by the {@code operatorExpressionDefault}
* labeled alternative in {@link EqlBaseParser#operatorExpression}.
* @param ctx the parse tree
*/
void exitOperatorExpressionDefault(EqlBaseParser.OperatorExpressionDefaultContext ctx);
/** /**
* Enter a parse tree produced by the {@code arithmeticBinary} * Enter a parse tree produced by the {@code arithmeticBinary}
* labeled alternative in {@link EqlBaseParser#valueExpression}. * labeled alternative in {@link EqlBaseParser#operatorExpression}.
* @param ctx the parse tree * @param ctx the parse tree
*/ */
void enterArithmeticBinary(EqlBaseParser.ArithmeticBinaryContext ctx); void enterArithmeticBinary(EqlBaseParser.ArithmeticBinaryContext ctx);
/** /**
* Exit a parse tree produced by the {@code arithmeticBinary} * Exit a parse tree produced by the {@code arithmeticBinary}
* labeled alternative in {@link EqlBaseParser#valueExpression}. * labeled alternative in {@link EqlBaseParser#operatorExpression}.
* @param ctx the parse tree * @param ctx the parse tree
*/ */
void exitArithmeticBinary(EqlBaseParser.ArithmeticBinaryContext ctx); void exitArithmeticBinary(EqlBaseParser.ArithmeticBinaryContext ctx);
/** /**
* Enter a parse tree produced by the {@code arithmeticUnary} * Enter a parse tree produced by the {@code arithmeticUnary}
* labeled alternative in {@link EqlBaseParser#valueExpression}. * labeled alternative in {@link EqlBaseParser#operatorExpression}.
* @param ctx the parse tree * @param ctx the parse tree
*/ */
void enterArithmeticUnary(EqlBaseParser.ArithmeticUnaryContext ctx); void enterArithmeticUnary(EqlBaseParser.ArithmeticUnaryContext ctx);
/** /**
* Exit a parse tree produced by the {@code arithmeticUnary} * Exit a parse tree produced by the {@code arithmeticUnary}
* labeled alternative in {@link EqlBaseParser#valueExpression}. * labeled alternative in {@link EqlBaseParser#operatorExpression}.
* @param ctx the parse tree * @param ctx the parse tree
*/ */
void exitArithmeticUnary(EqlBaseParser.ArithmeticUnaryContext ctx); void exitArithmeticUnary(EqlBaseParser.ArithmeticUnaryContext ctx);

View File

@ -142,16 +142,23 @@ interface EqlBaseVisitor<T> extends ParseTreeVisitor<T> {
* @return the visitor result * @return the visitor result
*/ */
T visitComparison(EqlBaseParser.ComparisonContext ctx); T visitComparison(EqlBaseParser.ComparisonContext ctx);
/**
* Visit a parse tree produced by the {@code operatorExpressionDefault}
* labeled alternative in {@link EqlBaseParser#operatorExpression}.
* @param ctx the parse tree
* @return the visitor result
*/
T visitOperatorExpressionDefault(EqlBaseParser.OperatorExpressionDefaultContext ctx);
/** /**
* Visit a parse tree produced by the {@code arithmeticBinary} * Visit a parse tree produced by the {@code arithmeticBinary}
* labeled alternative in {@link EqlBaseParser#valueExpression}. * labeled alternative in {@link EqlBaseParser#operatorExpression}.
* @param ctx the parse tree * @param ctx the parse tree
* @return the visitor result * @return the visitor result
*/ */
T visitArithmeticBinary(EqlBaseParser.ArithmeticBinaryContext ctx); T visitArithmeticBinary(EqlBaseParser.ArithmeticBinaryContext ctx);
/** /**
* Visit a parse tree produced by the {@code arithmeticUnary} * Visit a parse tree produced by the {@code arithmeticUnary}
* labeled alternative in {@link EqlBaseParser#valueExpression}. * labeled alternative in {@link EqlBaseParser#operatorExpression}.
* @param ctx the parse tree * @param ctx the parse tree
* @return the visitor result * @return the visitor result
*/ */

View File

@ -17,7 +17,6 @@ import org.elasticsearch.xpack.eql.parser.EqlBaseParser.JoinKeysContext;
import org.elasticsearch.xpack.eql.parser.EqlBaseParser.LogicalBinaryContext; import org.elasticsearch.xpack.eql.parser.EqlBaseParser.LogicalBinaryContext;
import org.elasticsearch.xpack.eql.parser.EqlBaseParser.LogicalNotContext; import org.elasticsearch.xpack.eql.parser.EqlBaseParser.LogicalNotContext;
import org.elasticsearch.xpack.eql.parser.EqlBaseParser.PredicateContext; import org.elasticsearch.xpack.eql.parser.EqlBaseParser.PredicateContext;
import org.elasticsearch.xpack.eql.parser.EqlBaseParser.ValueExpressionDefaultContext;
import org.elasticsearch.xpack.ql.QlIllegalArgumentException; import org.elasticsearch.xpack.ql.QlIllegalArgumentException;
import org.elasticsearch.xpack.ql.expression.Attribute; import org.elasticsearch.xpack.ql.expression.Attribute;
import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.expression.Expression;
@ -85,7 +84,7 @@ public class ExpressionBuilder extends IdentifierBuilder {
@Override @Override
public Expression visitArithmeticUnary(ArithmeticUnaryContext ctx) { public Expression visitArithmeticUnary(ArithmeticUnaryContext ctx) {
Expression expr = expression(ctx.valueExpression()); Expression expr = expression(ctx.operatorExpression());
Source source = source(ctx); Source source = source(ctx);
int type = ctx.operator.getType(); int type = ctx.operator.getType();
@ -149,7 +148,7 @@ public class ExpressionBuilder extends IdentifierBuilder {
} }
@Override @Override
public Expression visitValueExpressionDefault(ValueExpressionDefaultContext ctx) { public Object visitOperatorExpressionDefault(EqlBaseParser.OperatorExpressionDefaultContext ctx) {
Expression expr = expression(ctx.primaryExpression()); Expression expr = expression(ctx.primaryExpression());
Source source = source(ctx); Source source = source(ctx);

View File

@ -15,6 +15,7 @@ import org.elasticsearch.xpack.ql.expression.function.UnresolvedFunction;
import org.elasticsearch.xpack.ql.expression.predicate.logical.And; import org.elasticsearch.xpack.ql.expression.predicate.logical.And;
import org.elasticsearch.xpack.ql.expression.predicate.logical.Not; import org.elasticsearch.xpack.ql.expression.predicate.logical.Not;
import org.elasticsearch.xpack.ql.expression.predicate.logical.Or; import org.elasticsearch.xpack.ql.expression.predicate.logical.Or;
import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.Mul;
import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.Neg; import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.Neg;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan;
@ -96,7 +97,7 @@ public class ExpressionTests extends ESTestCase {
assertEquals(expected, parsed); assertEquals(expected, parsed);
} }
public void testSingleQuotedUnescapedStringForbidden() { public void testSingleQuotedUnescapedStringDisallowed() {
ParsingException e = expectThrows(ParsingException.class, () -> expr("?'hello world'")); ParsingException e = expectThrows(ParsingException.class, () -> expr("?'hello world'"));
assertEquals("line 1:2: Use double quotes [\"] to define string literals, not single quotes [']", assertEquals("line 1:2: Use double quotes [\"] to define string literals, not single quotes [']",
e.getMessage()); e.getMessage());
@ -221,4 +222,43 @@ public class ExpressionTests extends ESTestCase {
expectThrows(ParsingException.class, "Expected syntax error", expectThrows(ParsingException.class, "Expected syntax error",
() -> expr("name in ()")); () -> expr("name in ()"));
} }
public void testComplexComparison() {
String comparison;
if (randomBoolean()) {
comparison = "1 * -2 <= -3 * 4";
} else {
comparison = "(1 * -2) <= (-3 * 4)";
}
Mul left = new Mul(null,
new Literal(null, 1, DataTypes.INTEGER),
new Neg(null, new Literal(null, 2, DataTypes.INTEGER)));
Mul right = new Mul(null,
new Neg(null, new Literal(null, 3, DataTypes.INTEGER)),
new Literal(null, 4, DataTypes.INTEGER));
assertEquals(new LessThanOrEqual(null, left, right, UTC), expr(comparison));
}
public void testChainedComparisonsDisallowed() {
int noComparisions = randomIntBetween(2, 20);
String firstComparator = "";
String secondComparator = "";
StringBuilder sb = new StringBuilder("a ");
for (int i = 0 ; i < noComparisions; i++) {
String comparator = randomFrom("=", "==", "!=", "<", "<=", ">", ">=");
sb.append(comparator).append(" a ");
if (i == 0) {
firstComparator = comparator;
} else if (i == 1) {
secondComparator = comparator;
}
}
ParsingException e = expectThrows(ParsingException.class, () -> expr(sb.toString()));
assertEquals("line 1:" + (6 + firstComparator.length()) + ": mismatched input '" + secondComparator +
"' expecting {<EOF>, 'and', 'in', 'not', 'or', '+', '-', '*', '/', '%', '.', '['}",
e.getMessage());
}
} }