HHH-16802 typecheck subquery predicates (#6801)

This commit is contained in:
Gavin King 2023-06-15 11:51:08 +02:00 committed by GitHub
parent 9d052413fc
commit 96941f3775
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 155 additions and 17 deletions

View File

@ -2645,8 +2645,8 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
@SuppressWarnings({ "unchecked", "rawtypes" })
public SqmPredicate visitInPredicate(HqlParser.InPredicateContext ctx) {
final boolean negated = ctx.getChildCount() == 4;
final SqmExpression<?> testExpression = (SqmExpression<?>) ctx.getChild( 0 ).accept( this );
final HqlParser.InListContext inListContext = (HqlParser.InListContext) ctx.getChild( ctx.getChildCount() - 1 );
final SqmExpression<?> testExpression = (SqmExpression<?>) ctx.expression().accept( this );
final HqlParser.InListContext inListContext = ctx.inList();
if ( inListContext instanceof HqlParser.ExplicitTupleInListContext ) {
final HqlParser.ExplicitTupleInListContext tupleExpressionListContext = (HqlParser.ExplicitTupleInListContext) inListContext;
final int size = tupleExpressionListContext.getChildCount();
@ -2707,9 +2707,11 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
}
else if ( inListContext instanceof HqlParser.SubqueryInListContext ) {
final HqlParser.SubqueryInListContext subQueryOrParamInListContext = (HqlParser.SubqueryInListContext) inListContext;
final SqmSubQuery<?> subquery = visitSubquery(subQueryOrParamInListContext.subquery());
( (SqmCriteriaNodeBuilder) creationContext.getNodeBuilder() ).assertComparable( testExpression, subquery );
return new SqmInSubQueryPredicate(
testExpression,
visitSubquery( subQueryOrParamInListContext.subquery() ),
subquery,
negated,
creationContext.getNodeBuilder()
);
@ -2718,7 +2720,8 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
if ( getCreationOptions().useStrictJpaCompliance() ) {
throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.HQL_COLLECTION_FUNCTION );
}
final HqlParser.PersistentCollectionReferenceInListContext collectionReferenceInListContext = (HqlParser.PersistentCollectionReferenceInListContext) inListContext;
final HqlParser.PersistentCollectionReferenceInListContext collectionReferenceInListContext =
(HqlParser.PersistentCollectionReferenceInListContext) inListContext;
return new SqmInSubQueryPredicate<>(
testExpression,
createCollectionReferenceSubQuery(

View File

@ -2030,14 +2030,20 @@ public class SqmCriteriaNodeBuilder implements NodeBuilder, SqmCreationContext,
}
public void assertComparable(Expression<?> x, Expression<?> y) {
final SqmExpressible<?> lhsType = ( (SqmExpression<?>) x ).getNodeType();
final SqmExpressible<?> rhsType = ( (SqmExpression<?>) y ).getNodeType();
if ( !areTypesComparable( lhsType, rhsType ) ) {
SqmExpression<?> left = (SqmExpression<?>) x;
SqmExpression<?> right = (SqmExpression<?>) y;
if ( left.getTupleLength() != null && right.getTupleLength() != null
&& left.getTupleLength().intValue() != right.getTupleLength().intValue() ) {
throw new SemanticException( "Cannot compare tuples of different lengths" );
}
final SqmExpressible<?> leftType = left.getNodeType();
final SqmExpressible<?> rightType = right.getNodeType();
if ( !areTypesComparable( leftType, rightType ) ) {
throw new SemanticException(
String.format(
"Can't compare test expression of type [%s] with element of type [%s]",
lhsType,
rhsType
"Cannot compare left expression of type [%s] with right expression of type [%s]",
leftType,
rightType
)
);
}

View File

@ -6,6 +6,8 @@
*/
package org.hibernate.query.sqm.tree.domain;
import jakarta.persistence.metamodel.Attribute;
import jakarta.persistence.metamodel.SingularAttribute;
import org.hibernate.metamodel.model.domain.DomainType;
import org.hibernate.metamodel.model.domain.EmbeddableDomainType;
import org.hibernate.metamodel.model.domain.EntityDomainType;
@ -20,6 +22,8 @@ import org.hibernate.query.sqm.tree.SqmCopyContext;
import org.hibernate.spi.NavigablePath;
import org.hibernate.type.descriptor.java.JavaType;
import java.util.Set;
/**
* @author Steve Ebersole
*/
@ -74,6 +78,21 @@ public class SqmEmbeddedValuedSimplePath<T>
return this;
}
@Override
public Integer getTupleLength() {
final EmbeddableDomainType<?> sqmPathType = (EmbeddableDomainType<?>) getReferencedPathSource().getSqmPathType();
final Set<? extends SingularAttribute<?, ?>> attributes = sqmPathType.getSingularAttributes();
return length(attributes);
}
private int length(Set<? extends SingularAttribute<?, ?>> attributes) {
int length = 0;
for (Attribute<?, ?> attribute : attributes) {
length += get(attribute.getName()).getTupleLength();
}
return length;
}
@Override
public DomainType<T> getSqmType() {
return getReferencedPathSource().getSqmType();

View File

@ -78,6 +78,12 @@ public class JpaCriteriaParameter<T>
return null;
}
@Override
public Integer getTupleLength() {
// TODO: we should be able to do much better than this!
return null;
}
@Override
public boolean allowsMultiValuedBinding() {
return allowsMultiValuedBinding;

View File

@ -8,6 +8,7 @@ package org.hibernate.query.sqm.tree.expression;
import org.hibernate.query.sqm.NodeBuilder;
import org.hibernate.query.sqm.SemanticQueryWalker;
import org.hibernate.query.sqm.SqmExpressible;
import org.hibernate.query.sqm.tree.SqmCopyContext;
import org.hibernate.query.sqm.tree.select.SqmSubQuery;
@ -23,6 +24,16 @@ public class SqmAny<T> extends AbstractSqmExpression<T> {
this.subquery = subquery;
}
@Override
public SqmExpressible<T> getNodeType() {
return subquery.getNodeType();
}
@Override
public Integer getTupleLength() {
return subquery.getTupleLength();
}
@Override
public SqmAny<T> copy(SqmCopyContext context) {
final SqmAny<T> existing = context.getCopy( this );

View File

@ -8,6 +8,7 @@ package org.hibernate.query.sqm.tree.expression;
import org.hibernate.query.sqm.NodeBuilder;
import org.hibernate.query.sqm.SemanticQueryWalker;
import org.hibernate.query.sqm.SqmExpressible;
import org.hibernate.query.sqm.tree.SqmCopyContext;
import org.hibernate.query.sqm.tree.select.SqmSubQuery;
@ -23,6 +24,16 @@ public class SqmEvery<T> extends AbstractSqmExpression<T> {
this.subquery = subquery;
}
@Override
public SqmExpressible<T> getNodeType() {
return subquery.getNodeType();
}
@Override
public Integer getTupleLength() {
return subquery.getTupleLength();
}
@Override
public SqmEvery<T> copy(SqmCopyContext context) {
final SqmEvery<T> existing = context.getCopy( this );

View File

@ -58,6 +58,10 @@ public interface SqmExpression<T> extends SqmSelectableNode<T>, JpaExpression<T>
jpaSelectionConsumer.accept( this );
}
default Integer getTupleLength() {
return 1;
}
@Override
SqmExpression<Long> asLong();

View File

@ -88,4 +88,9 @@ public class SqmNamedParameter<T> extends AbstractSqmParameter<T> {
? getName().compareTo( ( (SqmNamedParameter<?>) anotherParameter ).getName() )
: -1;
}
@Override
public Integer getTupleLength() {
return null;
}
}

View File

@ -110,4 +110,8 @@ public class SqmTuple<T>
return getGroupedExpressions();
}
@Override
public Integer getTupleLength() {
return groupedExpressions.size();
}
}

View File

@ -149,6 +149,12 @@ public class SqmSubQuery<T> extends AbstractSqmSelectQuery<T> implements SqmSele
return statement;
}
@Override
public Integer getTupleLength() {
final SqmSelectClause selectClause = getQuerySpec().getSelectClause();
return selectClause == null ? null : selectClause.getSelectionItems().size();
}
@Override
public SqmCteStatement<?> getCteStatement(String cteLabel) {
final SqmCteStatement<?> cteCriteria = super.getCteStatement( cteLabel );

View File

@ -528,7 +528,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
log.info( "query against Department with a subquery on Salesperson in the APAC reqion..." );
List departments = session.createQuery(
"select d from Department as d where d.id in (select s.department from Salesperson s where s.name = ?1)"
"select d from Department as d where d in (select s.department from Salesperson s where s.name = ?1)"
).setParameter( 1, "steve" ).list();
assertEquals( "Incorrect department count", 1, departments.size() );
@ -537,7 +537,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
session.enableFilter( "region" ).setParameter( "region", "Foobar" );
departments = session.createQuery(
"select d from Department as d where d.id in (select s.department from Salesperson s where s.name = ?1)" )
"select d from Department as d where d in (select s.department from Salesperson s where s.name = ?1)" )
.setParameter( 1, "steve" )
.list();
@ -560,7 +560,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() );
orders = session.createQuery(
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" )
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product.id in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" )
.setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list();
assertEquals( "Incorrect orders count", 1, orders.size() );
@ -577,7 +577,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
);
orders = session.createQuery(
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" )
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p from Product p where p.name = ?2)) and o.buyer = ?3" )
.setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list();
assertEquals( "Incorrect orders count", 0, orders.size() );
@ -589,7 +589,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() );
orders = session.createQuery(
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" )
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p from Product p where p.name = ?2)) and o.buyer = ?3" )
.setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list();
assertEquals( "Incorrect orders count", 1, orders.size() );
@ -601,7 +601,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() );
orders = session.createQuery(
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" )
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p from Product p where p.name = ?2)) and o.buyer = ?3" )
.setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list();
assertEquals( "Incorrect orders count", 1, orders.size() );

View File

@ -0,0 +1,63 @@
package org.hibernate.orm.test.hql;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import org.hibernate.query.SemanticException;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.fail;
@SessionFactory
@DomainModel(annotatedClasses = SubqueryPredicateTypingTest.Book.class)
public class SubqueryPredicateTypingTest {
@Test void test(SessionFactoryScope scope) {
scope.inSession( s -> {
// these should succeed
s.createSelectionQuery("from Book where title in (select title from Book)").getResultList();
s.createSelectionQuery("from Book where title = any (select title from Book)").getResultList();
// test tuple length errors
try {
s.createSelectionQuery("from Book where title = any (select title, isbn from Book)").getResultList();
fail();
}
catch (SemanticException se) {}
try {
s.createSelectionQuery("from Book where title = every (select title, isbn from Book)").getResultList();
fail();
}
catch (SemanticException se) {}
try {
s.createSelectionQuery("from Book where title in (select title, isbn from Book)").getResultList();
fail();
}
catch (SemanticException se) {}
// test typing errors
try {
s.createSelectionQuery("from Book where 1 = any (select title from Book)").getResultList();
fail();
}
catch (SemanticException se) {}
try {
s.createSelectionQuery("from Book where 1 = every (select title from Book)").getResultList();
fail();
}
catch (SemanticException se) {}
try {
s.createSelectionQuery("from Book where 1 in (select title from Book)").getResultList();
fail();
}
catch (SemanticException se) {}
});
}
@Entity(name = "Book")
static class Book {
@Id String isbn;
String title;
}
}

View File

@ -61,7 +61,7 @@ public class LikeEscapeParameterTest {
session -> {
session.createQuery(
"select s from TestEntity s where s.name like ?1 escape '\\'" +
" or s in (" +
" or s.id in (" +
" select distinct t.id from TestEntity t where t.name like ?2 escape '\\'" +
" )",
TestEntity.class