HHH-16472 Allow null discriminators for treated left/full joins

This commit is contained in:
Marco Belladelli 2023-05-15 14:57:15 +02:00
parent 04684da054
commit a8fe62ebb3
1 changed files with 37 additions and 6 deletions

View File

@ -4983,7 +4983,6 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
if ( entityNames.isEmpty() ) { if ( entityNames.isEmpty() ) {
continue; continue;
} }
registerTypeUsage( tableGroup );
final ModelPartContainer modelPart = tableGroup.getModelPart(); final ModelPartContainer modelPart = tableGroup.getModelPart();
final EntityMappingType entityMapping; final EntityMappingType entityMapping;
@ -5000,12 +4999,17 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
tableGroup, tableGroup,
this this
); );
// We need to check if this is a treated left or full join, which case we should
// allow null discriminator values to maintain correct semantics
final TableGroupJoin join = getParentTableGroupJoin( tableGroup );
final boolean allowNulls = join != null && ( join.getJoinType() == SqlAstJoinType.LEFT || join.getJoinType() == SqlAstJoinType.FULL );
registerTypeUsage( tableGroup ); registerTypeUsage( tableGroup );
predicate = combinePredicates( predicate = combinePredicates(
predicate, predicate,
createTreatTypeRestriction( createTreatTypeRestriction(
typeExpression, typeExpression,
entityNames entityNames,
allowNulls
) )
); );
} }
@ -5013,6 +5017,20 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
return predicate; return predicate;
} }
private TableGroupJoin getParentTableGroupJoin(TableGroup tableGroup) {
final NavigablePath parentNavigablePath = tableGroup.getNavigablePath().getParent();
if ( parentNavigablePath != null ) {
final TableGroup parentTableGroup = getFromClauseIndex().findTableGroup( parentNavigablePath );
if ( parentTableGroup instanceof PluralTableGroup ) {
return getParentTableGroupJoin( parentTableGroup );
}
else if ( parentTableGroup != null ) {
return parentTableGroup.findTableGroupJoin( tableGroup );
}
}
return null;
}
private Set<String> determineEntityNamesForTreatTypeRestriction( private Set<String> determineEntityNamesForTreatTypeRestriction(
EntityMappingType partMappingType, EntityMappingType partMappingType,
Map<String, EntityNameUse> entityNameUses) { Map<String, EntityNameUse> entityNameUses) {
@ -5079,13 +5097,18 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
registerTypeUsage( discriminatorSqmPath ); registerTypeUsage( discriminatorSqmPath );
return createTreatTypeRestriction( return createTreatTypeRestriction(
DiscriminatorPathInterpretation.from( discriminatorSqmPath, this ), DiscriminatorPathInterpretation.from( discriminatorSqmPath, this ),
subclassEntityNames subclassEntityNames,
false
); );
} }
private Predicate createTreatTypeRestriction(Expression typeExpression, Set<String> subclassEntityNames) { private Predicate createTreatTypeRestriction(
Expression typeExpression,
Set<String> subclassEntityNames,
boolean allowNulls) {
final Predicate discriminatorPredicate;
if ( subclassEntityNames.size() == 1 ) { if ( subclassEntityNames.size() == 1 ) {
return new ComparisonPredicate( discriminatorPredicate = new ComparisonPredicate(
typeExpression, typeExpression,
ComparisonOperator.EQUAL, ComparisonOperator.EQUAL,
new EntityTypeLiteral( domainModel.findEntityDescriptor( subclassEntityNames.iterator().next() ) ) new EntityTypeLiteral( domainModel.findEntityDescriptor( subclassEntityNames.iterator().next() ) )
@ -5096,8 +5119,16 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
for ( String subclassEntityName : subclassEntityNames ) { for ( String subclassEntityName : subclassEntityNames ) {
typeLiterals.add( new EntityTypeLiteral( domainModel.findEntityDescriptor( subclassEntityName ) ) ); typeLiterals.add( new EntityTypeLiteral( domainModel.findEntityDescriptor( subclassEntityName ) ) );
} }
return new InListPredicate( typeExpression, typeLiterals ); discriminatorPredicate = new InListPredicate( typeExpression, typeLiterals );
} }
if ( allowNulls ) {
return new Junction(
Junction.Nature.DISJUNCTION,
List.of( discriminatorPredicate, new NullnessPredicate( typeExpression ) ),
getBooleanType()
);
}
return discriminatorPredicate;
} }
private MappingModelExpressible<?> resolveInferredType() { private MappingModelExpressible<?> resolveInferredType() {