diff --git a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java index 2700459fce..4c3b0f148e 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java @@ -2507,7 +2507,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem break; } } - ( (SqmCriteriaNodeBuilder) creationContext.getNodeBuilder() ).assertComparable( left, right ); + SqmCriteriaNodeBuilder.assertComparable( left, right ); return new SqmComparisonPredicate( left, comparisonOperator, @@ -2708,7 +2708,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem else if ( inListContext instanceof HqlParser.SubqueryInListContext ) { final HqlParser.SubqueryInListContext subQueryOrParamInListContext = (HqlParser.SubqueryInListContext) inListContext; final SqmSubQuery subquery = visitSubquery(subQueryOrParamInListContext.subquery()); - ( (SqmCriteriaNodeBuilder) creationContext.getNodeBuilder() ).assertComparable( testExpression, subquery ); + SqmCriteriaNodeBuilder.assertComparable( testExpression, subquery ); return new SqmInSubQueryPredicate( testExpression, subquery, @@ -2979,10 +2979,16 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem throw new ParsingException( "Expecting two operands to the additive operator" ); } + final SqmExpression left = (SqmExpression) ctx.expression(0).accept(this); + final SqmExpression right = (SqmExpression) ctx.expression(1).accept(this); + final BinaryArithmeticOperator operator = (BinaryArithmeticOperator) ctx.additiveOperator().accept(this); + SqmCriteriaNodeBuilder.assertNumeric( left, operator ); + SqmCriteriaNodeBuilder.assertNumeric( right, operator ); + return new SqmBinaryArithmetic<>( - (BinaryArithmeticOperator) ctx.getChild( 1 ).accept( this ), - (SqmExpression) ctx.getChild( 0 ).accept( this ), - (SqmExpression) ctx.getChild( 2 ).accept( this ), + operator, + left, + right, creationContext.getJpaMetamodel(), creationContext.getNodeBuilder() ); @@ -2994,9 +3000,11 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem throw new ParsingException( "Expecting two operands to the multiplicative operator" ); } - final SqmExpression left = (SqmExpression) ctx.getChild( 0 ).accept( this ); - final SqmExpression right = (SqmExpression) ctx.getChild( 2 ).accept( this ); - final BinaryArithmeticOperator operator = (BinaryArithmeticOperator) ctx.getChild( 1 ).accept( this ); + final SqmExpression left = (SqmExpression) ctx.expression(0).accept( this ); + final SqmExpression right = (SqmExpression) ctx.expression(1).accept( this ); + final BinaryArithmeticOperator operator = (BinaryArithmeticOperator) ctx.multiplicativeOperator().accept( this ); + SqmCriteriaNodeBuilder.assertNumeric( left, operator ); + SqmCriteriaNodeBuilder.assertNumeric( right, operator ); if ( operator == BinaryArithmeticOperator.MODULO ) { return getFunctionDescriptor("mod").generateSqmExpression( @@ -3036,9 +3044,11 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem @Override public Object visitFromDurationExpression(HqlParser.FromDurationExpressionContext ctx) { + SqmExpression expression = (SqmExpression) ctx.expression().accept( this ); + return new SqmByUnit( - toDurationUnit( (SqmExtractUnit) ctx.getChild( 2 ).accept( this ) ), - (SqmExpression) ctx.getChild( 0 ).accept( this ), + toDurationUnit( (SqmExtractUnit) ctx.datetimeField().accept( this ) ), + expression, resolveExpressibleTypeBasic( Long.class ), creationContext.getNodeBuilder() ); @@ -3046,9 +3056,12 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem @Override public SqmUnaryOperation visitUnaryExpression(HqlParser.UnaryExpressionContext ctx) { + final SqmExpression expression = (SqmExpression) ctx.expression().accept(this); + final UnaryArithmeticOperator operator = (UnaryArithmeticOperator) ctx.signOperator().accept(this); + SqmCriteriaNodeBuilder.assertNumeric( expression, operator ); return new SqmUnaryOperation<>( - (UnaryArithmeticOperator) ctx.getChild( 0 ).accept( this ), - (SqmExpression) ctx.getChild( 1 ).accept( this ) + operator, + expression ); } diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java index de2425c692..6b05c165df 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java @@ -17,7 +17,9 @@ import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; +import java.time.temporal.Temporal; import java.time.temporal.TemporalAccessor; +import java.time.temporal.TemporalAmount; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -242,7 +244,7 @@ public class SqmCriteriaNodeBuilder implements NodeBuilder, SqmCreationContext, // Allow comparing an embeddable against a tuple literal || lhsType instanceof EmbeddedSqmPathSource && rhsType instanceof TupleType || rhsType instanceof EmbeddedSqmPathSource && lhsType instanceof TupleType - // Since we don't know any better, we just allow any comparison with multi-valued parameters + // Since we don't know any better, we just allow any comparison with multivalued parameters || lhsType instanceof MultiValueParameterType || rhsType instanceof MultiValueParameterType) { return true; @@ -2029,7 +2031,7 @@ public class SqmCriteriaNodeBuilder implements NodeBuilder, SqmCreationContext, ); } - public void assertComparable(Expression x, Expression y) { + public static void assertComparable(Expression x, Expression y) { SqmExpression left = (SqmExpression) x; SqmExpression right = (SqmExpression) y; if ( left.getTupleLength() != null && right.getTupleLength() != null @@ -2041,14 +2043,52 @@ public class SqmCriteriaNodeBuilder implements NodeBuilder, SqmCreationContext, if ( !areTypesComparable( leftType, rightType ) ) { throw new SemanticException( String.format( - "Cannot compare left expression of type [%s] with right expression of type [%s]", - leftType, - rightType + "Cannot compare left expression of type '%s' with right expression of type '%s'", + leftType.getTypeName(), + rightType.getTypeName() ) ); } } + public static void assertNumeric(SqmExpression expression, BinaryArithmeticOperator op) { + final SqmExpressible nodeType = expression.getNodeType(); + if ( nodeType != null ) { + final Class javaType = nodeType.getExpressibleJavaType().getJavaTypeClass(); + if ( !Number.class.isAssignableFrom( javaType ) + && !Temporal.class.isAssignableFrom( javaType ) + && !TemporalAmount.class.isAssignableFrom( javaType ) + && !java.util.Date.class.isAssignableFrom( javaType ) ) { + throw new SemanticException( "Operand of " + op.getOperatorSqlText() + + " is of type '" + nodeType.getTypeName() + "' which is not a numeric type" + + " (it is not an instance of 'java.lang.Number', 'java.time.Temporal', or 'java.time.TemporalAmount')" ); + } + } + } + + public static void assertDuration(SqmExpression expression) { + final SqmExpressible nodeType = expression.getNodeType(); + if ( nodeType != null ) { + final Class javaType = nodeType.getExpressibleJavaType().getJavaTypeClass(); + if ( !TemporalAmount.class.isAssignableFrom( javaType ) ) { + throw new SemanticException( "Operand of 'by' is of type '" + nodeType.getTypeName() + "' which is not a duration" + + " (it is not an instance of 'java.time.TemporalAmount')" ); + } + } + } + + public static void assertNumeric(SqmExpression expression, UnaryArithmeticOperator op) { + final SqmExpressible nodeType = expression.getNodeType(); + if ( nodeType != null ) { + final Class javaType = nodeType.getExpressibleJavaType().getJavaTypeClass(); + if ( !Number.class.isAssignableFrom( javaType ) ) { + throw new SemanticException( "Operand of " + op.getOperatorChar() + + " is of type '" + nodeType.getTypeName() + "' which is not a numeric type" + + " (it is not an instance of 'java.lang.Number')" ); + } + } + } + @Override public SqmPredicate equal(Expression x, Expression y) { assertComparable( x, y ); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/hql/SubqueryPredicateTypingTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/hql/HqlOperatorTypesafetyTest.java similarity index 63% rename from hibernate-core/src/test/java/org/hibernate/orm/test/hql/SubqueryPredicateTypingTest.java rename to hibernate-core/src/test/java/org/hibernate/orm/test/hql/HqlOperatorTypesafetyTest.java index 96e9071f1d..d5d4ab4313 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/hql/SubqueryPredicateTypingTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/hql/HqlOperatorTypesafetyTest.java @@ -11,9 +11,34 @@ import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.fail; @SessionFactory -@DomainModel(annotatedClasses = SubqueryPredicateTypingTest.Book.class) -public class SubqueryPredicateTypingTest { - @Test void test(SessionFactoryScope scope) { +@DomainModel(annotatedClasses = HqlOperatorTypesafetyTest.Book.class) +public class HqlOperatorTypesafetyTest { + @Test void testOperatorTyping(SessionFactoryScope scope) { + scope.inSession( s -> { + // these should succeed + s.createSelectionQuery("from Book where title = 'Hibernate'").getResultList(); + s.createSelectionQuery("from Book where title > ''").getResultList(); + s.createSelectionQuery("select edition + 1 from Book").getResultList(); + s.createSelectionQuery("select title || '!' from Book").getResultList(); + try { + s.createSelectionQuery("from Book where title = 1").getResultList(); + fail(); + } + catch (SemanticException se) {} + try { + s.createSelectionQuery("from Book where title > 1").getResultList(); + fail(); + } + catch (SemanticException se) {} + try { + s.createSelectionQuery("select title + 1 from Book").getResultList(); + fail(); + } + catch (SemanticException se) {} + }); + } + + @Test void testSubselectTyping(SessionFactoryScope scope) { scope.inSession( s -> { // these should succeed s.createSelectionQuery("from Book where title in (select title from Book)").getResultList(); @@ -59,5 +84,6 @@ public class SubqueryPredicateTypingTest { static class Book { @Id String isbn; String title; + int edition; } }