Fix type inference for binary arithmetic expressions

This commit is contained in:
Christian Beikov 2022-01-30 14:57:26 +01:00
parent af42f3a76c
commit ce5951a948
2 changed files with 194 additions and 78 deletions

View File

@ -411,6 +411,10 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
private SqmJoin<?, ?> currentlyProcessingJoin;
private final Stack<Clause> currentClauseStack = new StandardStack<>();
private final Stack<Supplier<MappingModelExpressible<?>>> inferrableTypeAccessStack = new StandardStack<>(
() -> null
);
private boolean inTypeInference;
private SqmByUnit appliedByUnit;
private Expression adjustedTimestamp;
@ -3890,15 +3894,24 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
}
}
private MappingModelExpressible<?> resolveInferredType() {
final Supplier<MappingModelExpressible<?>> inferableTypeAccess = inferrableTypeAccessStack.getCurrent();
if ( inTypeInference || inferableTypeAccess == null ) {
return null;
}
inTypeInference = true;
final MappingModelExpressible<?> inferredType = inferableTypeAccess.get();
inTypeInference = false;
return inferredType;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// General expressions
@Override
public Expression visitLiteral(SqmLiteral<?> literal) {
final Supplier<MappingModelExpressible<?>> inferableTypeAccess = inferrableTypeAccessStack.getCurrent();
if ( literal instanceof SqmLiteralNull ) {
MappingModelExpressible<?> mappingModelExpressible = inferableTypeAccess.get();
MappingModelExpressible<?> mappingModelExpressible = resolveInferredType();
if ( mappingModelExpressible == null ) {
mappingModelExpressible = determineCurrentExpressible( literal );
}
@ -3924,7 +3937,7 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
return new SqlTuple( expressions, mappingModelExpressible);
}
final MappingModelExpressible<?> inferableExpressible = inferableTypeAccess.get();
final MappingModelExpressible<?> inferableExpressible = resolveInferredType();
if ( inferableExpressible instanceof ConvertibleModelPart ) {
final ConvertibleModelPart convertibleModelPart = (ConvertibleModelPart) inferableExpressible;
@ -4234,57 +4247,30 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
);
}
// protected MappingModelExpressible<?> lenientlyResolveMappingExpressible(SqmExpressible<?> nodeType) {
// return resolveMappingExpressible( nodeType );
// }
//
// protected MappingModelExpressible<?> resolveMappingExpressible(SqmExpressible<?> nodeType) {
// final MappingModelExpressible<?> valueMapping = getCreationContext().getDomainModel().resolveMappingExpressible(
// nodeType,
// this::findTableGroupByPath
// );
//
// if ( valueMapping == null ) {
// final Supplier<MappingModelExpressible<?>> currentExpressibleSupplier = inferrableTypeAccessStack.getCurrent();
// if ( currentExpressibleSupplier != null ) {
// return currentExpressibleSupplier.get();
// }
// }
//
// if ( valueMapping == null ) {
// throw new ConversionException( "Could not determine ValueMapping for SqmExpressible: " + nodeType );
// }
//
// return valueMapping;
// }
protected MappingModelExpressible<?> determineValueMapping(SqmExpression<?> sqmExpression) {
if ( sqmExpression instanceof SqmParameter ) {
return determineValueMapping( (SqmParameter<?>) sqmExpression );
}
final MappingMetamodel domainModel = creationContext.getSessionFactory()
.getRuntimeMetamodels()
.getMappingMetamodel();
if ( sqmExpression instanceof SqmPath ) {
else if ( sqmExpression instanceof SqmPath ) {
log.debugf( "Determining mapping-model type for SqmPath : %s ", sqmExpression );
prepareReusablePath( (SqmPath<?>) sqmExpression, () -> null );
final MappingMetamodel domainModel = creationContext.getSessionFactory()
.getRuntimeMetamodels()
.getMappingMetamodel();
return SqmMappingModelHelper.resolveMappingModelExpressible(
sqmExpression,
domainModel,
getFromClauseAccess()::findTableGroup
);
}
// The model type of an enum literal is always inferred
if ( sqmExpression instanceof SqmEnumLiteral<?> ) {
final Supplier<MappingModelExpressible<?>> currentExpressibleSupplier = inferrableTypeAccessStack.getCurrent();
if ( currentExpressibleSupplier != null ) {
return currentExpressibleSupplier.get();
else if ( sqmExpression instanceof SqmEnumLiteral<?> ) {
final MappingModelExpressible<?> mappingModelExpressible = resolveInferredType();
if ( mappingModelExpressible != null ) {
return mappingModelExpressible;
}
}
if ( sqmExpression instanceof SqmSubQuery<?> ) {
else if ( sqmExpression instanceof SqmSubQuery<?> ) {
final SqmSubQuery<?> subQuery = (SqmSubQuery<?>) sqmExpression;
final SqmSelectClause selectClause = subQuery.getQuerySpec().getSelectClause();
if ( selectClause.getSelections().size() == 1 ) {
@ -4319,6 +4305,9 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
sqmExpressible = selectionNodeType;
}
final MappingMetamodel domainModel = creationContext.getSessionFactory()
.getRuntimeMetamodels()
.getMappingMetamodel();
final MappingModelExpressible<?> expressible = domainModel.resolveMappingExpressible(sqmExpressible, this::findTableGroupByPath );
if ( expressible != null ) {
@ -4326,7 +4315,10 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
}
try {
return inferrableTypeAccessStack.getCurrent().get();
final MappingModelExpressible<?> mappingModelExpressible = resolveInferredType();
if ( mappingModelExpressible != null ) {
return mappingModelExpressible;
}
}
catch (Exception ignore) {
return null;
@ -4341,15 +4333,18 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
// We can't determine the type of the expression
return null;
}
final MappingMetamodel domainModel = creationContext.getSessionFactory()
.getRuntimeMetamodels()
.getMappingMetamodel();
final MappingModelExpressible<?> valueMapping = domainModel.resolveMappingExpressible(
nodeType,
this::findTableGroupByPath
);
if ( valueMapping == null ) {
final Supplier<MappingModelExpressible<?>> currentExpressibleSupplier = inferrableTypeAccessStack.getCurrent();
if ( currentExpressibleSupplier != null ) {
return currentExpressibleSupplier.get();
final MappingModelExpressible<?> mappingModelExpressible = resolveInferredType();
if ( mappingModelExpressible != null ) {
return mappingModelExpressible;
}
}
@ -4361,17 +4356,14 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
}
protected MappingModelExpressible<?> getInferredValueMapping() {
final Supplier<MappingModelExpressible<?>> currentExpressibleSupplier = inferrableTypeAccessStack.getCurrent();
if ( currentExpressibleSupplier != null ) {
final MappingModelExpressible<?> inferredMapping = currentExpressibleSupplier.get();
if ( inferredMapping != null ) {
if ( inferredMapping instanceof PluralAttributeMapping ) {
return ( (PluralAttributeMapping) inferredMapping ).getElementDescriptor();
}
else if ( !( inferredMapping instanceof JavaObjectType ) ) {
// Never report back the "object type" as inferred type and instead rely on the value type
return inferredMapping;
}
final MappingModelExpressible<?> inferredMapping = resolveInferredType();
if ( inferredMapping != null ) {
if ( inferredMapping instanceof PluralAttributeMapping ) {
return ( (PluralAttributeMapping) inferredMapping ).getElementDescriptor();
}
else if ( !( inferredMapping instanceof JavaObjectType ) ) {
// Never report back the "object type" as inferred type and instead rely on the value type
return inferredMapping;
}
}
return null;
@ -4452,10 +4444,6 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
throw new ConversionException( "Could not determine ValueMapping for SqmParameter: " + sqmParameter );
}
protected final Stack<Supplier<MappingModelExpressible<?>>> inferrableTypeAccessStack = new StandardStack<>(
() -> null
);
private void resolveSqmParameter(
SqmParameter<?> expression,
MappingModelExpressible<?> valueMapping,
@ -4513,7 +4501,7 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
final List<SqmExpression<?>> groupedExpressions = sqmTuple.getGroupedExpressions();
final int size = groupedExpressions.size();
final List<Expression> expressions = new ArrayList<>( size );
final MappingModelExpressible<?> mappingModelExpressible = inferrableTypeAccessStack.getCurrent().get();
final MappingModelExpressible<?> mappingModelExpressible = resolveInferredType();
final EmbeddableMappingType embeddableMappingType;
if ( mappingModelExpressible instanceof ValueMapping ) {
embeddableMappingType = (EmbeddableMappingType) ( (ValueMapping) mappingModelExpressible).getMappedType();
@ -4667,23 +4655,33 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
else if ( temporalTypeToLeft != null && temporalTypeToRight != null ) {
return transformDatetimeArithmetic( expression );
}
else if ( durationToRight && appliedByUnit != null ) {
return new BinaryArithmeticExpression(
toSqlExpression( leftOperand.accept( this ) ),
expression.getOperator(),
toSqlExpression( rightOperand.accept( this ) ),
//after distributing the 'by unit' operator
//we always get a Long value back
(BasicValuedMapping) appliedByUnit.getNodeType()
);
}
else {
return new BinaryArithmeticExpression(
toSqlExpression( leftOperand.accept( this ) ),
expression.getOperator(),
toSqlExpression( rightOperand.accept( this ) ),
getExpressionType( expression )
);
// Infer one operand type through the other
inferrableTypeAccessStack.push( () -> determineValueMapping( rightOperand ) );
final Expression lhs = toSqlExpression( leftOperand.accept( this ) );
inferrableTypeAccessStack.pop();
inferrableTypeAccessStack.push( () -> determineValueMapping( leftOperand ) );
final Expression rhs = toSqlExpression( rightOperand.accept( this ) );
inferrableTypeAccessStack.pop();
if ( durationToRight && appliedByUnit != null ) {
return new BinaryArithmeticExpression(
lhs,
expression.getOperator(),
rhs,
//after distributing the 'by unit' operator
//we always get a Long value back
(BasicValuedMapping) appliedByUnit.getNodeType()
);
}
else {
return new BinaryArithmeticExpression(
lhs,
expression.getOperator(),
rhs,
getExpressionType( expression )
);
}
}
}
@ -5281,14 +5279,14 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
@Override
public Expression visitParameterizedEntityTypeExpression(SqmParameterizedEntityType<?> sqmExpression) {
assert inferrableTypeAccessStack.getCurrent().get() instanceof EntityDiscriminatorMapping;
assert resolveInferredType() instanceof EntityDiscriminatorMapping;
return (Expression) sqmExpression.getDiscriminatorSource().accept( this );
}
@SuppressWarnings({"raw","unchecked"})
@Override
public Object visitEnumLiteral(SqmEnumLiteral<?> sqmEnumLiteral) {
final BasicValuedMapping inferrableType = (BasicValuedMapping) inferrableTypeAccessStack.getCurrent().get();
final BasicValuedMapping inferrableType = (BasicValuedMapping) resolveInferredType();
if ( inferrableType instanceof ConvertibleModelPart ) {
final ConvertibleModelPart inferredPart = (ConvertibleModelPart) inferrableType;
@SuppressWarnings("unchecked")

View File

@ -0,0 +1,118 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.orm.test.hql;
import java.util.List;
import org.hibernate.dialect.PostgreSQLDialect;
import org.hibernate.testing.SkipForDialect;
import org.hibernate.testing.TestForIssue;
import org.hibernate.testing.junit4.BaseCoreFunctionalTestCase;
import org.junit.Before;
import org.junit.Test;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.Id;
import jakarta.persistence.TypedQuery;
import static org.hamcrest.CoreMatchers.hasItem;
import static org.hibernate.testing.transaction.TransactionUtil.doInHibernate;
import static org.junit.Assert.assertThat;
/**
* @author Christian Beikov
*/
public class InferenceTest extends BaseCoreFunctionalTestCase {
private Person person;
@Override
protected Class[] getAnnotatedClasses() {
return new Class[] {
Person.class
};
}
@Before
public void setUp() {
doInHibernate( this::sessionFactory, session -> {
person = new Person();
person.setName("Johannes");
person.setSurname("Buehler");
session.persist(person);
} );
}
@Test
public void testBinaryArithmeticInference() {
doInHibernate( this::sessionFactory, session -> {
TypedQuery<Person> query = session.createQuery( "from Person p where p.id + 1 < :param", Person.class );
query.setParameter("param", 10);
List<Person> resultList = query.getResultList();
assertThat(resultList, hasItem(person));
} );
}
@Entity(name = "Person")
public static class Person {
@Id
@GeneratedValue
private Integer id;
@Column
private String name;
@Column
private String surname;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getSurname() {
return surname;
}
public void setSurname(String surname) {
this.surname = surname;
}
public Integer getId() {
return id;
}
public void setId(Integer id) {
this.id = id;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Person )) return false;
Person person = (Person) o;
return id != null ? id.equals(person.id) : person.id == null;
}
@Override
public int hashCode() {
return id != null ? id.hashCode() : 0;
}
}
}