From bf297e0e877e8d558e5330f6d1c7975c4e1506d0 Mon Sep 17 00:00:00 2001 From: Gavin King Date: Sat, 19 Aug 2023 22:32:59 +0200 Subject: [PATCH] HHH-16891 typechecking for arithmetic expressions --- .../hql/internal/SemanticQueryBuilder.java | 17 +-- .../sqm/internal/SqmCriteriaNodeBuilder.java | 41 ------ .../query/sqm/internal/TypecheckUtil.java | 125 ++++++++++++++++++ 3 files changed, 131 insertions(+), 52 deletions(-) diff --git a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java index 7a6425bcf3..f452e66903 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java @@ -98,9 +98,9 @@ import org.hibernate.query.sqm.function.NamedSqmFunctionDescriptor; import org.hibernate.query.sqm.function.SqmFunctionDescriptor; import org.hibernate.query.sqm.internal.ParameterCollector; import org.hibernate.query.sqm.internal.SqmCreationProcessingStateImpl; -import org.hibernate.query.sqm.internal.SqmCriteriaNodeBuilder; import org.hibernate.query.sqm.internal.SqmDmlCreationProcessingState; import org.hibernate.query.sqm.internal.SqmQueryPartCreationProcessingStateStandardImpl; +import org.hibernate.query.sqm.internal.TypecheckUtil; import org.hibernate.query.sqm.produce.function.FunctionArgumentException; import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; import org.hibernate.query.sqm.spi.ParameterDeclarationContext; @@ -2946,8 +2946,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem final SqmExpression left = (SqmExpression) ctx.expression(0).accept(this); final SqmExpression right = (SqmExpression) ctx.expression(1).accept(this); final BinaryArithmeticOperator operator = (BinaryArithmeticOperator) ctx.additiveOperator().accept(this); - SqmCriteriaNodeBuilder.assertNumeric( left, operator ); - SqmCriteriaNodeBuilder.assertNumeric( right, operator ); + TypecheckUtil.assertOperable( left, right, operator ); return new SqmBinaryArithmetic<>( operator, @@ -2967,8 +2966,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem final SqmExpression left = (SqmExpression) ctx.expression(0).accept( this ); final SqmExpression right = (SqmExpression) ctx.expression(1).accept( this ); final BinaryArithmeticOperator operator = (BinaryArithmeticOperator) ctx.multiplicativeOperator().accept( this ); - SqmCriteriaNodeBuilder.assertNumeric( left, operator ); - SqmCriteriaNodeBuilder.assertNumeric( right, operator ); + TypecheckUtil.assertOperable( left, right, operator ); if ( operator == BinaryArithmeticOperator.MODULO ) { return getFunctionDescriptor("mod").generateSqmExpression( @@ -3009,7 +3007,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem @Override public Object visitFromDurationExpression(HqlParser.FromDurationExpressionContext ctx) { SqmExpression expression = (SqmExpression) ctx.expression().accept( this ); - + TypecheckUtil.assertDuration( expression ); return new SqmByUnit( toDurationUnit( (SqmExtractUnit) ctx.datetimeField().accept( this ) ), expression, @@ -3022,11 +3020,8 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem public SqmUnaryOperation visitUnaryExpression(HqlParser.UnaryExpressionContext ctx) { final SqmExpression expression = (SqmExpression) ctx.expression().accept(this); final UnaryArithmeticOperator operator = (UnaryArithmeticOperator) ctx.signOperator().accept(this); - SqmCriteriaNodeBuilder.assertNumeric( expression, operator ); - return new SqmUnaryOperation<>( - operator, - expression - ); + TypecheckUtil.assertNumeric( expression, operator ); + return new SqmUnaryOperation<>( operator, expression ); } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java index 668f28a291..833b2ea611 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java @@ -17,9 +17,7 @@ import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; -import java.time.temporal.Temporal; import java.time.temporal.TemporalAccessor; -import java.time.temporal.TemporalAmount; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -49,7 +47,6 @@ import org.hibernate.metamodel.spi.MappingMetamodelImplementor; import org.hibernate.query.BindableType; import org.hibernate.query.NullPrecedence; import org.hibernate.query.ReturnableType; -import org.hibernate.query.SemanticException; import org.hibernate.query.SortDirection; import org.hibernate.query.criteria.HibernateCriteriaBuilder; import org.hibernate.query.criteria.JpaCoalesce; @@ -1932,44 +1929,6 @@ public class SqmCriteriaNodeBuilder implements NodeBuilder, SqmCreationContext, ); } - public static void assertNumeric(SqmExpression expression, BinaryArithmeticOperator op) { - final SqmExpressible nodeType = expression.getNodeType(); - if ( nodeType != null ) { - final Class javaType = nodeType.getExpressibleJavaType().getJavaTypeClass(); - if ( !Number.class.isAssignableFrom( javaType ) - && !Temporal.class.isAssignableFrom( javaType ) - && !TemporalAmount.class.isAssignableFrom( javaType ) - && !java.util.Date.class.isAssignableFrom( javaType ) ) { - throw new SemanticException( "Operand of " + op.getOperatorSqlText() - + " is of type '" + nodeType.getTypeName() + "' which is not a numeric type" - + " (it is not an instance of 'java.lang.Number', 'java.time.Temporal', or 'java.time.TemporalAmount')" ); - } - } - } - - public static void assertDuration(SqmExpression expression) { - final SqmExpressible nodeType = expression.getNodeType(); - if ( nodeType != null ) { - final Class javaType = nodeType.getExpressibleJavaType().getJavaTypeClass(); - if ( !TemporalAmount.class.isAssignableFrom( javaType ) ) { - throw new SemanticException( "Operand of 'by' is of type '" + nodeType.getTypeName() + "' which is not a duration" - + " (it is not an instance of 'java.time.TemporalAmount')" ); - } - } - } - - public static void assertNumeric(SqmExpression expression, UnaryArithmeticOperator op) { - final SqmExpressible nodeType = expression.getNodeType(); - if ( nodeType != null ) { - final Class javaType = nodeType.getExpressibleJavaType().getJavaTypeClass(); - if ( !Number.class.isAssignableFrom( javaType ) ) { - throw new SemanticException( "Operand of " + op.getOperatorChar() - + " is of type '" + nodeType.getTypeName() + "' which is not a numeric type" - + " (it is not an instance of 'java.lang.Number')" ); - } - } - } - @Override public SqmPredicate equal(Expression x, Expression y) { assertComparable( x, y, getNodeBuilder().getSessionFactory() ); diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/TypecheckUtil.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/TypecheckUtil.java index 999b009fac..573efc4c7c 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/TypecheckUtil.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/TypecheckUtil.java @@ -16,8 +16,10 @@ import org.hibernate.metamodel.model.domain.internal.DiscriminatorSqmPathSource; import org.hibernate.metamodel.model.domain.internal.EmbeddedSqmPathSource; import org.hibernate.persister.entity.EntityPersister; import org.hibernate.query.SemanticException; +import org.hibernate.query.sqm.BinaryArithmeticOperator; import org.hibernate.query.sqm.SqmExpressible; import org.hibernate.query.sqm.SqmPathSource; +import org.hibernate.query.sqm.UnaryArithmeticOperator; import org.hibernate.query.sqm.tree.SqmTypedNode; import org.hibernate.query.sqm.tree.domain.SqmPath; import org.hibernate.query.sqm.tree.expression.SqmExpression; @@ -25,6 +27,9 @@ import org.hibernate.query.sqm.tree.expression.SqmLiteralNull; import org.hibernate.type.BasicType; import org.hibernate.type.descriptor.jdbc.JdbcType; +import java.time.temporal.Temporal; +import java.time.temporal.TemporalAmount; + import static org.hibernate.type.descriptor.java.JavaTypeHelper.isUnknown; /** @@ -360,4 +365,124 @@ public class TypecheckUtil { } } } + + public static void assertOperable(SqmExpression left, SqmExpression right, BinaryArithmeticOperator op) { + final SqmExpressible leftNodeType = left.getNodeType(); + final SqmExpressible rightNodeType = right.getNodeType(); + if ( leftNodeType != null && rightNodeType != null ) { + final Class leftJavaType = leftNodeType.getExpressibleJavaType().getJavaTypeClass(); + final Class rightJavaType = rightNodeType.getExpressibleJavaType().getJavaTypeClass(); + if ( Number.class.isAssignableFrom( leftJavaType ) ) { + // left operand is a number + switch (op) { + case MULTIPLY: + if ( !Number.class.isAssignableFrom( rightJavaType ) + // we can scale a duration by a number + && !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) { + throw new SemanticException( "Operand of " + op.getOperatorSqlText() + + " is of type '" + rightNodeType.getTypeName() + "' which is not a numeric type" + + " (it is not an instance of 'java.lang.Number' or 'java.time.TemporalAmount')" ); + } + break; + default: + if ( !Number.class.isAssignableFrom( rightJavaType ) ) { + throw new SemanticException( "Operand of " + op.getOperatorSqlText() + + " is of type '" + rightNodeType.getTypeName() + "' which is not a numeric type" + + " (it is not an instance of 'java.lang.Number')" ); + } + break; + } + } + else if ( TemporalAmount.class.isAssignableFrom( leftJavaType ) ) { + // left operand is a duration + switch (op) { + case ADD: + case SUBTRACT: + // we can add/subtract durations + if ( !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) { + throw new SemanticException( "Operand of " + op.getOperatorSqlText() + + " is of type '" + rightNodeType.getTypeName() + "' which is not a temporal amount" + + " (it is not an instance of 'java.time.TemporalAmount')" ); + } + break; + default: + throw new SemanticException( "Operand of " + op.getOperatorSqlText() + + " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type" + + " (it is not an instance of 'java.lang.Number')" ); + } + } + else if ( Temporal.class.isAssignableFrom( leftJavaType ) + || java.util.Date.class.isAssignableFrom( leftJavaType ) ) { + // left operand is a date, time, or datetime + switch (op) { + case ADD: + // we can add a duration to date, time, or datetime + if ( !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) { + throw new SemanticException( "Operand of " + op.getOperatorSqlText() + + " is of type '" + rightNodeType.getTypeName() + "' which is not a temporal amount" + + " (it is not an instance of 'java.time.TemporalAmount')" ); + } + break; + case SUBTRACT: + // we can subtract dates, times, or datetimes + if ( !Temporal.class.isAssignableFrom( rightJavaType ) + && !java.util.Date.class.isAssignableFrom( rightJavaType ) + // we can subtract a duration from a date, time, or datetime + && !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) { + throw new SemanticException( "Operand of " + op.getOperatorSqlText() + + " is of type '" + rightNodeType.getTypeName() + "' which is not a temporal amount" + + " (it is not an instance of 'java.time.TemporalAmount')" ); + } + break; + default: + throw new SemanticException( "Operand of " + op.getOperatorSqlText() + + " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type" + + " (it is not an instance of 'java.lang.Number')" ); + } + } + else { + throw new SemanticException( "Operand of " + op.getOperatorSqlText() + + " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type" + + " (it is not an instance of 'java.lang.Number', 'java.time.Temporal', or 'java.time.TemporalAmount')" ); + } + } + } + +// public static void assertNumeric(SqmExpression expression, BinaryArithmeticOperator op) { +// final SqmExpressible nodeType = expression.getNodeType(); +// if ( nodeType != null ) { +// final Class javaType = nodeType.getExpressibleJavaType().getJavaTypeClass(); +// if ( !Number.class.isAssignableFrom( javaType ) +// && !Temporal.class.isAssignableFrom( javaType ) +// && !TemporalAmount.class.isAssignableFrom( javaType ) +// && !java.util.Date.class.isAssignableFrom( javaType ) ) { +// throw new SemanticException( "Operand of " + op.getOperatorSqlText() +// + " is of type '" + nodeType.getTypeName() + "' which is not a numeric type" +// + " (it is not an instance of 'java.lang.Number', 'java.time.Temporal', or 'java.time.TemporalAmount')" ); +// } +// } +// } + + public static void assertDuration(SqmExpression expression) { + final SqmExpressible nodeType = expression.getNodeType(); + if ( nodeType != null ) { + final Class javaType = nodeType.getExpressibleJavaType().getJavaTypeClass(); + if ( !TemporalAmount.class.isAssignableFrom( javaType ) ) { + throw new SemanticException( "Operand of 'by' is of type '" + nodeType.getTypeName() + "' which is not a duration" + + " (it is not an instance of 'java.time.TemporalAmount')" ); + } + } + } + + public static void assertNumeric(SqmExpression expression, UnaryArithmeticOperator op) { + final SqmExpressible nodeType = expression.getNodeType(); + if ( nodeType != null ) { + final Class javaType = nodeType.getExpressibleJavaType().getJavaTypeClass(); + if ( !Number.class.isAssignableFrom( javaType ) ) { + throw new SemanticException( "Operand of " + op.getOperatorChar() + + " is of type '" + nodeType.getTypeName() + "' which is not a numeric type" + + " (it is not an instance of 'java.lang.Number')" ); + } + } + } }