Fix type inference for binary arithmetic expressions

This commit is contained in:
Christian Beikov 2022-01-30 14:57:26 +01:00
parent af42f3a76c
commit ce5951a948
2 changed files with 194 additions and 78 deletions

View File

@ -411,6 +411,10 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
private SqmJoin<?, ?> currentlyProcessingJoin; private SqmJoin<?, ?> currentlyProcessingJoin;
private final Stack<Clause> currentClauseStack = new StandardStack<>(); private final Stack<Clause> currentClauseStack = new StandardStack<>();
private final Stack<Supplier<MappingModelExpressible<?>>> inferrableTypeAccessStack = new StandardStack<>(
() -> null
);
private boolean inTypeInference;
private SqmByUnit appliedByUnit; private SqmByUnit appliedByUnit;
private Expression adjustedTimestamp; private Expression adjustedTimestamp;
@ -3890,15 +3894,24 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
} }
} }
private MappingModelExpressible<?> resolveInferredType() {
final Supplier<MappingModelExpressible<?>> inferableTypeAccess = inferrableTypeAccessStack.getCurrent();
if ( inTypeInference || inferableTypeAccess == null ) {
return null;
}
inTypeInference = true;
final MappingModelExpressible<?> inferredType = inferableTypeAccess.get();
inTypeInference = false;
return inferredType;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// General expressions // General expressions
@Override @Override
public Expression visitLiteral(SqmLiteral<?> literal) { public Expression visitLiteral(SqmLiteral<?> literal) {
final Supplier<MappingModelExpressible<?>> inferableTypeAccess = inferrableTypeAccessStack.getCurrent();
if ( literal instanceof SqmLiteralNull ) { if ( literal instanceof SqmLiteralNull ) {
MappingModelExpressible<?> mappingModelExpressible = inferableTypeAccess.get(); MappingModelExpressible<?> mappingModelExpressible = resolveInferredType();
if ( mappingModelExpressible == null ) { if ( mappingModelExpressible == null ) {
mappingModelExpressible = determineCurrentExpressible( literal ); mappingModelExpressible = determineCurrentExpressible( literal );
} }
@ -3924,7 +3937,7 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
return new SqlTuple( expressions, mappingModelExpressible); return new SqlTuple( expressions, mappingModelExpressible);
} }
final MappingModelExpressible<?> inferableExpressible = inferableTypeAccess.get(); final MappingModelExpressible<?> inferableExpressible = resolveInferredType();
if ( inferableExpressible instanceof ConvertibleModelPart ) { if ( inferableExpressible instanceof ConvertibleModelPart ) {
final ConvertibleModelPart convertibleModelPart = (ConvertibleModelPart) inferableExpressible; final ConvertibleModelPart convertibleModelPart = (ConvertibleModelPart) inferableExpressible;
@ -4234,57 +4247,30 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
); );
} }
// protected MappingModelExpressible<?> lenientlyResolveMappingExpressible(SqmExpressible<?> nodeType) {
// return resolveMappingExpressible( nodeType );
// }
//
// protected MappingModelExpressible<?> resolveMappingExpressible(SqmExpressible<?> nodeType) {
// final MappingModelExpressible<?> valueMapping = getCreationContext().getDomainModel().resolveMappingExpressible(
// nodeType,
// this::findTableGroupByPath
// );
//
// if ( valueMapping == null ) {
// final Supplier<MappingModelExpressible<?>> currentExpressibleSupplier = inferrableTypeAccessStack.getCurrent();
// if ( currentExpressibleSupplier != null ) {
// return currentExpressibleSupplier.get();
// }
// }
//
// if ( valueMapping == null ) {
// throw new ConversionException( "Could not determine ValueMapping for SqmExpressible: " + nodeType );
// }
//
// return valueMapping;
// }
protected MappingModelExpressible<?> determineValueMapping(SqmExpression<?> sqmExpression) { protected MappingModelExpressible<?> determineValueMapping(SqmExpression<?> sqmExpression) {
if ( sqmExpression instanceof SqmParameter ) { if ( sqmExpression instanceof SqmParameter ) {
return determineValueMapping( (SqmParameter<?>) sqmExpression ); return determineValueMapping( (SqmParameter<?>) sqmExpression );
} }
else if ( sqmExpression instanceof SqmPath ) {
final MappingMetamodel domainModel = creationContext.getSessionFactory()
.getRuntimeMetamodels()
.getMappingMetamodel();
if ( sqmExpression instanceof SqmPath ) {
log.debugf( "Determining mapping-model type for SqmPath : %s ", sqmExpression ); log.debugf( "Determining mapping-model type for SqmPath : %s ", sqmExpression );
prepareReusablePath( (SqmPath<?>) sqmExpression, () -> null ); prepareReusablePath( (SqmPath<?>) sqmExpression, () -> null );
final MappingMetamodel domainModel = creationContext.getSessionFactory()
.getRuntimeMetamodels()
.getMappingMetamodel();
return SqmMappingModelHelper.resolveMappingModelExpressible( return SqmMappingModelHelper.resolveMappingModelExpressible(
sqmExpression, sqmExpression,
domainModel, domainModel,
getFromClauseAccess()::findTableGroup getFromClauseAccess()::findTableGroup
); );
} }
// The model type of an enum literal is always inferred // The model type of an enum literal is always inferred
if ( sqmExpression instanceof SqmEnumLiteral<?> ) { else if ( sqmExpression instanceof SqmEnumLiteral<?> ) {
final Supplier<MappingModelExpressible<?>> currentExpressibleSupplier = inferrableTypeAccessStack.getCurrent(); final MappingModelExpressible<?> mappingModelExpressible = resolveInferredType();
if ( currentExpressibleSupplier != null ) { if ( mappingModelExpressible != null ) {
return currentExpressibleSupplier.get(); return mappingModelExpressible;
} }
} }
else if ( sqmExpression instanceof SqmSubQuery<?> ) {
if ( sqmExpression instanceof SqmSubQuery<?> ) {
final SqmSubQuery<?> subQuery = (SqmSubQuery<?>) sqmExpression; final SqmSubQuery<?> subQuery = (SqmSubQuery<?>) sqmExpression;
final SqmSelectClause selectClause = subQuery.getQuerySpec().getSelectClause(); final SqmSelectClause selectClause = subQuery.getQuerySpec().getSelectClause();
if ( selectClause.getSelections().size() == 1 ) { if ( selectClause.getSelections().size() == 1 ) {
@ -4319,6 +4305,9 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
sqmExpressible = selectionNodeType; sqmExpressible = selectionNodeType;
} }
final MappingMetamodel domainModel = creationContext.getSessionFactory()
.getRuntimeMetamodels()
.getMappingMetamodel();
final MappingModelExpressible<?> expressible = domainModel.resolveMappingExpressible(sqmExpressible, this::findTableGroupByPath ); final MappingModelExpressible<?> expressible = domainModel.resolveMappingExpressible(sqmExpressible, this::findTableGroupByPath );
if ( expressible != null ) { if ( expressible != null ) {
@ -4326,7 +4315,10 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
} }
try { try {
return inferrableTypeAccessStack.getCurrent().get(); final MappingModelExpressible<?> mappingModelExpressible = resolveInferredType();
if ( mappingModelExpressible != null ) {
return mappingModelExpressible;
}
} }
catch (Exception ignore) { catch (Exception ignore) {
return null; return null;
@ -4341,15 +4333,18 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
// We can't determine the type of the expression // We can't determine the type of the expression
return null; return null;
} }
final MappingMetamodel domainModel = creationContext.getSessionFactory()
.getRuntimeMetamodels()
.getMappingMetamodel();
final MappingModelExpressible<?> valueMapping = domainModel.resolveMappingExpressible( final MappingModelExpressible<?> valueMapping = domainModel.resolveMappingExpressible(
nodeType, nodeType,
this::findTableGroupByPath this::findTableGroupByPath
); );
if ( valueMapping == null ) { if ( valueMapping == null ) {
final Supplier<MappingModelExpressible<?>> currentExpressibleSupplier = inferrableTypeAccessStack.getCurrent(); final MappingModelExpressible<?> mappingModelExpressible = resolveInferredType();
if ( currentExpressibleSupplier != null ) { if ( mappingModelExpressible != null ) {
return currentExpressibleSupplier.get(); return mappingModelExpressible;
} }
} }
@ -4361,17 +4356,14 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
} }
protected MappingModelExpressible<?> getInferredValueMapping() { protected MappingModelExpressible<?> getInferredValueMapping() {
final Supplier<MappingModelExpressible<?>> currentExpressibleSupplier = inferrableTypeAccessStack.getCurrent(); final MappingModelExpressible<?> inferredMapping = resolveInferredType();
if ( currentExpressibleSupplier != null ) { if ( inferredMapping != null ) {
final MappingModelExpressible<?> inferredMapping = currentExpressibleSupplier.get(); if ( inferredMapping instanceof PluralAttributeMapping ) {
if ( inferredMapping != null ) { return ( (PluralAttributeMapping) inferredMapping ).getElementDescriptor();
if ( inferredMapping instanceof PluralAttributeMapping ) { }
return ( (PluralAttributeMapping) inferredMapping ).getElementDescriptor(); else if ( !( inferredMapping instanceof JavaObjectType ) ) {
} // Never report back the "object type" as inferred type and instead rely on the value type
else if ( !( inferredMapping instanceof JavaObjectType ) ) { return inferredMapping;
// Never report back the "object type" as inferred type and instead rely on the value type
return inferredMapping;
}
} }
} }
return null; return null;
@ -4452,10 +4444,6 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
throw new ConversionException( "Could not determine ValueMapping for SqmParameter: " + sqmParameter ); throw new ConversionException( "Could not determine ValueMapping for SqmParameter: " + sqmParameter );
} }
protected final Stack<Supplier<MappingModelExpressible<?>>> inferrableTypeAccessStack = new StandardStack<>(
() -> null
);
private void resolveSqmParameter( private void resolveSqmParameter(
SqmParameter<?> expression, SqmParameter<?> expression,
MappingModelExpressible<?> valueMapping, MappingModelExpressible<?> valueMapping,
@ -4513,7 +4501,7 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
final List<SqmExpression<?>> groupedExpressions = sqmTuple.getGroupedExpressions(); final List<SqmExpression<?>> groupedExpressions = sqmTuple.getGroupedExpressions();
final int size = groupedExpressions.size(); final int size = groupedExpressions.size();
final List<Expression> expressions = new ArrayList<>( size ); final List<Expression> expressions = new ArrayList<>( size );
final MappingModelExpressible<?> mappingModelExpressible = inferrableTypeAccessStack.getCurrent().get(); final MappingModelExpressible<?> mappingModelExpressible = resolveInferredType();
final EmbeddableMappingType embeddableMappingType; final EmbeddableMappingType embeddableMappingType;
if ( mappingModelExpressible instanceof ValueMapping ) { if ( mappingModelExpressible instanceof ValueMapping ) {
embeddableMappingType = (EmbeddableMappingType) ( (ValueMapping) mappingModelExpressible).getMappedType(); embeddableMappingType = (EmbeddableMappingType) ( (ValueMapping) mappingModelExpressible).getMappedType();
@ -4667,23 +4655,33 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
else if ( temporalTypeToLeft != null && temporalTypeToRight != null ) { else if ( temporalTypeToLeft != null && temporalTypeToRight != null ) {
return transformDatetimeArithmetic( expression ); return transformDatetimeArithmetic( expression );
} }
else if ( durationToRight && appliedByUnit != null ) {
return new BinaryArithmeticExpression(
toSqlExpression( leftOperand.accept( this ) ),
expression.getOperator(),
toSqlExpression( rightOperand.accept( this ) ),
//after distributing the 'by unit' operator
//we always get a Long value back
(BasicValuedMapping) appliedByUnit.getNodeType()
);
}
else { else {
return new BinaryArithmeticExpression( // Infer one operand type through the other
toSqlExpression( leftOperand.accept( this ) ), inferrableTypeAccessStack.push( () -> determineValueMapping( rightOperand ) );
expression.getOperator(), final Expression lhs = toSqlExpression( leftOperand.accept( this ) );
toSqlExpression( rightOperand.accept( this ) ), inferrableTypeAccessStack.pop();
getExpressionType( expression ) inferrableTypeAccessStack.push( () -> determineValueMapping( leftOperand ) );
); final Expression rhs = toSqlExpression( rightOperand.accept( this ) );
inferrableTypeAccessStack.pop();
if ( durationToRight && appliedByUnit != null ) {
return new BinaryArithmeticExpression(
lhs,
expression.getOperator(),
rhs,
//after distributing the 'by unit' operator
//we always get a Long value back
(BasicValuedMapping) appliedByUnit.getNodeType()
);
}
else {
return new BinaryArithmeticExpression(
lhs,
expression.getOperator(),
rhs,
getExpressionType( expression )
);
}
} }
} }
@ -5281,14 +5279,14 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
@Override @Override
public Expression visitParameterizedEntityTypeExpression(SqmParameterizedEntityType<?> sqmExpression) { public Expression visitParameterizedEntityTypeExpression(SqmParameterizedEntityType<?> sqmExpression) {
assert inferrableTypeAccessStack.getCurrent().get() instanceof EntityDiscriminatorMapping; assert resolveInferredType() instanceof EntityDiscriminatorMapping;
return (Expression) sqmExpression.getDiscriminatorSource().accept( this ); return (Expression) sqmExpression.getDiscriminatorSource().accept( this );
} }
@SuppressWarnings({"raw","unchecked"}) @SuppressWarnings({"raw","unchecked"})
@Override @Override
public Object visitEnumLiteral(SqmEnumLiteral<?> sqmEnumLiteral) { public Object visitEnumLiteral(SqmEnumLiteral<?> sqmEnumLiteral) {
final BasicValuedMapping inferrableType = (BasicValuedMapping) inferrableTypeAccessStack.getCurrent().get(); final BasicValuedMapping inferrableType = (BasicValuedMapping) resolveInferredType();
if ( inferrableType instanceof ConvertibleModelPart ) { if ( inferrableType instanceof ConvertibleModelPart ) {
final ConvertibleModelPart inferredPart = (ConvertibleModelPart) inferrableType; final ConvertibleModelPart inferredPart = (ConvertibleModelPart) inferrableType;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")

View File

@ -0,0 +1,118 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.orm.test.hql;
import java.util.List;
import org.hibernate.dialect.PostgreSQLDialect;
import org.hibernate.testing.SkipForDialect;
import org.hibernate.testing.TestForIssue;
import org.hibernate.testing.junit4.BaseCoreFunctionalTestCase;
import org.junit.Before;
import org.junit.Test;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.Id;
import jakarta.persistence.TypedQuery;
import static org.hamcrest.CoreMatchers.hasItem;
import static org.hibernate.testing.transaction.TransactionUtil.doInHibernate;
import static org.junit.Assert.assertThat;
/**
* @author Christian Beikov
*/
public class InferenceTest extends BaseCoreFunctionalTestCase {
private Person person;
@Override
protected Class[] getAnnotatedClasses() {
return new Class[] {
Person.class
};
}
@Before
public void setUp() {
doInHibernate( this::sessionFactory, session -> {
person = new Person();
person.setName("Johannes");
person.setSurname("Buehler");
session.persist(person);
} );
}
@Test
public void testBinaryArithmeticInference() {
doInHibernate( this::sessionFactory, session -> {
TypedQuery<Person> query = session.createQuery( "from Person p where p.id + 1 < :param", Person.class );
query.setParameter("param", 10);
List<Person> resultList = query.getResultList();
assertThat(resultList, hasItem(person));
} );
}
@Entity(name = "Person")
public static class Person {
@Id
@GeneratedValue
private Integer id;
@Column
private String name;
@Column
private String surname;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getSurname() {
return surname;
}
public void setSurname(String surname) {
this.surname = surname;
}
public Integer getId() {
return id;
}
public void setId(Integer id) {
this.id = id;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Person )) return false;
Person person = (Person) o;
return id != null ? id.equals(person.id) : person.id == null;
}
@Override
public int hashCode() {
return id != null ? id.hashCode() : 0;
}
}
}