diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java index bb93a6d7db..19e557ae5f 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java @@ -434,7 +434,7 @@ public abstract class BaseSqmToSqlAstConverter extends Base */ private Map>> trackedFetchSelectionsForGroup = Collections.emptyMap(); - private final Map collectionFilterPredicates = new HashMap<>(); + private final Map collectionFilterPredicates = new IdentityHashMap<>(); private List> orderByFragments; private final SqlAliasBaseManager sqlAliasBaseManager = new SqlAliasBaseManager(); @@ -1838,18 +1838,27 @@ private TableGroup findTableGroupByPath(NavigablePath navigablePath) { } protected void applyCollectionFilterPredicates(QuerySpec sqlQuerySpec) { - final List roots = sqlQuerySpec.getFromClause().getRoots(); - if ( roots != null && roots.size() == 1 ) { - final TableGroup root = roots.get( 0 ); - - if ( CollectionHelper.isNotEmpty( collectionFilterPredicates ) ) { - root.getTableGroupJoins().forEach( (tableGroupJoin) -> - collectionFilterPredicates.forEach( (alias, predicates) -> { - if ( tableGroupJoin.getJoinedGroup().getGroupAlias().equals( alias ) ) { - tableGroupJoin.applyPredicate( predicates.getPredicate() ); - } - } ) - ); + if ( CollectionHelper.isNotEmpty( collectionFilterPredicates ) ) { + final FromClauseAccess fromClauseAccess = getFromClauseAccess(); + OUTER: + for ( Map.Entry entry : collectionFilterPredicates.entrySet() ) { + final TableGroup parentTableGroup = fromClauseAccess.findTableGroup( entry.getKey().getParent() ); + if ( parentTableGroup == null ) { + // Since we only keep a single map, this could return null for collections of subqueries + continue; + } + for ( TableGroupJoin tableGroupJoin : parentTableGroup.getTableGroupJoins() ) { + if ( tableGroupJoin.getJoinedGroup().getNavigablePath() == entry.getKey() ) { + tableGroupJoin.applyPredicate( entry.getValue().getPredicate() ); + continue OUTER; + } + } + for ( TableGroupJoin tableGroupJoin : parentTableGroup.getNestedTableGroupJoins() ) { + if ( tableGroupJoin.getJoinedGroup().getNavigablePath() == entry.getKey() ) { + tableGroupJoin.applyPredicate( entry.getValue().getPredicate() ); + continue OUTER; + } + } } } } @@ -2347,6 +2356,10 @@ protected void consumeFromClauseCorrelatedRoot(SqmRoot sqmRoot) { fromClauseIndex.register( from, tableGroup ); registerPluralTableGroupParts( tableGroup ); + // Note that we do not need to register the correlated table group to the from clause + // because that is never "rendered" in the subquery anyway. + // Any table group joins added to the correlated table group are added to the query spec + // as roots anyway, so nothing to worry about log.tracef( "Resolved SqmRoot [%s] to correlated TableGroup [%s]", sqmRoot, tableGroup ); consumeExplicitJoins( from, tableGroup ); @@ -2516,8 +2529,15 @@ protected void registerTreatUsage(SqmFrom sqmFrom, TableGroup tableGroup) else { return; } + final TableGroup actualTableGroup; + if ( tableGroup instanceof PluralTableGroup ) { + actualTableGroup = ( (PluralTableGroup) tableGroup ).getElementTableGroup(); + } + else { + actualTableGroup = tableGroup; + } final Set treatedEntityNames = tableGroupTreatUsages.computeIfAbsent( - tableGroup, + actualTableGroup, tg -> new HashSet<>( 1 ) ); treatedEntityNames.add( treatedType.getHibernateEntityName() ); @@ -2545,7 +2565,7 @@ protected void registerTypeUsage(DiscriminatorSqmPath path) { } final int subclassTableSpan = persister.getSubclassTableSpan(); for ( int i = 0; i < subclassTableSpan; i++ ) { - tableGroup.resolveTableReference( persister.getSubclassTableName( i ) ); + tableGroup.resolveTableReference( null, persister.getSubclassTableName( i ), false ); } } @@ -2669,11 +2689,11 @@ private TableGroup consumeAttributeJoin( pluralAttributeMapping.applyBaseRestrictions( (predicate) -> { - final PredicateCollector existing = collectionFilterPredicates.get( joinedTableGroup.getGroupAlias() ); + final PredicateCollector existing = collectionFilterPredicates.get( joinedTableGroup.getNavigablePath() ); final PredicateCollector collector; if ( existing == null ) { collector = new PredicateCollector( predicate ); - collectionFilterPredicates.put( joinedTableGroup.getGroupAlias(), collector ); + collectionFilterPredicates.put( joinedTableGroup.getNavigablePath(), collector ); } else { collector = existing; @@ -6754,7 +6774,7 @@ private Fetch buildFetch( .getCollectionType() .getAssociatedJoinable( getCreationContext().getSessionFactory() ); joinable.applyBaseRestrictions( - (predicate) -> addCollectionFilterPredicate( tableGroup.getGroupAlias(), predicate ), + (predicate) -> addCollectionFilterPredicate( tableGroup.getNavigablePath(), predicate ), tableGroup, true, getLoadQueryInfluencers().getEnabledFilters(), @@ -6804,13 +6824,13 @@ private Fetch buildFetch( } } - private void addCollectionFilterPredicate(String groupAlias, Predicate predicate) { - final PredicateCollector existing = collectionFilterPredicates.get( groupAlias ); + private void addCollectionFilterPredicate(NavigablePath navigablePath, Predicate predicate) { + final PredicateCollector existing = collectionFilterPredicates.get( navigablePath ); if ( existing != null ) { existing.applyPredicate( predicate ); } else { - collectionFilterPredicates.put( groupAlias, new PredicateCollector( predicate ) ); + collectionFilterPredicates.put( navigablePath, new PredicateCollector( predicate ) ); } }