HHH-16805 typecheck arguments of HQL arithmetic operators (#6804)

This commit is contained in:
Gavin King 2023-06-15 17:26:59 +02:00 committed by GitHub
parent aff3c105b6
commit 126207bbfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 20 deletions

View File

@ -2507,7 +2507,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
break; break;
} }
} }
( (SqmCriteriaNodeBuilder) creationContext.getNodeBuilder() ).assertComparable( left, right ); SqmCriteriaNodeBuilder.assertComparable( left, right );
return new SqmComparisonPredicate( return new SqmComparisonPredicate(
left, left,
comparisonOperator, comparisonOperator,
@ -2708,7 +2708,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
else if ( inListContext instanceof HqlParser.SubqueryInListContext ) { else if ( inListContext instanceof HqlParser.SubqueryInListContext ) {
final HqlParser.SubqueryInListContext subQueryOrParamInListContext = (HqlParser.SubqueryInListContext) inListContext; final HqlParser.SubqueryInListContext subQueryOrParamInListContext = (HqlParser.SubqueryInListContext) inListContext;
final SqmSubQuery<?> subquery = visitSubquery(subQueryOrParamInListContext.subquery()); final SqmSubQuery<?> subquery = visitSubquery(subQueryOrParamInListContext.subquery());
( (SqmCriteriaNodeBuilder) creationContext.getNodeBuilder() ).assertComparable( testExpression, subquery ); SqmCriteriaNodeBuilder.assertComparable( testExpression, subquery );
return new SqmInSubQueryPredicate( return new SqmInSubQueryPredicate(
testExpression, testExpression,
subquery, subquery,
@ -2979,10 +2979,16 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
throw new ParsingException( "Expecting two operands to the additive operator" ); throw new ParsingException( "Expecting two operands to the additive operator" );
} }
final SqmExpression<?> left = (SqmExpression<?>) ctx.expression(0).accept(this);
final SqmExpression<?> right = (SqmExpression<?>) ctx.expression(1).accept(this);
final BinaryArithmeticOperator operator = (BinaryArithmeticOperator) ctx.additiveOperator().accept(this);
SqmCriteriaNodeBuilder.assertNumeric( left, operator );
SqmCriteriaNodeBuilder.assertNumeric( right, operator );
return new SqmBinaryArithmetic<>( return new SqmBinaryArithmetic<>(
(BinaryArithmeticOperator) ctx.getChild( 1 ).accept( this ), operator,
(SqmExpression<?>) ctx.getChild( 0 ).accept( this ), left,
(SqmExpression<?>) ctx.getChild( 2 ).accept( this ), right,
creationContext.getJpaMetamodel(), creationContext.getJpaMetamodel(),
creationContext.getNodeBuilder() creationContext.getNodeBuilder()
); );
@ -2994,9 +3000,11 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
throw new ParsingException( "Expecting two operands to the multiplicative operator" ); throw new ParsingException( "Expecting two operands to the multiplicative operator" );
} }
final SqmExpression<?> left = (SqmExpression<?>) ctx.getChild( 0 ).accept( this ); final SqmExpression<?> left = (SqmExpression<?>) ctx.expression(0).accept( this );
final SqmExpression<?> right = (SqmExpression<?>) ctx.getChild( 2 ).accept( this ); final SqmExpression<?> right = (SqmExpression<?>) ctx.expression(1).accept( this );
final BinaryArithmeticOperator operator = (BinaryArithmeticOperator) ctx.getChild( 1 ).accept( this ); final BinaryArithmeticOperator operator = (BinaryArithmeticOperator) ctx.multiplicativeOperator().accept( this );
SqmCriteriaNodeBuilder.assertNumeric( left, operator );
SqmCriteriaNodeBuilder.assertNumeric( right, operator );
if ( operator == BinaryArithmeticOperator.MODULO ) { if ( operator == BinaryArithmeticOperator.MODULO ) {
return getFunctionDescriptor("mod").generateSqmExpression( return getFunctionDescriptor("mod").generateSqmExpression(
@ -3036,9 +3044,11 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
@Override @Override
public Object visitFromDurationExpression(HqlParser.FromDurationExpressionContext ctx) { public Object visitFromDurationExpression(HqlParser.FromDurationExpressionContext ctx) {
SqmExpression<?> expression = (SqmExpression<?>) ctx.expression().accept( this );
return new SqmByUnit( return new SqmByUnit(
toDurationUnit( (SqmExtractUnit<?>) ctx.getChild( 2 ).accept( this ) ), toDurationUnit( (SqmExtractUnit<?>) ctx.datetimeField().accept( this ) ),
(SqmExpression<?>) ctx.getChild( 0 ).accept( this ), expression,
resolveExpressibleTypeBasic( Long.class ), resolveExpressibleTypeBasic( Long.class ),
creationContext.getNodeBuilder() creationContext.getNodeBuilder()
); );
@ -3046,9 +3056,12 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
@Override @Override
public SqmUnaryOperation<?> visitUnaryExpression(HqlParser.UnaryExpressionContext ctx) { public SqmUnaryOperation<?> visitUnaryExpression(HqlParser.UnaryExpressionContext ctx) {
final SqmExpression<?> expression = (SqmExpression<?>) ctx.expression().accept(this);
final UnaryArithmeticOperator operator = (UnaryArithmeticOperator) ctx.signOperator().accept(this);
SqmCriteriaNodeBuilder.assertNumeric( expression, operator );
return new SqmUnaryOperation<>( return new SqmUnaryOperation<>(
(UnaryArithmeticOperator) ctx.getChild( 0 ).accept( this ), operator,
(SqmExpression<?>) ctx.getChild( 1 ).accept( this ) expression
); );
} }

View File

@ -17,7 +17,9 @@ import java.time.Instant;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.LocalTime; import java.time.LocalTime;
import java.time.temporal.Temporal;
import java.time.temporal.TemporalAccessor; import java.time.temporal.TemporalAccessor;
import java.time.temporal.TemporalAmount;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
@ -242,7 +244,7 @@ public class SqmCriteriaNodeBuilder implements NodeBuilder, SqmCreationContext,
// Allow comparing an embeddable against a tuple literal // Allow comparing an embeddable against a tuple literal
|| lhsType instanceof EmbeddedSqmPathSource<?> && rhsType instanceof TupleType || lhsType instanceof EmbeddedSqmPathSource<?> && rhsType instanceof TupleType
|| rhsType instanceof EmbeddedSqmPathSource<?> && lhsType instanceof TupleType || rhsType instanceof EmbeddedSqmPathSource<?> && lhsType instanceof TupleType
// Since we don't know any better, we just allow any comparison with multi-valued parameters // Since we don't know any better, we just allow any comparison with multivalued parameters
|| lhsType instanceof MultiValueParameterType<?> || lhsType instanceof MultiValueParameterType<?>
|| rhsType instanceof MultiValueParameterType<?>) { || rhsType instanceof MultiValueParameterType<?>) {
return true; return true;
@ -2029,7 +2031,7 @@ public class SqmCriteriaNodeBuilder implements NodeBuilder, SqmCreationContext,
); );
} }
public void assertComparable(Expression<?> x, Expression<?> y) { public static void assertComparable(Expression<?> x, Expression<?> y) {
SqmExpression<?> left = (SqmExpression<?>) x; SqmExpression<?> left = (SqmExpression<?>) x;
SqmExpression<?> right = (SqmExpression<?>) y; SqmExpression<?> right = (SqmExpression<?>) y;
if ( left.getTupleLength() != null && right.getTupleLength() != null if ( left.getTupleLength() != null && right.getTupleLength() != null
@ -2041,14 +2043,52 @@ public class SqmCriteriaNodeBuilder implements NodeBuilder, SqmCreationContext,
if ( !areTypesComparable( leftType, rightType ) ) { if ( !areTypesComparable( leftType, rightType ) ) {
throw new SemanticException( throw new SemanticException(
String.format( String.format(
"Cannot compare left expression of type [%s] with right expression of type [%s]", "Cannot compare left expression of type '%s' with right expression of type '%s'",
leftType, leftType.getTypeName(),
rightType rightType.getTypeName()
) )
); );
} }
} }
public static void assertNumeric(SqmExpression<?> expression, BinaryArithmeticOperator op) {
final SqmExpressible<?> nodeType = expression.getNodeType();
if ( nodeType != null ) {
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
if ( !Number.class.isAssignableFrom( javaType )
&& !Temporal.class.isAssignableFrom( javaType )
&& !TemporalAmount.class.isAssignableFrom( javaType )
&& !java.util.Date.class.isAssignableFrom( javaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + nodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number', 'java.time.Temporal', or 'java.time.TemporalAmount')" );
}
}
}
public static void assertDuration(SqmExpression<?> expression) {
final SqmExpressible<?> nodeType = expression.getNodeType();
if ( nodeType != null ) {
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
if ( !TemporalAmount.class.isAssignableFrom( javaType ) ) {
throw new SemanticException( "Operand of 'by' is of type '" + nodeType.getTypeName() + "' which is not a duration"
+ " (it is not an instance of 'java.time.TemporalAmount')" );
}
}
}
public static void assertNumeric(SqmExpression<?> expression, UnaryArithmeticOperator op) {
final SqmExpressible<?> nodeType = expression.getNodeType();
if ( nodeType != null ) {
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
if ( !Number.class.isAssignableFrom( javaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorChar()
+ " is of type '" + nodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number')" );
}
}
}
@Override @Override
public SqmPredicate equal(Expression<?> x, Expression<?> y) { public SqmPredicate equal(Expression<?> x, Expression<?> y) {
assertComparable( x, y ); assertComparable( x, y );

View File

@ -11,9 +11,34 @@ import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@SessionFactory @SessionFactory
@DomainModel(annotatedClasses = SubqueryPredicateTypingTest.Book.class) @DomainModel(annotatedClasses = HqlOperatorTypesafetyTest.Book.class)
public class SubqueryPredicateTypingTest { public class HqlOperatorTypesafetyTest {
@Test void test(SessionFactoryScope scope) { @Test void testOperatorTyping(SessionFactoryScope scope) {
scope.inSession( s -> {
// these should succeed
s.createSelectionQuery("from Book where title = 'Hibernate'").getResultList();
s.createSelectionQuery("from Book where title > ''").getResultList();
s.createSelectionQuery("select edition + 1 from Book").getResultList();
s.createSelectionQuery("select title || '!' from Book").getResultList();
try {
s.createSelectionQuery("from Book where title = 1").getResultList();
fail();
}
catch (SemanticException se) {}
try {
s.createSelectionQuery("from Book where title > 1").getResultList();
fail();
}
catch (SemanticException se) {}
try {
s.createSelectionQuery("select title + 1 from Book").getResultList();
fail();
}
catch (SemanticException se) {}
});
}
@Test void testSubselectTyping(SessionFactoryScope scope) {
scope.inSession( s -> { scope.inSession( s -> {
// these should succeed // these should succeed
s.createSelectionQuery("from Book where title in (select title from Book)").getResultList(); s.createSelectionQuery("from Book where title in (select title from Book)").getResultList();
@ -59,5 +84,6 @@ public class SubqueryPredicateTypingTest {
static class Book { static class Book {
@Id String isbn; @Id String isbn;
String title; String title;
int edition;
} }
} }