HHH-16891 typechecking for arithmetic expressions

This commit is contained in:
Gavin King 2023-08-19 22:32:59 +02:00
parent 392d539c8c
commit bf297e0e87
3 changed files with 131 additions and 52 deletions

View File

@ -98,9 +98,9 @@ import org.hibernate.query.sqm.function.NamedSqmFunctionDescriptor;
import org.hibernate.query.sqm.function.SqmFunctionDescriptor; import org.hibernate.query.sqm.function.SqmFunctionDescriptor;
import org.hibernate.query.sqm.internal.ParameterCollector; import org.hibernate.query.sqm.internal.ParameterCollector;
import org.hibernate.query.sqm.internal.SqmCreationProcessingStateImpl; import org.hibernate.query.sqm.internal.SqmCreationProcessingStateImpl;
import org.hibernate.query.sqm.internal.SqmCriteriaNodeBuilder;
import org.hibernate.query.sqm.internal.SqmDmlCreationProcessingState; import org.hibernate.query.sqm.internal.SqmDmlCreationProcessingState;
import org.hibernate.query.sqm.internal.SqmQueryPartCreationProcessingStateStandardImpl; import org.hibernate.query.sqm.internal.SqmQueryPartCreationProcessingStateStandardImpl;
import org.hibernate.query.sqm.internal.TypecheckUtil;
import org.hibernate.query.sqm.produce.function.FunctionArgumentException; import org.hibernate.query.sqm.produce.function.FunctionArgumentException;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.query.sqm.spi.ParameterDeclarationContext; import org.hibernate.query.sqm.spi.ParameterDeclarationContext;
@ -2946,8 +2946,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
final SqmExpression<?> left = (SqmExpression<?>) ctx.expression(0).accept(this); final SqmExpression<?> left = (SqmExpression<?>) ctx.expression(0).accept(this);
final SqmExpression<?> right = (SqmExpression<?>) ctx.expression(1).accept(this); final SqmExpression<?> right = (SqmExpression<?>) ctx.expression(1).accept(this);
final BinaryArithmeticOperator operator = (BinaryArithmeticOperator) ctx.additiveOperator().accept(this); final BinaryArithmeticOperator operator = (BinaryArithmeticOperator) ctx.additiveOperator().accept(this);
SqmCriteriaNodeBuilder.assertNumeric( left, operator ); TypecheckUtil.assertOperable( left, right, operator );
SqmCriteriaNodeBuilder.assertNumeric( right, operator );
return new SqmBinaryArithmetic<>( return new SqmBinaryArithmetic<>(
operator, operator,
@ -2967,8 +2966,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
final SqmExpression<?> left = (SqmExpression<?>) ctx.expression(0).accept( this ); final SqmExpression<?> left = (SqmExpression<?>) ctx.expression(0).accept( this );
final SqmExpression<?> right = (SqmExpression<?>) ctx.expression(1).accept( this ); final SqmExpression<?> right = (SqmExpression<?>) ctx.expression(1).accept( this );
final BinaryArithmeticOperator operator = (BinaryArithmeticOperator) ctx.multiplicativeOperator().accept( this ); final BinaryArithmeticOperator operator = (BinaryArithmeticOperator) ctx.multiplicativeOperator().accept( this );
SqmCriteriaNodeBuilder.assertNumeric( left, operator ); TypecheckUtil.assertOperable( left, right, operator );
SqmCriteriaNodeBuilder.assertNumeric( right, operator );
if ( operator == BinaryArithmeticOperator.MODULO ) { if ( operator == BinaryArithmeticOperator.MODULO ) {
return getFunctionDescriptor("mod").generateSqmExpression( return getFunctionDescriptor("mod").generateSqmExpression(
@ -3009,7 +3007,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
@Override @Override
public Object visitFromDurationExpression(HqlParser.FromDurationExpressionContext ctx) { public Object visitFromDurationExpression(HqlParser.FromDurationExpressionContext ctx) {
SqmExpression<?> expression = (SqmExpression<?>) ctx.expression().accept( this ); SqmExpression<?> expression = (SqmExpression<?>) ctx.expression().accept( this );
TypecheckUtil.assertDuration( expression );
return new SqmByUnit( return new SqmByUnit(
toDurationUnit( (SqmExtractUnit<?>) ctx.datetimeField().accept( this ) ), toDurationUnit( (SqmExtractUnit<?>) ctx.datetimeField().accept( this ) ),
expression, expression,
@ -3022,11 +3020,8 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
public SqmUnaryOperation<?> visitUnaryExpression(HqlParser.UnaryExpressionContext ctx) { public SqmUnaryOperation<?> visitUnaryExpression(HqlParser.UnaryExpressionContext ctx) {
final SqmExpression<?> expression = (SqmExpression<?>) ctx.expression().accept(this); final SqmExpression<?> expression = (SqmExpression<?>) ctx.expression().accept(this);
final UnaryArithmeticOperator operator = (UnaryArithmeticOperator) ctx.signOperator().accept(this); final UnaryArithmeticOperator operator = (UnaryArithmeticOperator) ctx.signOperator().accept(this);
SqmCriteriaNodeBuilder.assertNumeric( expression, operator ); TypecheckUtil.assertNumeric( expression, operator );
return new SqmUnaryOperation<>( return new SqmUnaryOperation<>( operator, expression );
operator,
expression
);
} }
@Override @Override

View File

@ -17,9 +17,7 @@ import java.time.Instant;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.LocalTime; import java.time.LocalTime;
import java.time.temporal.Temporal;
import java.time.temporal.TemporalAccessor; import java.time.temporal.TemporalAccessor;
import java.time.temporal.TemporalAmount;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
@ -49,7 +47,6 @@ import org.hibernate.metamodel.spi.MappingMetamodelImplementor;
import org.hibernate.query.BindableType; import org.hibernate.query.BindableType;
import org.hibernate.query.NullPrecedence; import org.hibernate.query.NullPrecedence;
import org.hibernate.query.ReturnableType; import org.hibernate.query.ReturnableType;
import org.hibernate.query.SemanticException;
import org.hibernate.query.SortDirection; import org.hibernate.query.SortDirection;
import org.hibernate.query.criteria.HibernateCriteriaBuilder; import org.hibernate.query.criteria.HibernateCriteriaBuilder;
import org.hibernate.query.criteria.JpaCoalesce; import org.hibernate.query.criteria.JpaCoalesce;
@ -1932,44 +1929,6 @@ public class SqmCriteriaNodeBuilder implements NodeBuilder, SqmCreationContext,
); );
} }
public static void assertNumeric(SqmExpression<?> expression, BinaryArithmeticOperator op) {
final SqmExpressible<?> nodeType = expression.getNodeType();
if ( nodeType != null ) {
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
if ( !Number.class.isAssignableFrom( javaType )
&& !Temporal.class.isAssignableFrom( javaType )
&& !TemporalAmount.class.isAssignableFrom( javaType )
&& !java.util.Date.class.isAssignableFrom( javaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + nodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number', 'java.time.Temporal', or 'java.time.TemporalAmount')" );
}
}
}
public static void assertDuration(SqmExpression<?> expression) {
final SqmExpressible<?> nodeType = expression.getNodeType();
if ( nodeType != null ) {
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
if ( !TemporalAmount.class.isAssignableFrom( javaType ) ) {
throw new SemanticException( "Operand of 'by' is of type '" + nodeType.getTypeName() + "' which is not a duration"
+ " (it is not an instance of 'java.time.TemporalAmount')" );
}
}
}
public static void assertNumeric(SqmExpression<?> expression, UnaryArithmeticOperator op) {
final SqmExpressible<?> nodeType = expression.getNodeType();
if ( nodeType != null ) {
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
if ( !Number.class.isAssignableFrom( javaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorChar()
+ " is of type '" + nodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number')" );
}
}
}
@Override @Override
public SqmPredicate equal(Expression<?> x, Expression<?> y) { public SqmPredicate equal(Expression<?> x, Expression<?> y) {
assertComparable( x, y, getNodeBuilder().getSessionFactory() ); assertComparable( x, y, getNodeBuilder().getSessionFactory() );

View File

@ -16,8 +16,10 @@ import org.hibernate.metamodel.model.domain.internal.DiscriminatorSqmPathSource;
import org.hibernate.metamodel.model.domain.internal.EmbeddedSqmPathSource; import org.hibernate.metamodel.model.domain.internal.EmbeddedSqmPathSource;
import org.hibernate.persister.entity.EntityPersister; import org.hibernate.persister.entity.EntityPersister;
import org.hibernate.query.SemanticException; import org.hibernate.query.SemanticException;
import org.hibernate.query.sqm.BinaryArithmeticOperator;
import org.hibernate.query.sqm.SqmExpressible; import org.hibernate.query.sqm.SqmExpressible;
import org.hibernate.query.sqm.SqmPathSource; import org.hibernate.query.sqm.SqmPathSource;
import org.hibernate.query.sqm.UnaryArithmeticOperator;
import org.hibernate.query.sqm.tree.SqmTypedNode; import org.hibernate.query.sqm.tree.SqmTypedNode;
import org.hibernate.query.sqm.tree.domain.SqmPath; import org.hibernate.query.sqm.tree.domain.SqmPath;
import org.hibernate.query.sqm.tree.expression.SqmExpression; import org.hibernate.query.sqm.tree.expression.SqmExpression;
@ -25,6 +27,9 @@ import org.hibernate.query.sqm.tree.expression.SqmLiteralNull;
import org.hibernate.type.BasicType; import org.hibernate.type.BasicType;
import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.JdbcType;
import java.time.temporal.Temporal;
import java.time.temporal.TemporalAmount;
import static org.hibernate.type.descriptor.java.JavaTypeHelper.isUnknown; import static org.hibernate.type.descriptor.java.JavaTypeHelper.isUnknown;
/** /**
@ -360,4 +365,124 @@ public class TypecheckUtil {
} }
} }
} }
public static void assertOperable(SqmExpression<?> left, SqmExpression<?> right, BinaryArithmeticOperator op) {
final SqmExpressible<?> leftNodeType = left.getNodeType();
final SqmExpressible<?> rightNodeType = right.getNodeType();
if ( leftNodeType != null && rightNodeType != null ) {
final Class<?> leftJavaType = leftNodeType.getExpressibleJavaType().getJavaTypeClass();
final Class<?> rightJavaType = rightNodeType.getExpressibleJavaType().getJavaTypeClass();
if ( Number.class.isAssignableFrom( leftJavaType ) ) {
// left operand is a number
switch (op) {
case MULTIPLY:
if ( !Number.class.isAssignableFrom( rightJavaType )
// we can scale a duration by a number
&& !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number' or 'java.time.TemporalAmount')" );
}
break;
default:
if ( !Number.class.isAssignableFrom( rightJavaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number')" );
}
break;
}
}
else if ( TemporalAmount.class.isAssignableFrom( leftJavaType ) ) {
// left operand is a duration
switch (op) {
case ADD:
case SUBTRACT:
// we can add/subtract durations
if ( !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a temporal amount"
+ " (it is not an instance of 'java.time.TemporalAmount')" );
}
break;
default:
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number')" );
}
}
else if ( Temporal.class.isAssignableFrom( leftJavaType )
|| java.util.Date.class.isAssignableFrom( leftJavaType ) ) {
// left operand is a date, time, or datetime
switch (op) {
case ADD:
// we can add a duration to date, time, or datetime
if ( !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a temporal amount"
+ " (it is not an instance of 'java.time.TemporalAmount')" );
}
break;
case SUBTRACT:
// we can subtract dates, times, or datetimes
if ( !Temporal.class.isAssignableFrom( rightJavaType )
&& !java.util.Date.class.isAssignableFrom( rightJavaType )
// we can subtract a duration from a date, time, or datetime
&& !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a temporal amount"
+ " (it is not an instance of 'java.time.TemporalAmount')" );
}
break;
default:
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number')" );
}
}
else {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number', 'java.time.Temporal', or 'java.time.TemporalAmount')" );
}
}
}
// public static void assertNumeric(SqmExpression<?> expression, BinaryArithmeticOperator op) {
// final SqmExpressible<?> nodeType = expression.getNodeType();
// if ( nodeType != null ) {
// final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
// if ( !Number.class.isAssignableFrom( javaType )
// && !Temporal.class.isAssignableFrom( javaType )
// && !TemporalAmount.class.isAssignableFrom( javaType )
// && !java.util.Date.class.isAssignableFrom( javaType ) ) {
// throw new SemanticException( "Operand of " + op.getOperatorSqlText()
// + " is of type '" + nodeType.getTypeName() + "' which is not a numeric type"
// + " (it is not an instance of 'java.lang.Number', 'java.time.Temporal', or 'java.time.TemporalAmount')" );
// }
// }
// }
public static void assertDuration(SqmExpression<?> expression) {
final SqmExpressible<?> nodeType = expression.getNodeType();
if ( nodeType != null ) {
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
if ( !TemporalAmount.class.isAssignableFrom( javaType ) ) {
throw new SemanticException( "Operand of 'by' is of type '" + nodeType.getTypeName() + "' which is not a duration"
+ " (it is not an instance of 'java.time.TemporalAmount')" );
}
}
}
public static void assertNumeric(SqmExpression<?> expression, UnaryArithmeticOperator op) {
final SqmExpressible<?> nodeType = expression.getNodeType();
if ( nodeType != null ) {
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
if ( !Number.class.isAssignableFrom( javaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorChar()
+ " is of type '" + nodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number')" );
}
}
}
} }