HHH-16875 be a bit more forgiving when type checking expressions involving unknown HQL functions

Let's not reject expressions like:

    function('current_user') = 'username'

also add QueryArgumentException
This commit is contained in:
Gavin King 2023-07-04 19:17:26 +02:00
parent a2defad7a4
commit cd02a961c8
9 changed files with 116 additions and 51 deletions

View File

@ -0,0 +1,35 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.query;
/**
* An error that occurs binding an argument to a query parameter.
* Usually indicates that the argument is of a type not assignable
* to the type of the parameter.
*
* @since 6.3
*
* @author Gavin King
*/
public class QueryArgumentException extends IllegalArgumentException {
private final Class<?> parameterType;
private final Object argument;
public QueryArgumentException(String message, Class<?> parameterType, Object argument) {
super(message);
this.parameterType = parameterType;
this.argument = argument;
}
public Class<?> getParameterType() {
return parameterType;
}
public Object getArgument() {
return argument;
}
}

View File

@ -213,6 +213,9 @@ import org.hibernate.type.BasicType;
import org.hibernate.type.descriptor.java.JavaType; import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.java.PrimitiveByteArrayJavaType; import org.hibernate.type.descriptor.java.PrimitiveByteArrayJavaType;
import org.hibernate.type.descriptor.java.spi.UnknownBasicJavaType;
import org.hibernate.type.descriptor.jdbc.ObjectJdbcType;
import org.hibernate.type.internal.BasicTypeImpl;
import org.jboss.logging.Logger; import org.jboss.logging.Logger;
import jakarta.persistence.criteria.Predicate; import jakarta.persistence.criteria.Predicate;
@ -237,6 +240,7 @@ import static org.hibernate.grammars.hql.HqlParser.ListaggFunctionContext;
import static org.hibernate.grammars.hql.HqlParser.OnOverflowClauseContext; import static org.hibernate.grammars.hql.HqlParser.OnOverflowClauseContext;
import static org.hibernate.grammars.hql.HqlParser.PLUS; import static org.hibernate.grammars.hql.HqlParser.PLUS;
import static org.hibernate.grammars.hql.HqlParser.UNION; import static org.hibernate.grammars.hql.HqlParser.UNION;
import static org.hibernate.internal.util.QuotingHelper.unquoteStringLiteral;
import static org.hibernate.query.sqm.TemporalUnit.DATE; import static org.hibernate.query.sqm.TemporalUnit.DATE;
import static org.hibernate.query.sqm.TemporalUnit.DAY_OF_MONTH; import static org.hibernate.query.sqm.TemporalUnit.DAY_OF_MONTH;
import static org.hibernate.query.sqm.TemporalUnit.DAY_OF_WEEK; import static org.hibernate.query.sqm.TemporalUnit.DAY_OF_WEEK;
@ -2577,7 +2581,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
else { else {
assert child instanceof TerminalNode; assert child instanceof TerminalNode;
final TerminalNode terminalNode = (TerminalNode) child; final TerminalNode terminalNode = (TerminalNode) child;
final String escape = QuotingHelper.unquoteStringLiteral( terminalNode.getText() ); final String escape = unquoteStringLiteral( terminalNode.getText() );
if ( escape.length() != 1 ) { if ( escape.length() != 1 ) {
throw new SemanticException( throw new SemanticException(
"Escape character literals must have exactly a single character, but found: " + escape "Escape character literals must have exactly a single character, but found: " + escape
@ -3450,7 +3454,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
final TerminalNode firstChild = (TerminalNode) ctx.getChild( 0 ); final TerminalNode firstChild = (TerminalNode) ctx.getChild( 0 );
final String timezoneText; final String timezoneText;
if ( firstChild.getSymbol().getType() == HqlParser.STRING_LITERAL ) { if ( firstChild.getSymbol().getType() == HqlParser.STRING_LITERAL ) {
timezoneText = QuotingHelper.unquoteStringLiteral( ctx.getText() ); timezoneText = unquoteStringLiteral( ctx.getText() );
} }
else { else {
timezoneText = ctx.getText(); timezoneText = ctx.getText();
@ -3617,7 +3621,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
private SqmLiteral<String> stringLiteral(String text) { private SqmLiteral<String> stringLiteral(String text) {
return new SqmLiteral<>( return new SqmLiteral<>(
QuotingHelper.unquoteStringLiteral( text ), unquoteStringLiteral( text ),
resolveExpressibleTypeBasic( String.class ), resolveExpressibleTypeBasic( String.class ),
creationContext.getNodeBuilder() creationContext.getNodeBuilder()
); );
@ -3880,11 +3884,11 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
@Override @Override
public SqmExpression<?> visitJpaNonstandardFunction(HqlParser.JpaNonstandardFunctionContext ctx) { public SqmExpression<?> visitJpaNonstandardFunction(HqlParser.JpaNonstandardFunctionContext ctx) {
final String functionName = QuotingHelper.unquoteStringLiteral( ctx.getChild( 2 ).getText() ).toLowerCase(); final String functionName = unquoteStringLiteral( ctx.jpaNonstandardFunctionName().getText() ).toLowerCase();
final List<SqmTypedNode<?>> functionArguments; final List<SqmTypedNode<?>> functionArguments;
if ( ctx.getChildCount() > 4 ) { if ( ctx.getChildCount() > 4 ) {
//noinspection unchecked //noinspection unchecked
functionArguments = (List<SqmTypedNode<?>>) ctx.getChild( 4 ).accept( this ); functionArguments = (List<SqmTypedNode<?>>) ctx.genericFunctionArguments().accept( this );
} }
else { else {
functionArguments = emptyList(); functionArguments = emptyList();
@ -3897,7 +3901,10 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
true, true,
null, null,
StandardFunctionReturnTypeResolvers.invariant( StandardFunctionReturnTypeResolvers.invariant(
resolveExpressibleTypeBasic( Object.class ) new BasicTypeImpl<>(
new UnknownBasicJavaType<>( Object.class ),
ObjectJdbcType.INSTANCE
)
), ),
null null
); );
@ -4396,7 +4403,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
@Override @Override
public Object visitFormat(HqlParser.FormatContext ctx) { public Object visitFormat(HqlParser.FormatContext ctx) {
final String format = QuotingHelper.unquoteStringLiteral( ctx.getChild( 0 ).getText() ); final String format = unquoteStringLiteral( ctx.getChild( 0 ).getText() );
return new SqmFormat( return new SqmFormat(
format, format,
resolveExpressibleTypeBasic( String.class ), resolveExpressibleTypeBasic( String.class ),
@ -4898,7 +4905,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
@Override @Override
public SqmLiteral<Character> visitTrimCharacter(HqlParser.TrimCharacterContext ctx) { public SqmLiteral<Character> visitTrimCharacter(HqlParser.TrimCharacterContext ctx) {
final String trimCharText = ctx != null final String trimCharText = ctx != null
? QuotingHelper.unquoteStringLiteral( ctx.getText() ) ? unquoteStringLiteral( ctx.getText() )
: " "; // JPA says space is the default : " "; // JPA says space is the default
if ( trimCharText.length() != 1 ) { if ( trimCharText.length() != 1 ) {

View File

@ -12,6 +12,7 @@ import java.util.Date;
import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.query.BindableType; import org.hibernate.query.BindableType;
import org.hibernate.query.QueryArgumentException;
import org.hibernate.query.sqm.SqmExpressible; import org.hibernate.query.sqm.SqmExpressible;
import org.hibernate.type.descriptor.java.JavaType; import org.hibernate.type.descriptor.java.JavaType;
@ -80,14 +81,16 @@ public class QueryParameterBindingValidator {
bind, bind,
temporalPrecision temporalPrecision
) ) { ) ) {
throw new IllegalArgumentException( throw new QueryArgumentException(
String.format( String.format(
"Argument [%s] of type [%s] did not match parameter type [%s (%s)]", "Argument [%s] of type [%s] did not match parameter type [%s (%s)]",
bind, bind,
bind.getClass().getName(), bind.getClass().getName(),
parameterJavaType.getName(), parameterJavaType.getName(),
extractName( temporalPrecision ) extractName( temporalPrecision )
) ),
parameterJavaType,
bind
); );
} }
} }
@ -104,13 +107,16 @@ public class QueryParameterBindingValidator {
// validate the elements... // validate the elements...
for ( Object element : value ) { for ( Object element : value ) {
if ( !isValidBindValue( parameterType, element, temporalType ) ) { if ( !isValidBindValue( parameterType, element, temporalType ) ) {
throw new IllegalArgumentException( throw new QueryArgumentException(
String.format( String.format(
"Parameter value element [%s] did not match expected type [%s (%s)]", "Parameter value element [%s] did not match expected type [%s (%s)]",
element, element,
parameterType.getName(), parameterType.getName(),
extractName( temporalType ) extractName( temporalType )
) )
,
parameterType,
element
); );
} }
} }
@ -193,12 +199,14 @@ public class QueryParameterBindingValidator {
Object value, Object value,
TemporalType temporalType) { TemporalType temporalType) {
if ( !parameterType.isArray() ) { if ( !parameterType.isArray() ) {
throw new IllegalArgumentException( throw new QueryArgumentException(
String.format( String.format(
"Encountered array-valued parameter binding, but was expecting [%s (%s)]", "Encountered array-valued parameter binding, but was expecting [%s (%s)]",
parameterType.getName(), parameterType.getName(),
extractName( temporalType ) extractName( temporalType )
) ),
parameterType,
value
); );
} }
@ -206,13 +214,15 @@ public class QueryParameterBindingValidator {
// we have a primitive array. we validate that the actual array has the component type (type of elements) // we have a primitive array. we validate that the actual array has the component type (type of elements)
// we expect based on the component type of the parameter specification // we expect based on the component type of the parameter specification
if ( !parameterType.getComponentType().isAssignableFrom( value.getClass().getComponentType() ) ) { if ( !parameterType.getComponentType().isAssignableFrom( value.getClass().getComponentType() ) ) {
throw new IllegalArgumentException( throw new QueryArgumentException(
String.format( String.format(
"Primitive array-valued parameter bind value type [%s] did not match expected type [%s (%s)]", "Primitive array-valued parameter bind value type [%s] did not match expected type [%s (%s)]",
value.getClass().getComponentType().getName(), value.getClass().getComponentType().getName(),
parameterType.getName(), parameterType.getName(),
extractName( temporalType ) extractName( temporalType )
) ),
parameterType,
value
); );
} }
} }
@ -222,13 +232,15 @@ public class QueryParameterBindingValidator {
final Object[] array = (Object[]) value; final Object[] array = (Object[]) value;
for ( Object element : array ) { for ( Object element : array ) {
if ( !isValidBindValue( parameterType.getComponentType(), element, temporalType ) ) { if ( !isValidBindValue( parameterType.getComponentType(), element, temporalType ) ) {
throw new IllegalArgumentException( throw new QueryArgumentException(
String.format( String.format(
"Array-valued parameter value element [%s] did not match expected type [%s (%s)]", "Array-valued parameter value element [%s] did not match expected type [%s (%s)]",
element, element,
parameterType.getName(), parameterType.getName(),
extractName( temporalType ) extractName( temporalType )
) ),
parameterType,
array
); );
} }
} }

View File

@ -25,6 +25,8 @@ import org.hibernate.query.sqm.tree.expression.SqmLiteralNull;
import org.hibernate.type.BasicType; import org.hibernate.type.BasicType;
import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.JdbcType;
import static org.hibernate.type.descriptor.java.JavaTypeHelper.isUnknown;
/** /**
* Functions for typechecking comparison expressions and assignments in the SQM tree. * Functions for typechecking comparison expressions and assignments in the SQM tree.
* A comparison expression is any predicate like {@code x = y} or {@code x > y}. An * A comparison expression is any predicate like {@code x = y} or {@code x > y}. An
@ -285,7 +287,8 @@ public class TypecheckUtil {
} }
private static boolean isSameJavaType(SqmExpressible<?> leftType, SqmExpressible<?> rightType) { private static boolean isSameJavaType(SqmExpressible<?> leftType, SqmExpressible<?> rightType) {
return leftType.getRelationalJavaType() == rightType.getRelationalJavaType() return isUnknown( leftType.getExpressibleJavaType() ) || isUnknown( rightType.getExpressibleJavaType() )
|| leftType.getRelationalJavaType() == rightType.getRelationalJavaType()
|| leftType.getExpressibleJavaType() == rightType.getExpressibleJavaType() || leftType.getExpressibleJavaType() == rightType.getExpressibleJavaType()
|| leftType.getBindableJavaType() == rightType.getBindableJavaType(); || leftType.getBindableJavaType() == rightType.getBindableJavaType();
} }

View File

@ -22,6 +22,7 @@ import org.hibernate.query.sqm.tree.expression.SqmExtractUnit;
import org.hibernate.query.sqm.tree.expression.SqmTrimSpecification; import org.hibernate.query.sqm.tree.expression.SqmTrimSpecification;
import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.Expression; import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.type.BasicType;
import org.hibernate.type.JavaObjectType; import org.hibernate.type.JavaObjectType;
import org.hibernate.type.descriptor.java.JavaType; import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.java.spi.JdbcTypeRecommendationException; import org.hibernate.type.descriptor.java.spi.JdbcTypeRecommendationException;
@ -42,6 +43,7 @@ import static org.hibernate.type.SqlTypes.isIntegral;
import static org.hibernate.type.SqlTypes.isNumericType; import static org.hibernate.type.SqlTypes.isNumericType;
import static org.hibernate.type.SqlTypes.isSpatialType; import static org.hibernate.type.SqlTypes.isSpatialType;
import static org.hibernate.type.SqlTypes.isTemporalType; import static org.hibernate.type.SqlTypes.isTemporalType;
import static org.hibernate.type.descriptor.java.JavaTypeHelper.isUnknown;
/** /**
@ -132,6 +134,7 @@ public class ArgumentTypesValidator implements ArgumentsValidator {
JdbcTypeIndicators indicators, JdbcTypeIndicators indicators,
FunctionParameterType type, FunctionParameterType type,
JavaType<?> javaType) { JavaType<?> javaType) {
if ( !isUnknown( javaType ) ) {
DomainType<?> domainType = argument.getExpressible().getSqmType(); DomainType<?> domainType = argument.getExpressible().getSqmType();
if ( domainType instanceof JdbcMapping ) { if ( domainType instanceof JdbcMapping ) {
checkArgumentType( checkArgumentType(
@ -154,6 +157,7 @@ public class ArgumentTypesValidator implements ArgumentsValidator {
} }
} }
} }
}
private int getJdbcType(JdbcTypeIndicators indicators, JavaType<?> javaType) { private int getJdbcType(JdbcTypeIndicators indicators, JavaType<?> javaType) {
if ( javaType.getJavaTypeClass().isEnum() ) { if ( javaType.getJavaTypeClass().isEnum() ) {
@ -177,21 +181,31 @@ public class ArgumentTypesValidator implements ArgumentsValidator {
@Override @Override
public void validateSqlTypes(List<? extends SqlAstNode> arguments, String functionName) { public void validateSqlTypes(List<? extends SqlAstNode> arguments, String functionName) {
int count = 0; int count = 0;
for (SqlAstNode argument : arguments) { for ( SqlAstNode argument : arguments ) {
if (argument instanceof Expression) { if ( argument instanceof Expression ) {
JdbcMappingContainer expressionType = ((Expression) argument).getExpressionType(); final Expression expression = (Expression) argument;
final JdbcMappingContainer expressionType = expression.getExpressionType();
if (expressionType != null) { if (expressionType != null) {
if (expressionType instanceof JavaObjectType) { if ( isUnknownExpressionType( expressionType ) ) {
count += expressionType.getJdbcTypeCount(); count += expressionType.getJdbcTypeCount();
} }
else { else {
count = validateArgument(count, expressionType, functionName); count = validateArgument( count, expressionType, functionName );
} }
} }
} }
} }
} }
/**
* We can't validate some expressions involving parameters / unknown functions.
*/
private static boolean isUnknownExpressionType(JdbcMappingContainer expressionType) {
return expressionType instanceof JavaObjectType
|| expressionType instanceof BasicType
&& isUnknown( ((BasicType<?>) expressionType).getJavaTypeDescriptor() );
}
private int validateArgument(int count, JdbcMappingContainer expressionType, String functionName) { private int validateArgument(int count, JdbcMappingContainer expressionType, String functionName) {
final int jdbcTypeCount = expressionType.getJdbcTypeCount(); final int jdbcTypeCount = expressionType.getJdbcTypeCount();
for ( int i = 0; i < jdbcTypeCount; i++ ) { for ( int i = 0; i < jdbcTypeCount; i++ ) {
@ -275,7 +289,7 @@ public class ArgumentTypesValidator implements ArgumentsValidator {
private void throwError(FunctionParameterType type, Type javaType, String functionName, int count) { private void throwError(FunctionParameterType type, Type javaType, String functionName, int count) {
throw new FunctionArgumentException( throw new FunctionArgumentException(
String.format( String.format(
"Parameter %d of function %s() has type %s, but argument is of type %s", "Parameter %d of function '%s()' has type '%s', but argument is of type '%s'",
count, count,
functionName, functionName,
type, type,

View File

@ -5349,10 +5349,10 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
); );
} }
else { else {
throw new SqlTreeCreationException( throw new SemanticException(
String.format( String.format(
Locale.ROOT, Locale.ROOT,
"QueryLiteral type [`%s`] did not match domain Java-type [`%s`] nor JDBC Java-type [`%s`]", "Literal type '%s' did not match domain type '%s' nor converted type '%s'",
value.getClass(), value.getClass(),
valueConverter.getDomainJavaType().getJavaTypeClass().getName(), valueConverter.getDomainJavaType().getJavaTypeClass().getName(),
valueConverter.getRelationalJavaType().getJavaTypeClass().getName() valueConverter.getRelationalJavaType().getJavaTypeClass().getName()

View File

@ -14,8 +14,6 @@ import java.util.function.Supplier;
import org.hibernate.annotations.Immutable; import org.hibernate.annotations.Immutable;
import org.hibernate.annotations.Mutability; import org.hibernate.annotations.Mutability;
import org.hibernate.internal.util.ReflectHelper; import org.hibernate.internal.util.ReflectHelper;
import org.hibernate.resource.beans.spi.ManagedBean;
import org.hibernate.resource.beans.spi.ManagedBeanRegistry;
import org.hibernate.type.descriptor.java.EnumJavaType; import org.hibernate.type.descriptor.java.EnumJavaType;
import org.hibernate.type.descriptor.java.ImmutableMutabilityPlan; import org.hibernate.type.descriptor.java.ImmutableMutabilityPlan;
import org.hibernate.type.descriptor.java.JavaType; import org.hibernate.type.descriptor.java.JavaType;

View File

@ -1,9 +1,8 @@
package org.hibernate.orm.test.jpa.criteria; package org.hibernate.orm.test.jpa.criteria;
import org.hibernate.query.SemanticException; import org.hibernate.query.QueryArgumentException;
import org.hibernate.testing.TestForIssue; import org.hibernate.testing.TestForIssue;
import org.hibernate.testing.orm.junit.EntityManagerFactoryScope; import org.hibernate.testing.orm.junit.EntityManagerFactoryScope;
import org.hibernate.testing.orm.junit.FailureExpected;
import org.hibernate.testing.orm.junit.Jpa; import org.hibernate.testing.orm.junit.Jpa;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -39,7 +38,7 @@ public class ObjectParameterTypeForEmbeddableTest {
); );
} }
@Test @FailureExpected(reason = "This query is in my opinion not well-typed, and should be rejected") @Test
public void testSettingParameterOfTypeObject(EntityManagerFactoryScope scope) { public void testSettingParameterOfTypeObject(EntityManagerFactoryScope scope) {
scope.inTransaction( scope.inTransaction(
entityManager -> { entityManager -> {
@ -77,8 +76,8 @@ public class ObjectParameterTypeForEmbeddableTest {
@Test @Test
public void testSettingParameterOfTypeWrongType(EntityManagerFactoryScope scope) { public void testSettingParameterOfTypeWrongType(EntityManagerFactoryScope scope) {
SemanticException thrown = assertThrows( QueryArgumentException thrown = assertThrows(
SemanticException.class, () -> QueryArgumentException.class, () ->
scope.inTransaction( scope.inTransaction(
entityManager -> { entityManager -> {
final CriteriaBuilder cb = entityManager.getCriteriaBuilder(); final CriteriaBuilder cb = entityManager.getCriteriaBuilder();
@ -94,7 +93,7 @@ public class ObjectParameterTypeForEmbeddableTest {
) )
); );
assertThat( thrown.getMessage() ).startsWith( "Cannot compare left expression" ); assertThat( thrown.getMessage() ).contains( "did not match parameter type" );
} }
@Entity(name = "TestEntity") @Entity(name = "TestEntity")

View File

@ -21,7 +21,6 @@ import org.hibernate.engine.spi.SessionImplementor;
import org.hibernate.internal.util.ExceptionHelper; import org.hibernate.internal.util.ExceptionHelper;
import org.hibernate.query.Query; import org.hibernate.query.Query;
import org.hibernate.query.SemanticException; import org.hibernate.query.SemanticException;
import org.hibernate.sql.ast.SqlTreeCreationException;
import org.hibernate.testing.orm.junit.DomainModel; import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.SessionFactory; import org.hibernate.testing.orm.junit.SessionFactory;
@ -29,9 +28,7 @@ import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;