HHH-16802 typecheck subquery predicates (#6801)
This commit is contained in:
parent
9d052413fc
commit
96941f3775
|
@ -2645,8 +2645,8 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
|
|||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public SqmPredicate visitInPredicate(HqlParser.InPredicateContext ctx) {
|
||||
final boolean negated = ctx.getChildCount() == 4;
|
||||
final SqmExpression<?> testExpression = (SqmExpression<?>) ctx.getChild( 0 ).accept( this );
|
||||
final HqlParser.InListContext inListContext = (HqlParser.InListContext) ctx.getChild( ctx.getChildCount() - 1 );
|
||||
final SqmExpression<?> testExpression = (SqmExpression<?>) ctx.expression().accept( this );
|
||||
final HqlParser.InListContext inListContext = ctx.inList();
|
||||
if ( inListContext instanceof HqlParser.ExplicitTupleInListContext ) {
|
||||
final HqlParser.ExplicitTupleInListContext tupleExpressionListContext = (HqlParser.ExplicitTupleInListContext) inListContext;
|
||||
final int size = tupleExpressionListContext.getChildCount();
|
||||
|
@ -2707,9 +2707,11 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
|
|||
}
|
||||
else if ( inListContext instanceof HqlParser.SubqueryInListContext ) {
|
||||
final HqlParser.SubqueryInListContext subQueryOrParamInListContext = (HqlParser.SubqueryInListContext) inListContext;
|
||||
final SqmSubQuery<?> subquery = visitSubquery(subQueryOrParamInListContext.subquery());
|
||||
( (SqmCriteriaNodeBuilder) creationContext.getNodeBuilder() ).assertComparable( testExpression, subquery );
|
||||
return new SqmInSubQueryPredicate(
|
||||
testExpression,
|
||||
visitSubquery( subQueryOrParamInListContext.subquery() ),
|
||||
subquery,
|
||||
negated,
|
||||
creationContext.getNodeBuilder()
|
||||
);
|
||||
|
@ -2718,7 +2720,8 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
|
|||
if ( getCreationOptions().useStrictJpaCompliance() ) {
|
||||
throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.HQL_COLLECTION_FUNCTION );
|
||||
}
|
||||
final HqlParser.PersistentCollectionReferenceInListContext collectionReferenceInListContext = (HqlParser.PersistentCollectionReferenceInListContext) inListContext;
|
||||
final HqlParser.PersistentCollectionReferenceInListContext collectionReferenceInListContext =
|
||||
(HqlParser.PersistentCollectionReferenceInListContext) inListContext;
|
||||
return new SqmInSubQueryPredicate<>(
|
||||
testExpression,
|
||||
createCollectionReferenceSubQuery(
|
||||
|
|
|
@ -2030,14 +2030,20 @@ public class SqmCriteriaNodeBuilder implements NodeBuilder, SqmCreationContext,
|
|||
}
|
||||
|
||||
public void assertComparable(Expression<?> x, Expression<?> y) {
|
||||
final SqmExpressible<?> lhsType = ( (SqmExpression<?>) x ).getNodeType();
|
||||
final SqmExpressible<?> rhsType = ( (SqmExpression<?>) y ).getNodeType();
|
||||
if ( !areTypesComparable( lhsType, rhsType ) ) {
|
||||
SqmExpression<?> left = (SqmExpression<?>) x;
|
||||
SqmExpression<?> right = (SqmExpression<?>) y;
|
||||
if ( left.getTupleLength() != null && right.getTupleLength() != null
|
||||
&& left.getTupleLength().intValue() != right.getTupleLength().intValue() ) {
|
||||
throw new SemanticException( "Cannot compare tuples of different lengths" );
|
||||
}
|
||||
final SqmExpressible<?> leftType = left.getNodeType();
|
||||
final SqmExpressible<?> rightType = right.getNodeType();
|
||||
if ( !areTypesComparable( leftType, rightType ) ) {
|
||||
throw new SemanticException(
|
||||
String.format(
|
||||
"Can't compare test expression of type [%s] with element of type [%s]",
|
||||
lhsType,
|
||||
rhsType
|
||||
"Cannot compare left expression of type [%s] with right expression of type [%s]",
|
||||
leftType,
|
||||
rightType
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
*/
|
||||
package org.hibernate.query.sqm.tree.domain;
|
||||
|
||||
import jakarta.persistence.metamodel.Attribute;
|
||||
import jakarta.persistence.metamodel.SingularAttribute;
|
||||
import org.hibernate.metamodel.model.domain.DomainType;
|
||||
import org.hibernate.metamodel.model.domain.EmbeddableDomainType;
|
||||
import org.hibernate.metamodel.model.domain.EntityDomainType;
|
||||
|
@ -20,6 +22,8 @@ import org.hibernate.query.sqm.tree.SqmCopyContext;
|
|||
import org.hibernate.spi.NavigablePath;
|
||||
import org.hibernate.type.descriptor.java.JavaType;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* @author Steve Ebersole
|
||||
*/
|
||||
|
@ -74,6 +78,21 @@ public class SqmEmbeddedValuedSimplePath<T>
|
|||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getTupleLength() {
|
||||
final EmbeddableDomainType<?> sqmPathType = (EmbeddableDomainType<?>) getReferencedPathSource().getSqmPathType();
|
||||
final Set<? extends SingularAttribute<?, ?>> attributes = sqmPathType.getSingularAttributes();
|
||||
return length(attributes);
|
||||
}
|
||||
|
||||
private int length(Set<? extends SingularAttribute<?, ?>> attributes) {
|
||||
int length = 0;
|
||||
for (Attribute<?, ?> attribute : attributes) {
|
||||
length += get(attribute.getName()).getTupleLength();
|
||||
}
|
||||
return length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DomainType<T> getSqmType() {
|
||||
return getReferencedPathSource().getSqmType();
|
||||
|
|
|
@ -78,6 +78,12 @@ public class JpaCriteriaParameter<T>
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getTupleLength() {
|
||||
// TODO: we should be able to do much better than this!
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean allowsMultiValuedBinding() {
|
||||
return allowsMultiValuedBinding;
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.hibernate.query.sqm.tree.expression;
|
|||
|
||||
import org.hibernate.query.sqm.NodeBuilder;
|
||||
import org.hibernate.query.sqm.SemanticQueryWalker;
|
||||
import org.hibernate.query.sqm.SqmExpressible;
|
||||
import org.hibernate.query.sqm.tree.SqmCopyContext;
|
||||
import org.hibernate.query.sqm.tree.select.SqmSubQuery;
|
||||
|
||||
|
@ -23,6 +24,16 @@ public class SqmAny<T> extends AbstractSqmExpression<T> {
|
|||
this.subquery = subquery;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SqmExpressible<T> getNodeType() {
|
||||
return subquery.getNodeType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getTupleLength() {
|
||||
return subquery.getTupleLength();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SqmAny<T> copy(SqmCopyContext context) {
|
||||
final SqmAny<T> existing = context.getCopy( this );
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.hibernate.query.sqm.tree.expression;
|
|||
|
||||
import org.hibernate.query.sqm.NodeBuilder;
|
||||
import org.hibernate.query.sqm.SemanticQueryWalker;
|
||||
import org.hibernate.query.sqm.SqmExpressible;
|
||||
import org.hibernate.query.sqm.tree.SqmCopyContext;
|
||||
import org.hibernate.query.sqm.tree.select.SqmSubQuery;
|
||||
|
||||
|
@ -23,6 +24,16 @@ public class SqmEvery<T> extends AbstractSqmExpression<T> {
|
|||
this.subquery = subquery;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SqmExpressible<T> getNodeType() {
|
||||
return subquery.getNodeType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getTupleLength() {
|
||||
return subquery.getTupleLength();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SqmEvery<T> copy(SqmCopyContext context) {
|
||||
final SqmEvery<T> existing = context.getCopy( this );
|
||||
|
|
|
@ -58,6 +58,10 @@ public interface SqmExpression<T> extends SqmSelectableNode<T>, JpaExpression<T>
|
|||
jpaSelectionConsumer.accept( this );
|
||||
}
|
||||
|
||||
default Integer getTupleLength() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
SqmExpression<Long> asLong();
|
||||
|
||||
|
|
|
@ -88,4 +88,9 @@ public class SqmNamedParameter<T> extends AbstractSqmParameter<T> {
|
|||
? getName().compareTo( ( (SqmNamedParameter<?>) anotherParameter ).getName() )
|
||||
: -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getTupleLength() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -110,4 +110,8 @@ public class SqmTuple<T>
|
|||
return getGroupedExpressions();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getTupleLength() {
|
||||
return groupedExpressions.size();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -149,6 +149,12 @@ public class SqmSubQuery<T> extends AbstractSqmSelectQuery<T> implements SqmSele
|
|||
return statement;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getTupleLength() {
|
||||
final SqmSelectClause selectClause = getQuerySpec().getSelectClause();
|
||||
return selectClause == null ? null : selectClause.getSelectionItems().size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SqmCteStatement<?> getCteStatement(String cteLabel) {
|
||||
final SqmCteStatement<?> cteCriteria = super.getCteStatement( cteLabel );
|
||||
|
|
|
@ -528,7 +528,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
|
|||
log.info( "query against Department with a subquery on Salesperson in the APAC reqion..." );
|
||||
|
||||
List departments = session.createQuery(
|
||||
"select d from Department as d where d.id in (select s.department from Salesperson s where s.name = ?1)"
|
||||
"select d from Department as d where d in (select s.department from Salesperson s where s.name = ?1)"
|
||||
).setParameter( 1, "steve" ).list();
|
||||
|
||||
assertEquals( "Incorrect department count", 1, departments.size() );
|
||||
|
@ -537,7 +537,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
|
|||
|
||||
session.enableFilter( "region" ).setParameter( "region", "Foobar" );
|
||||
departments = session.createQuery(
|
||||
"select d from Department as d where d.id in (select s.department from Salesperson s where s.name = ?1)" )
|
||||
"select d from Department as d where d in (select s.department from Salesperson s where s.name = ?1)" )
|
||||
.setParameter( 1, "steve" )
|
||||
.list();
|
||||
|
||||
|
@ -560,7 +560,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
|
|||
session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() );
|
||||
|
||||
orders = session.createQuery(
|
||||
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" )
|
||||
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product.id in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" )
|
||||
.setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list();
|
||||
|
||||
assertEquals( "Incorrect orders count", 1, orders.size() );
|
||||
|
@ -577,7 +577,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
|
|||
);
|
||||
|
||||
orders = session.createQuery(
|
||||
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" )
|
||||
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p from Product p where p.name = ?2)) and o.buyer = ?3" )
|
||||
.setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list();
|
||||
|
||||
assertEquals( "Incorrect orders count", 0, orders.size() );
|
||||
|
@ -589,7 +589,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
|
|||
session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() );
|
||||
|
||||
orders = session.createQuery(
|
||||
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" )
|
||||
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p from Product p where p.name = ?2)) and o.buyer = ?3" )
|
||||
.setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list();
|
||||
|
||||
assertEquals( "Incorrect orders count", 1, orders.size() );
|
||||
|
@ -601,7 +601,7 @@ public class DynamicFilterTest extends BaseNonConfigCoreFunctionalTestCase {
|
|||
session.enableFilter( "effectiveDate" ).setParameter( "asOfDate", testData.lastMonth.getTime() );
|
||||
|
||||
orders = session.createQuery(
|
||||
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p.id from Product p where p.name = ?2)) and o.buyer = ?3" )
|
||||
"select o from Order as o where exists (select li.id from LineItem li where li.quantity >= ?1 and li.product in (select p from Product p where p.name = ?2)) and o.buyer = ?3" )
|
||||
.setParameter( 1, 1L ).setParameter( 2, "Acme Hair Gel" ).setParameter( 3, "gavin" ).list();
|
||||
|
||||
assertEquals( "Incorrect orders count", 1, orders.size() );
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
package org.hibernate.orm.test.hql;
|
||||
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.Id;
|
||||
import org.hibernate.query.SemanticException;
|
||||
import org.hibernate.testing.orm.junit.DomainModel;
|
||||
import org.hibernate.testing.orm.junit.SessionFactory;
|
||||
import org.hibernate.testing.orm.junit.SessionFactoryScope;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
@SessionFactory
|
||||
@DomainModel(annotatedClasses = SubqueryPredicateTypingTest.Book.class)
|
||||
public class SubqueryPredicateTypingTest {
|
||||
@Test void test(SessionFactoryScope scope) {
|
||||
scope.inSession( s -> {
|
||||
// these should succeed
|
||||
s.createSelectionQuery("from Book where title in (select title from Book)").getResultList();
|
||||
s.createSelectionQuery("from Book where title = any (select title from Book)").getResultList();
|
||||
|
||||
// test tuple length errors
|
||||
try {
|
||||
s.createSelectionQuery("from Book where title = any (select title, isbn from Book)").getResultList();
|
||||
fail();
|
||||
}
|
||||
catch (SemanticException se) {}
|
||||
try {
|
||||
s.createSelectionQuery("from Book where title = every (select title, isbn from Book)").getResultList();
|
||||
fail();
|
||||
}
|
||||
catch (SemanticException se) {}
|
||||
try {
|
||||
s.createSelectionQuery("from Book where title in (select title, isbn from Book)").getResultList();
|
||||
fail();
|
||||
}
|
||||
catch (SemanticException se) {}
|
||||
|
||||
// test typing errors
|
||||
try {
|
||||
s.createSelectionQuery("from Book where 1 = any (select title from Book)").getResultList();
|
||||
fail();
|
||||
}
|
||||
catch (SemanticException se) {}
|
||||
try {
|
||||
s.createSelectionQuery("from Book where 1 = every (select title from Book)").getResultList();
|
||||
fail();
|
||||
}
|
||||
catch (SemanticException se) {}
|
||||
try {
|
||||
s.createSelectionQuery("from Book where 1 in (select title from Book)").getResultList();
|
||||
fail();
|
||||
}
|
||||
catch (SemanticException se) {}
|
||||
});
|
||||
}
|
||||
|
||||
@Entity(name = "Book")
|
||||
static class Book {
|
||||
@Id String isbn;
|
||||
String title;
|
||||
}
|
||||
}
|
|
@ -61,7 +61,7 @@ public class LikeEscapeParameterTest {
|
|||
session -> {
|
||||
session.createQuery(
|
||||
"select s from TestEntity s where s.name like ?1 escape '\\'" +
|
||||
" or s in (" +
|
||||
" or s.id in (" +
|
||||
" select distinct t.id from TestEntity t where t.name like ?2 escape '\\'" +
|
||||
" )",
|
||||
TestEntity.class
|
||||
|
|
Loading…
Reference in New Issue