HHH-16802 typecheck subquery predicates (#6801)

This commit is contained in:
Gavin King 2023-06-15 11:51:08 +02:00 committed by GitHub
parent 9d052413fc
commit 96941f3775
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 155 additions and 17 deletions

View File

@ -2645,8 +2645,8 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
@SuppressWarnings({ "unchecked", "rawtypes" }) @SuppressWarnings({ "unchecked", "rawtypes" })
public SqmPredicate visitInPredicate(HqlParser.InPredicateContext ctx) { public SqmPredicate visitInPredicate(HqlParser.InPredicateContext ctx) {
final boolean negated = ctx.getChildCount() == 4; final boolean negated = ctx.getChildCount() == 4;
final SqmExpression<?> testExpression = (SqmExpression<?>) ctx.getChild( 0 ).accept( this ); final SqmExpression<?> testExpression = (SqmExpression<?>) ctx.expression().accept( this );
final HqlParser.InListContext inListContext = (HqlParser.InListContext) ctx.getChild( ctx.getChildCount() - 1 ); final HqlParser.InListContext inListContext = ctx.inList();
if ( inListContext instanceof HqlParser.ExplicitTupleInListContext ) { if ( inListContext instanceof HqlParser.ExplicitTupleInListContext ) {
final HqlParser.ExplicitTupleInListContext tupleExpressionListContext = (HqlParser.ExplicitTupleInListContext) inListContext; final HqlParser.ExplicitTupleInListContext tupleExpressionListContext = (HqlParser.ExplicitTupleInListContext) inListContext;
final int size = tupleExpressionListContext.getChildCount(); final int size = tupleExpressionListContext.getChildCount();
@ -2707,9 +2707,11 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
} }
else if ( inListContext instanceof HqlParser.SubqueryInListContext ) { else if ( inListContext instanceof HqlParser.SubqueryInListContext ) {
final HqlParser.SubqueryInListContext subQueryOrParamInListContext = (HqlParser.SubqueryInListContext) inListContext; final HqlParser.SubqueryInListContext subQueryOrParamInListContext = (HqlParser.SubqueryInListContext) inListContext;
final SqmSubQuery<?> subquery = visitSubquery(subQueryOrParamInListContext.subquery());
( (SqmCriteriaNodeBuilder) creationContext.getNodeBuilder() ).assertComparable( testExpression, subquery );
return new SqmInSubQueryPredicate( return new SqmInSubQueryPredicate(
testExpression, testExpression,
visitSubquery( subQueryOrParamInListContext.subquery() ), subquery,
negated, negated,
creationContext.getNodeBuilder() creationContext.getNodeBuilder()
); );
@ -2718,7 +2720,8 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
if ( getCreationOptions().useStrictJpaCompliance() ) { if ( getCreationOptions().useStrictJpaCompliance() ) {
throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.HQL_COLLECTION_FUNCTION ); throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.HQL_COLLECTION_FUNCTION );
} }
final HqlParser.PersistentCollectionReferenceInListContext collectionReferenceInListContext = (HqlParser.PersistentCollectionReferenceInListContext) inListContext; final HqlParser.PersistentCollectionReferenceInListContext collectionReferenceInListContext =
(HqlParser.PersistentCollectionReferenceInListContext) inListContext;
return new SqmInSubQueryPredicate<>( return new SqmInSubQueryPredicate<>(
testExpression, testExpression,
createCollectionReferenceSubQuery( createCollectionReferenceSubQuery(

View File

@ -2030,14 +2030,20 @@ public class SqmCriteriaNodeBuilder implements NodeBuilder, SqmCreationContext,
} }
public void assertComparable(Expression<?> x, Expression<?> y) { public void assertComparable(Expression<?> x, Expression<?> y) {
final SqmExpressible<?> lhsType = ( (SqmExpression<?>) x ).getNodeType(); SqmExpression<?> left = (SqmExpression<?>) x;
final SqmExpressible<?> rhsType = ( (SqmExpression<?>) y ).getNodeType(); SqmExpression<?> right = (SqmExpression<?>) y;
if ( !areTypesComparable( lhsType, rhsType ) ) { if ( left.getTupleLength() != null && right.getTupleLength() != null
&& left.getTupleLength().intValue() != right.getTupleLength().intValue() ) {
throw new SemanticException( "Cannot compare tuples of different lengths" );
}
final SqmExpressible<?> leftType = left.getNodeType();
final SqmExpressible<?> rightType = right.getNodeType();
if ( !areTypesComparable( leftType, rightType ) ) {
throw new SemanticException( throw new SemanticException(
String.format( String.format(
"Can't compare test expression of type [%s] with element of type [%s]", "Cannot compare left expression of type [%s] with right expression of type [%s]",
lhsType, leftType,
rhsType rightType
) )
); );
} }

View File

@ -6,6 +6,8 @@
*/ */
package org.hibernate.query.sqm.tree.domain; package org.hibernate.query.sqm.tree.domain;
import jakarta.persistence.metamodel.Attribute;
import jakarta.persistence.metamodel.SingularAttribute;
import org.hibernate.metamodel.model.domain.DomainType; import org.hibernate.metamodel.model.domain.DomainType;
import org.hibernate.metamodel.model.domain.EmbeddableDomainType; import org.hibernate.metamodel.model.domain.EmbeddableDomainType;
import org.hibernate.metamodel.model.domain.EntityDomainType; import org.hibernate.metamodel.model.domain.EntityDomainType;
@ -20,6 +22,8 @@ import org.hibernate.query.sqm.tree.SqmCopyContext;
import org.hibernate.spi.NavigablePath; import org.hibernate.spi.NavigablePath;
import org.hibernate.type.descriptor.java.JavaType; import org.hibernate.type.descriptor.java.JavaType;
import java.util.Set;
/** /**
* @author Steve Ebersole * @author Steve Ebersole
*/ */
@ -74,6 +78,21 @@ public class SqmEmbeddedValuedSimplePath<T>
return this; return this;
} }
@Override
public Integer getTupleLength() {
final EmbeddableDomainType<?> sqmPathType = (EmbeddableDomainType<?>) getReferencedPathSource().getSqmPathType();
final Set<? extends SingularAttribute<?, ?>> attributes = sqmPathType.getSingularAttributes();
return length(attributes);
}
private int length(Set<? extends SingularAttribute<?, ?>> attributes) {
int length = 0;
for (Attribute<?, ?> attribute : attributes) {
length += get(attribute.getName()).getTupleLength();
}
return length;
}
@Override @Override
public DomainType<T> getSqmType() { public DomainType<T> getSqmType() {
return getReferencedPathSource().getSqmType(); return getReferencedPathSource().getSqmType();

View File

@ -78,6 +78,12 @@ public class JpaCriteriaParameter<T>
return null; return null;
} }
@Override
public Integer getTupleLength() {
// TODO: we should be able to do much better than this!
return null;
}
@Override @Override
public boolean allowsMultiValuedBinding() { public boolean allowsMultiValuedBinding() {
return allowsMultiValuedBinding; return allowsMultiValuedBinding;

View File

@ -8,6 +8,7 @@ package org.hibernate.query.sqm.tree.expression;
import org.hibernate.query.sqm.NodeBuilder; import org.hibernate.query.sqm.NodeBuilder;
import org.hibernate.query.sqm.SemanticQueryWalker; import org.hibernate.query.sqm.SemanticQueryWalker;
import org.hibernate.query.sqm.SqmExpressible;
import org.hibernate.query.sqm.tree.SqmCopyContext; import org.hibernate.query.sqm.tree.SqmCopyContext;
import org.hibernate.query.sqm.tree.select.SqmSubQuery; import org.hibernate.query.sqm.tree.select.SqmSubQuery;
@ -23,6 +24,16 @@ public class SqmAny<T> extends AbstractSqmExpression<T> {
this.subquery = subquery; this.subquery = subquery;
} }
@Override
public SqmExpressible<T> getNodeType() {
return subquery.getNodeType();
}
@Override
public Integer getTupleLength() {
return subquery.getTupleLength();
}
@Override @Override
public SqmAny<T> copy(SqmCopyContext context) { public SqmAny<T> copy(SqmCopyContext context) {
final SqmAny<T> existing = context.getCopy( this ); final SqmAny<T> existing = context.getCopy( this );

View File

@ -8,6 +8,7 @@ package org.hibernate.query.sqm.tree.expression;
import org.hibernate.query.sqm.NodeBuilder; import org.hibernate.query.sqm.NodeBuilder;
import org.hibernate.query.sqm.SemanticQueryWalker; import org.hibernate.query.sqm.SemanticQueryWalker;
import org.hibernate.query.sqm.SqmExpressible;
import org.hibernate.query.sqm.tree.SqmCopyContext; import org.hibernate.query.sqm.tree.SqmCopyContext;
import org.hibernate.query.sqm.tree.select.SqmSubQuery; import org.hibernate.query.sqm.tree.select.SqmSubQuery;
@ -23,6 +24,16 @@ public class SqmEvery<T> extends AbstractSqmExpression<T> {
this.subquery = subquery; this.subquery = subquery;
} }
@Override
public SqmExpressible<T> getNodeType() {
return subquery.getNodeType();
}
@Override
public Integer getTupleLength() {
return subquery.getTupleLength();
}
@Override @Override
public SqmEvery<T> copy(SqmCopyContext context) { public SqmEvery<T> copy(SqmCopyContext context) {
final SqmEvery<T> existing = context.getCopy( this ); final SqmEvery<T> existing = context.getCopy( this );

View File

@ -58,6 +58,10 @@ public interface SqmExpression<T> extends SqmSelectableNode<T>, JpaExpression<T>
jpaSelectionConsumer.accept( this ); jpaSelectionConsumer.accept( this );
} }
default Integer getTupleLength() {
return 1;
}
@Override @Override
SqmExpression<Long> asLong(); SqmExpression<Long> asLong();

View File

@ -88,4 +88,9 @@ public class SqmNamedParameter<T> extends AbstractSqmParameter<T> {
? getName().compareTo( ( (SqmNamedParameter<?>) anotherParameter ).getName() ) ? getName().compareTo( ( (SqmNamedParameter<?>) anotherParameter ).getName() )
: -1; : -1;
} }
@Override
public Integer getTupleLength() {
return null;
}
} }

View File

@ -110,4 +110,8 @@ public class SqmTuple<T>
return getGroupedExpressions(); return getGroupedExpressions();
} }
@Override
public Integer getTupleLength() {
return groupedExpressions.size();
}
} }

View File

@ -149,6 +149,12 @@ public class SqmSubQuery<T> extends AbstractSqmSelectQuery<T> implements SqmSele
return statement; return statement;
} }
@Override
public Integer getTupleLength() {
final SqmSelectClause selectClause = getQuerySpec().getSelectClause();
return selectClause == null ? null : selectClause.getSelectionItems().size();
}
@Override @Override
public SqmCteStatement<?> getCteStatement(String cteLabel) { public SqmCteStatement<?> getCteStatement(String cteLabel) {
final SqmCteStatement<?> cteCriteria = super.getCteStatement( cteLabel ); final SqmCteStatement<?> cteCriteria = super.getCteStatement( cteLabel );

View File

@ -528,7 +528,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
log.info( "query against Department with a subquery on Salesperson in the APAC reqion..." ); log.info( "query against Department with a subquery on Salesperson in the APAC reqion..." );
List departments = session.createQuery( List departments = session.createQuery(
"select d from Department as d where d.id in (select s.department from Salesperson s where s.name = ?1)" "select d from Department as d where d in (select s.department from Salesperson s where s.name = ?1)"
).setParameter( 1, "steve" ).list(); ).setParameter( 1, "steve" ).list();
assertEquals( "Incorrect department count", 1, departments.size() ); assertEquals( "Incorrect department count", 1, departments.size() );
@ -537,7 +537,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
session.enableFilter( "region" ).setParameter( "region", "Foobar" ); session.enableFilter( "region" ).setParameter( "region", "Foobar" );
departments = session.createQuery( departments = session.createQuery(
"select d from Department as d where d.id in (select s.department from Salesperson s where s.name = ?1)" ) "select d from Department as d where d in (select s.department from Salesperson s where s.name = ?1)" )
.setParameter( 1, "steve" ) .setParameter( 1, "steve" )
.list(); .list();
@ -560,7 +560,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() ); session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() );
orders = session.createQuery( orders = session.createQuery(
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" ) "select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product.id in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" )
.setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list(); .setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list();
assertEquals( "Incorrect orders count", 1, orders.size() ); assertEquals( "Incorrect orders count", 1, orders.size() );
@ -577,7 +577,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
); );
orders = session.createQuery( orders = session.createQuery(
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" ) "select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p from Product p where p.name = ?2)) and o.buyer = ?3" )
.setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list(); .setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list();
assertEquals( "Incorrect orders count", 0, orders.size() ); assertEquals( "Incorrect orders count", 0, orders.size() );
@ -589,7 +589,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() ); session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() );
orders = session.createQuery( orders = session.createQuery(
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" ) "select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p from Product p where p.name = ?2)) and o.buyer = ?3" )
.setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list(); .setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list();
assertEquals( "Incorrect orders count", 1, orders.size() ); assertEquals( "Incorrect orders count", 1, orders.size() );
@ -601,7 +601,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() ); session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() );
orders = session.createQuery( orders = session.createQuery(
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" ) "select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p from Product p where p.name = ?2)) and o.buyer = ?3" )
.setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list(); .setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list();
assertEquals( "Incorrect orders count", 1, orders.size() ); assertEquals( "Incorrect orders count", 1, orders.size() );

View File

@ -0,0 +1,63 @@
package org.hibernate.orm.test.hql;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import org.hibernate.query.SemanticException;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.fail;
@SessionFactory
@DomainModel(annotatedClasses = SubqueryPredicateTypingTest.Book.class)
public class SubqueryPredicateTypingTest {
@Test void test(SessionFactoryScope scope) {
scope.inSession( s -> {
// these should succeed
s.createSelectionQuery("from Book where title in (select title from Book)").getResultList();
s.createSelectionQuery("from Book where title = any (select title from Book)").getResultList();
// test tuple length errors
try {
s.createSelectionQuery("from Book where title = any (select title, isbn from Book)").getResultList();
fail();
}
catch (SemanticException se) {}
try {
s.createSelectionQuery("from Book where title = every (select title, isbn from Book)").getResultList();
fail();
}
catch (SemanticException se) {}
try {
s.createSelectionQuery("from Book where title in (select title, isbn from Book)").getResultList();
fail();
}
catch (SemanticException se) {}
// test typing errors
try {
s.createSelectionQuery("from Book where 1 = any (select title from Book)").getResultList();
fail();
}
catch (SemanticException se) {}
try {
s.createSelectionQuery("from Book where 1 = every (select title from Book)").getResultList();
fail();
}
catch (SemanticException se) {}
try {
s.createSelectionQuery("from Book where 1 in (select title from Book)").getResultList();
fail();
}
catch (SemanticException se) {}
});
}
@Entity(name = "Book")
static class Book {
@Id String isbn;
String title;
}
}

View File

@ -61,7 +61,7 @@ public class LikeEscapeParameterTest {
session -> { session -> {
session.createQuery( session.createQuery(
"select s from TestEntity s where s.name like ?1 escape '\\'" + "select s from TestEntity s where s.name like ?1 escape '\\'" +
" or s in (" + " or s.id in (" +
" select distinct t.id from TestEntity t where t.name like ?2 escape '\\'" + " select distinct t.id from TestEntity t where t.name like ?2 escape '\\'" +
" )", " )",
TestEntity.class TestEntity.class