diff --git a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java index a3c427ce12..2700459fce 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java @@ -2645,8 +2645,8 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem @SuppressWarnings({ "unchecked", "rawtypes" }) public SqmPredicate visitInPredicate(HqlParser.InPredicateContext ctx) { final boolean negated = ctx.getChildCount() == 4; - final SqmExpression testExpression = (SqmExpression) ctx.getChild( 0 ).accept( this ); - final HqlParser.InListContext inListContext = (HqlParser.InListContext) ctx.getChild( ctx.getChildCount() - 1 ); + final SqmExpression testExpression = (SqmExpression) ctx.expression().accept( this ); + final HqlParser.InListContext inListContext = ctx.inList(); if ( inListContext instanceof HqlParser.ExplicitTupleInListContext ) { final HqlParser.ExplicitTupleInListContext tupleExpressionListContext = (HqlParser.ExplicitTupleInListContext) inListContext; final int size = tupleExpressionListContext.getChildCount(); @@ -2707,9 +2707,11 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem } else if ( inListContext instanceof HqlParser.SubqueryInListContext ) { final HqlParser.SubqueryInListContext subQueryOrParamInListContext = (HqlParser.SubqueryInListContext) inListContext; + final SqmSubQuery subquery = visitSubquery(subQueryOrParamInListContext.subquery()); + ( (SqmCriteriaNodeBuilder) creationContext.getNodeBuilder() ).assertComparable( testExpression, subquery ); return new SqmInSubQueryPredicate( testExpression, - visitSubquery( subQueryOrParamInListContext.subquery() ), + subquery, negated, creationContext.getNodeBuilder() ); @@ -2718,7 +2720,8 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem if ( getCreationOptions().useStrictJpaCompliance() ) { throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.HQL_COLLECTION_FUNCTION ); } - final HqlParser.PersistentCollectionReferenceInListContext collectionReferenceInListContext = (HqlParser.PersistentCollectionReferenceInListContext) inListContext; + final HqlParser.PersistentCollectionReferenceInListContext collectionReferenceInListContext = + (HqlParser.PersistentCollectionReferenceInListContext) inListContext; return new SqmInSubQueryPredicate<>( testExpression, createCollectionReferenceSubQuery( diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java index 9b41075de9..de2425c692 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java @@ -2030,14 +2030,20 @@ public class SqmCriteriaNodeBuilder implements NodeBuilder, SqmCreationContext, } public void assertComparable(Expression x, Expression y) { - final SqmExpressible lhsType = ( (SqmExpression) x ).getNodeType(); - final SqmExpressible rhsType = ( (SqmExpression) y ).getNodeType(); - if ( !areTypesComparable( lhsType, rhsType ) ) { + SqmExpression left = (SqmExpression) x; + SqmExpression right = (SqmExpression) y; + if ( left.getTupleLength() != null && right.getTupleLength() != null + && left.getTupleLength().intValue() != right.getTupleLength().intValue() ) { + throw new SemanticException( "Cannot compare tuples of different lengths" ); + } + final SqmExpressible leftType = left.getNodeType(); + final SqmExpressible rightType = right.getNodeType(); + if ( !areTypesComparable( leftType, rightType ) ) { throw new SemanticException( String.format( - "Can't compare test expression of type [%s] with element of type [%s]", - lhsType, - rhsType + "Cannot compare left expression of type [%s] with right expression of type [%s]", + leftType, + rightType ) ); } diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/domain/SqmEmbeddedValuedSimplePath.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/domain/SqmEmbeddedValuedSimplePath.java index 8f3d77e780..6c0f3a728e 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/domain/SqmEmbeddedValuedSimplePath.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/domain/SqmEmbeddedValuedSimplePath.java @@ -6,6 +6,8 @@ */ package org.hibernate.query.sqm.tree.domain; +import jakarta.persistence.metamodel.Attribute; +import jakarta.persistence.metamodel.SingularAttribute; import org.hibernate.metamodel.model.domain.DomainType; import org.hibernate.metamodel.model.domain.EmbeddableDomainType; import org.hibernate.metamodel.model.domain.EntityDomainType; @@ -20,6 +22,8 @@ import org.hibernate.query.sqm.tree.SqmCopyContext; import org.hibernate.spi.NavigablePath; import org.hibernate.type.descriptor.java.JavaType; +import java.util.Set; + /** * @author Steve Ebersole */ @@ -74,6 +78,21 @@ public class SqmEmbeddedValuedSimplePath return this; } + @Override + public Integer getTupleLength() { + final EmbeddableDomainType sqmPathType = (EmbeddableDomainType) getReferencedPathSource().getSqmPathType(); + final Set> attributes = sqmPathType.getSingularAttributes(); + return length(attributes); + } + + private int length(Set> attributes) { + int length = 0; + for (Attribute attribute : attributes) { + length += get(attribute.getName()).getTupleLength(); + } + return length; + } + @Override public DomainType getSqmType() { return getReferencedPathSource().getSqmType(); diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/JpaCriteriaParameter.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/JpaCriteriaParameter.java index 5ed8b1abff..ea4e7fca32 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/JpaCriteriaParameter.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/JpaCriteriaParameter.java @@ -78,6 +78,12 @@ public class JpaCriteriaParameter return null; } + @Override + public Integer getTupleLength() { + // TODO: we should be able to do much better than this! + return null; + } + @Override public boolean allowsMultiValuedBinding() { return allowsMultiValuedBinding; diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmAny.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmAny.java index 41a3fc35d2..e9d6ed4b39 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmAny.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmAny.java @@ -8,6 +8,7 @@ package org.hibernate.query.sqm.tree.expression; import org.hibernate.query.sqm.NodeBuilder; import org.hibernate.query.sqm.SemanticQueryWalker; +import org.hibernate.query.sqm.SqmExpressible; import org.hibernate.query.sqm.tree.SqmCopyContext; import org.hibernate.query.sqm.tree.select.SqmSubQuery; @@ -23,6 +24,16 @@ public class SqmAny extends AbstractSqmExpression { this.subquery = subquery; } + @Override + public SqmExpressible getNodeType() { + return subquery.getNodeType(); + } + + @Override + public Integer getTupleLength() { + return subquery.getTupleLength(); + } + @Override public SqmAny copy(SqmCopyContext context) { final SqmAny existing = context.getCopy( this ); diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmEvery.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmEvery.java index e7b6280b12..22e9e35a2a 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmEvery.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmEvery.java @@ -8,6 +8,7 @@ package org.hibernate.query.sqm.tree.expression; import org.hibernate.query.sqm.NodeBuilder; import org.hibernate.query.sqm.SemanticQueryWalker; +import org.hibernate.query.sqm.SqmExpressible; import org.hibernate.query.sqm.tree.SqmCopyContext; import org.hibernate.query.sqm.tree.select.SqmSubQuery; @@ -23,6 +24,16 @@ public class SqmEvery extends AbstractSqmExpression { this.subquery = subquery; } + @Override + public SqmExpressible getNodeType() { + return subquery.getNodeType(); + } + + @Override + public Integer getTupleLength() { + return subquery.getTupleLength(); + } + @Override public SqmEvery copy(SqmCopyContext context) { final SqmEvery existing = context.getCopy( this ); diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmExpression.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmExpression.java index adb8719b3e..54999fc56b 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmExpression.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmExpression.java @@ -58,6 +58,10 @@ public interface SqmExpression extends SqmSelectableNode, JpaExpression jpaSelectionConsumer.accept( this ); } + default Integer getTupleLength() { + return 1; + } + @Override SqmExpression asLong(); diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmNamedParameter.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmNamedParameter.java index 723961c1e8..08d66e9eaf 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmNamedParameter.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmNamedParameter.java @@ -88,4 +88,9 @@ public class SqmNamedParameter extends AbstractSqmParameter { ? getName().compareTo( ( (SqmNamedParameter) anotherParameter ).getName() ) : -1; } + + @Override + public Integer getTupleLength() { + return null; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmTuple.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmTuple.java index 72761bf553..57502bc223 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmTuple.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/SqmTuple.java @@ -110,4 +110,8 @@ public class SqmTuple return getGroupedExpressions(); } + @Override + public Integer getTupleLength() { + return groupedExpressions.size(); + } } diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmSubQuery.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmSubQuery.java index 676f4835dd..16278ccb19 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmSubQuery.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmSubQuery.java @@ -149,6 +149,12 @@ public class SqmSubQuery extends AbstractSqmSelectQuery implements SqmSele return statement; } + @Override + public Integer getTupleLength() { + final SqmSelectClause selectClause = getQuerySpec().getSelectClause(); + return selectClause == null ? null : selectClause.getSelectionItems().size(); + } + @Override public SqmCteStatement getCteStatement(String cteLabel) { final SqmCteStatement cteCriteria = super.getCteStatement( cteLabel ); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/filter/DynamicFilterTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/filter/DynamicFilterTest.java index 9010f4f41e..4b5dcaf753 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/filter/DynamicFilterTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/filter/DynamicFilterTest.java @@ -528,7 +528,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase { log.info( "query against Department with a subquery on Salesperson in the APAC reqion..." ); List departments = session.createQuery( - "select d from Department as d where d.id in (select s.department from Salesperson s where s.name = ?1)" + "select d from Department as d where d in (select s.department from Salesperson s where s.name = ?1)" ).setParameter( 1, "steve" ).list(); assertEquals( "Incorrect department count", 1, departments.size() ); @@ -537,7 +537,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase { session.enableFilter( "region" ).setParameter( "region", "Foobar" ); departments = session.createQuery( - "select d from Department as d where d.id in (select s.department from Salesperson s where s.name = ?1)" ) + "select d from Department as d where d in (select s.department from Salesperson s where s.name = ?1)" ) .setParameter( 1, "steve" ) .list(); @@ -560,7 +560,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase { session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() ); orders = session.createQuery( - "select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" ) + "select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product.id in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" ) .setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list(); assertEquals( "Incorrect orders count", 1, orders.size() ); @@ -577,7 +577,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase { ); orders = session.createQuery( - "select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" ) + "select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p from Product p where p.name = ?2)) and o.buyer = ?3" ) .setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list(); assertEquals( "Incorrect orders count", 0, orders.size() ); @@ -589,7 +589,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase { session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() ); orders = session.createQuery( - "select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" ) + "select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p from Product p where p.name = ?2)) and o.buyer = ?3" ) .setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list(); assertEquals( "Incorrect orders count", 1, orders.size() ); @@ -601,7 +601,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase { session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() ); orders = session.createQuery( - "select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" ) + "select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p from Product p where p.name = ?2)) and o.buyer = ?3" ) .setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list(); assertEquals( "Incorrect orders count", 1, orders.size() ); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/hql/SubqueryPredicateTypingTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/hql/SubqueryPredicateTypingTest.java new file mode 100644 index 0000000000..96e9071f1d --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/hql/SubqueryPredicateTypingTest.java @@ -0,0 +1,63 @@ +package org.hibernate.orm.test.hql; + +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import org.hibernate.query.SemanticException; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.fail; + +@SessionFactory +@DomainModel(annotatedClasses = SubqueryPredicateTypingTest.Book.class) +public class SubqueryPredicateTypingTest { + @Test void test(SessionFactoryScope scope) { + scope.inSession( s -> { + // these should succeed + s.createSelectionQuery("from Book where title in (select title from Book)").getResultList(); + s.createSelectionQuery("from Book where title = any (select title from Book)").getResultList(); + + // test tuple length errors + try { + s.createSelectionQuery("from Book where title = any (select title, isbn from Book)").getResultList(); + fail(); + } + catch (SemanticException se) {} + try { + s.createSelectionQuery("from Book where title = every (select title, isbn from Book)").getResultList(); + fail(); + } + catch (SemanticException se) {} + try { + s.createSelectionQuery("from Book where title in (select title, isbn from Book)").getResultList(); + fail(); + } + catch (SemanticException se) {} + + // test typing errors + try { + s.createSelectionQuery("from Book where 1 = any (select title from Book)").getResultList(); + fail(); + } + catch (SemanticException se) {} + try { + s.createSelectionQuery("from Book where 1 = every (select title from Book)").getResultList(); + fail(); + } + catch (SemanticException se) {} + try { + s.createSelectionQuery("from Book where 1 in (select title from Book)").getResultList(); + fail(); + } + catch (SemanticException se) {} + }); + } + + @Entity(name = "Book") + static class Book { + @Id String isbn; + String title; + } +} diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/LikeEscapeParameterTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/LikeEscapeParameterTest.java index 4917ac4080..463b03910c 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/LikeEscapeParameterTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/LikeEscapeParameterTest.java @@ -61,7 +61,7 @@ public class LikeEscapeParameterTest { session -> { session.createQuery( "select s from TestEntity s where s.name like ?1 escape '\\'" + - " or s in (" + + " or s.id in (" + " select distinct t.id from TestEntity t where t.name like ?2 escape '\\'" + " )", TestEntity.class