add to the root query the PK columns of the subclasses tables

This commit is contained in:
Andrea Boriero 2019-11-07 11:29:48 +00:00
parent 0a6fd5ba46
commit 0f2e5dca8c
5 changed files with 67 additions and 32 deletions

View File

@ -6,10 +6,13 @@
*/
package org.hibernate.metamodel.mapping.internal;
import java.util.List;
import org.hibernate.persister.entity.EntityPersister;
import org.hibernate.query.sqm.sql.SqlExpressionResolver;
import org.hibernate.sql.ast.spi.SqlSelection;
import org.hibernate.sql.ast.tree.expression.CaseSearchedExpression;
import org.hibernate.sql.ast.tree.expression.ColumnReference;
import org.hibernate.sql.ast.tree.from.TableGroup;
import org.hibernate.sql.results.spi.DomainResultCreationState;
import org.hibernate.type.BasicType;
@ -20,21 +23,36 @@ import org.hibernate.type.BasicType;
public class JoinedSubclassDiscriminatorMappingImpl extends EntityDiscriminatorMappingImpl {
private final CaseSearchedExpression caseSearchedExpression;
private final List<ColumnReference> columnReferences;
public JoinedSubclassDiscriminatorMappingImpl(
EntityPersister entityDescriptor,
String tableExpression,
CaseSearchedExpression caseSearchedExpression,
List<ColumnReference> columnReferences,
BasicType mappingType) {
super( entityDescriptor, tableExpression, mappingType );
this.caseSearchedExpression = caseSearchedExpression;
}
this.columnReferences = columnReferences;
}
@Override
protected SqlSelection resolveSqlSelection(TableGroup tableGroup, DomainResultCreationState creationState) {
final SqlExpressionResolver expressionResolver = creationState.getSqlAstCreationState()
.getSqlExpressionResolver();
// need to add the columns of the ids used in the case expression
columnReferences.forEach(
columnReference ->
expressionResolver.resolveSqlSelection(
columnReference,
getMappedTypeDescriptor().getMappedJavaTypeDescriptor(),
creationState.getSqlAstCreationState()
.getCreationContext()
.getDomainModel()
.getTypeConfiguration()
)
);
return expressionResolver.resolveSqlSelection(
expressionResolver.resolveSqlExpression(

View File

@ -62,16 +62,15 @@ import org.hibernate.sql.ast.tree.expression.QueryLiteral;
import org.hibernate.sql.ast.tree.from.TableGroup;
import org.hibernate.sql.ast.tree.from.TableReference;
import org.hibernate.sql.ast.tree.from.TableReferenceJoin;
import org.hibernate.sql.ast.tree.predicate.CasePredicate;
import org.hibernate.sql.ast.tree.predicate.NullnessPredicate;
import org.hibernate.sql.ast.tree.predicate.Predicate;
import org.hibernate.sql.results.internal.domain.entity.JoinedSubclassResultImpl;
import org.hibernate.sql.results.spi.DomainResult;
import org.hibernate.sql.results.spi.DomainResultCreationState;
import org.hibernate.sql.results.spi.Fetchable;
import org.hibernate.type.BasicType;
import org.hibernate.type.DiscriminatorType;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.StringType;
import org.hibernate.type.Type;
import org.jboss.logging.Logger;
@ -1196,15 +1195,24 @@ public class JoinedSubclassEntityPersister extends AbstractEntityPersister {
}
public EntityDiscriminatorMapping getDiscriminatorMapping(TableGroup tableGroup) {
CaseSearchedExpressionInfo info = getCaseSearchedExpression( tableGroup );
return new JoinedSubclassDiscriminatorMappingImpl(
this,
getRootTableName(),
getCaseSearchedExpression( tableGroup ),
info.caseSearchedExpression,
info.columnReferences,
(BasicType) getDiscriminatorType()
);
}
private CaseSearchedExpression getCaseSearchedExpression(TableGroup entityTableGroup) {
private class CaseSearchedExpressionInfo{
CaseSearchedExpression caseSearchedExpression;
List<ColumnReference> columnReferences = new ArrayList<>( );
}
private CaseSearchedExpressionInfo getCaseSearchedExpression(TableGroup entityTableGroup) {
CaseSearchedExpressionInfo info = new CaseSearchedExpressionInfo();
final TableReference primaryTableReference = entityTableGroup.getPrimaryTableReference();
final List<TableReferenceJoin> tableReferenceJoins = entityTableGroup.getTableReferenceJoins();
final BasicType discriminatorType = (BasicType) getDiscriminatorType();
@ -1215,28 +1223,26 @@ public class JoinedSubclassEntityPersister extends AbstractEntityPersister {
tableReferenceJoins.forEach(
tableReferenceJoin -> {
final TableReference joinedTableReference = tableReferenceJoin.getJoinedTableReference();
final EntityPersister entityDescriptor = getFactory().getMetamodel()
.findEntityDescriptor( subclassNameByTableName.get( joinedTableReference.getTableExpression() ) );
if ( entityDescriptor instanceof JoinedSubclassEntityPersister ) {
addWhen(
caseSearchedExpression,
joinedTableReference,
( (JoinedSubclassEntityPersister) entityDescriptor )
.getIdentifierColumnReferenceForCaseExpression( joinedTableReference ),
discriminatorType
);
}
final ColumnReference identifierColumnReference = getIdentifierColumnReference( joinedTableReference );
info.columnReferences.add( identifierColumnReference );
addWhen(
caseSearchedExpression,
joinedTableReference,
identifierColumnReference,
discriminatorType
);
}
);
addWhen(
caseSearchedExpression,
primaryTableReference,
getIdentifierColumnReferenceForCaseExpression( primaryTableReference ),
getIdentifierColumnReference( primaryTableReference ),
discriminatorType
);
return caseSearchedExpression;
info.caseSearchedExpression = caseSearchedExpression;
return info;
}
private void addWhen(
@ -1244,7 +1250,7 @@ public class JoinedSubclassEntityPersister extends AbstractEntityPersister {
TableReference table,
ColumnReference identifierColumnReference,
BasicType resultType) {
final Predicate predicate = new NullnessPredicate( identifierColumnReference, true );
final CasePredicate predicate = new NullnessPredicate( identifierColumnReference, true );
final Expression expression =
new QueryLiteral<>(
discriminatorValuesByTableName.get( table.getTableExpression() ),
@ -1255,14 +1261,13 @@ public class JoinedSubclassEntityPersister extends AbstractEntityPersister {
caseSearchedExpression.when( predicate, expression );
}
private ColumnReference getIdentifierColumnReferenceForCaseExpression(TableReference primaryTableReference) {
List<JdbcMapping> jdbcMappings = getIdentifierMapping().getJdbcMappings( getFactory().getTypeConfiguration() );
JdbcMapping jdbcMapping = jdbcMappings.get( 0 );
private ColumnReference getIdentifierColumnReference(TableReference tableReference) {
final List<JdbcMapping> jdbcMappings = getIdentifierMapping().getJdbcMappings( getFactory().getTypeConfiguration() );
return new ColumnReference(
primaryTableReference.getIdentificationVariable(),
tableReference.getIdentificationVariable(),
getIdentifierColumnNames()[0],
jdbcMapping,
jdbcMappings.get( 0 ),
getFactory()
);
}

View File

@ -112,6 +112,7 @@ import org.hibernate.sql.ast.tree.from.TableGroup;
import org.hibernate.sql.ast.tree.from.TableGroupJoin;
import org.hibernate.sql.ast.tree.from.TableGroupJoinProducer;
import org.hibernate.sql.ast.tree.predicate.BetweenPredicate;
import org.hibernate.sql.ast.tree.predicate.CasePredicate;
import org.hibernate.sql.ast.tree.predicate.ComparisonPredicate;
import org.hibernate.sql.ast.tree.predicate.GroupedPredicate;
import org.hibernate.sql.ast.tree.predicate.InListPredicate;
@ -1371,7 +1372,7 @@ public abstract class BaseSqmToSqlAstConverter
for ( SqmCaseSearched.WhenFragment whenFragment : expression.getWhenFragments() ) {
result.when(
(Predicate) whenFragment.getPredicate().accept( this ),
(CasePredicate) whenFragment.getPredicate().accept( this ),
(Expression) whenFragment.getResult().accept( this )
);
}

View File

@ -25,15 +25,22 @@ public class DecodeCaseExpressionWalker implements CaseExpressionWalker {
List<CaseSearchedExpression.WhenFragment> whenFragments = caseSearchedExpression.getWhenFragments();
int caseNumber = whenFragments.size();
CaseSearchedExpression.WhenFragment firstWhenFragment = null;
for ( int i = 0; i < caseNumber; i++ ) {
final CaseSearchedExpression.WhenFragment whenFragment = whenFragments.get( i );
if ( i != 0 ) {
sqlBuffer.append( ", " );
whenFragment.getPredicate().getLeftHandExpression().accept( sqlAstWalker );
sqlBuffer.append( ", " );
whenFragment.getResult().accept( sqlAstWalker );
}
else {
whenFragment.getPredicate().getLeftHandExpression().accept( sqlAstWalker );
firstWhenFragment = whenFragment;
}
whenFragment.getPredicate().getLeftHandExpression().accept( sqlAstWalker );
sqlBuffer.append( ", " );
whenFragment.getResult().accept( sqlAstWalker );
}
sqlBuffer.append( ", " );
firstWhenFragment.getResult().accept( sqlAstWalker );
Expression otherwise = caseSearchedExpression.getOtherwise();
if ( otherwise != null ) {
@ -41,7 +48,7 @@ public class DecodeCaseExpressionWalker implements CaseExpressionWalker {
otherwise.accept( sqlAstWalker );
}
sqlBuffer.append( "')'" );
sqlBuffer.append( ')' );
final String columnExpression = caseSearchedExpression.getColumnExpression();

View File

@ -18,7 +18,6 @@ import org.hibernate.persister.entity.EntityPersister;
import org.hibernate.persister.entity.JoinedSubclassEntityPersister;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.FailureExpected;
import org.hibernate.testing.orm.junit.ServiceRegistry;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
@ -30,6 +29,7 @@ import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertTrue;
/**
* @author Andrea Boriero
@ -88,22 +88,26 @@ public class JoinedInheritanceTest {
).list();
assertThat( results.size(), is( 2 ) );
boolean foundDomesticCustomer = false;
boolean foundForeignCustomer = false;
for ( Customer result : results ) {
if ( result.getId() == 1 ) {
assertThat( result, instanceOf( DomesticCustomer.class ) );
final DomesticCustomer customer = (DomesticCustomer) result;
assertThat( customer.getName(), is( "domestic" ) );
assertThat( ( customer ).getTaxId(), is( "123" ) );
foundDomesticCustomer = true;
}
else {
assertThat( result.getId(), is( 2 ) );
final ForeignCustomer customer = (ForeignCustomer) result;
assertThat( customer.getName(), is( "foreign" ) );
assertThat( ( customer ).getVat(), is( "987" ) );
foundForeignCustomer = true;
}
}
assertTrue( foundDomesticCustomer );
assertTrue( foundForeignCustomer );
}
}
);