From 5c8849c82422d3771e33ab800bca1f24e963878e Mon Sep 17 00:00:00 2001 From: Andrea Boriero Date: Fri, 18 Jun 2021 11:10:23 +0200 Subject: [PATCH] Fix issue with sql rendering of null discriminators --- .../entity/SingleTableEntityPersister.java | 52 +++++++++++++++---- .../sql/ast/spi/AbstractSqlAstTranslator.java | 35 +++++++++---- 2 files changed, 65 insertions(+), 22 deletions(-) diff --git a/hibernate-core/src/main/java/org/hibernate/persister/entity/SingleTableEntityPersister.java b/hibernate-core/src/main/java/org/hibernate/persister/entity/SingleTableEntityPersister.java index 0049a3943c..c594733a63 100644 --- a/hibernate-core/src/main/java/org/hibernate/persister/entity/SingleTableEntityPersister.java +++ b/hibernate-core/src/main/java/org/hibernate/persister/entity/SingleTableEntityPersister.java @@ -22,6 +22,7 @@ import org.hibernate.MappingException; import org.hibernate.boot.model.relational.Database; import org.hibernate.cache.spi.access.EntityDataAccess; import org.hibernate.cache.spi.access.NaturalIdDataAccess; +import org.hibernate.dialect.Dialect; import org.hibernate.engine.jdbc.env.spi.JdbcEnvironment; import org.hibernate.engine.spi.ExecuteUpdateResultCheckStyle; import org.hibernate.engine.spi.SessionFactoryImplementor; @@ -55,6 +56,8 @@ import org.hibernate.sql.ast.tree.expression.QueryLiteral; import org.hibernate.sql.ast.tree.from.TableGroup; import org.hibernate.sql.ast.tree.predicate.ComparisonPredicate; import org.hibernate.sql.ast.tree.predicate.InListPredicate; +import org.hibernate.sql.ast.tree.predicate.Junction; +import org.hibernate.sql.ast.tree.predicate.NullnessPredicate; import org.hibernate.sql.ast.tree.predicate.Predicate; import org.hibernate.type.AssociationType; import org.hibernate.type.BasicType; @@ -193,13 +196,14 @@ public class SingleTableEntityPersister extends AbstractEntityPersister { Iterator joinIter = persistentClass.getJoinClosureIterator(); int j = 1; + final Dialect dialect = factory.getJdbcServices().getDialect(); while ( joinIter.hasNext() ) { Join join = joinIter.next(); qualifiedTableNames[j] = determineTableName( join.getTable(), jdbcEnvironment ); isInverseTable[j] = join.isInverse(); isNullableTable[j] = join.isOptional(); cascadeDeleteEnabled[j] = join.getKey().isCascadeDeleteEnabled() && - factory.getJdbcServices().getDialect().supportsCascadeDelete(); + dialect.supportsCascadeDelete(); customSQLInsert[j] = join.getCustomSQLInsert(); insertCallable[j] = customSQLInsert[j] != null && join.isCustomInsertCallable(); @@ -222,7 +226,7 @@ public class SingleTableEntityPersister extends AbstractEntityPersister { int i = 0; while ( iter.hasNext() ) { Column col = (Column) iter.next(); - keyColumnNames[j][i++] = col.getQuotedName( factory.getJdbcServices().getDialect() ); + keyColumnNames[j][i++] = col.getQuotedName( dialect ); } j++; @@ -280,7 +284,7 @@ public class SingleTableEntityPersister extends AbstractEntityPersister { int i = 0; while ( iter.hasNext() ) { Column col = (Column) iter.next(); - keyCols[i++] = col.getQuotedName( factory.getJdbcServices().getDialect() ); + keyCols[i++] = col.getQuotedName( dialect ); } joinKeyColumns.add( keyCols ); } @@ -310,7 +314,7 @@ public class SingleTableEntityPersister extends AbstractEntityPersister { Formula formula = (Formula) selectable; discriminatorFormula = formula.getFormula(); discriminatorFormulaTemplate = formula.getTemplate( - factory.getJdbcServices().getDialect(), + dialect, factory.getQueryEngine().getSqmFunctionRegistry() ); discriminatorColumnName = null; @@ -320,13 +324,13 @@ public class SingleTableEntityPersister extends AbstractEntityPersister { } else { Column column = (Column) selectable; - discriminatorColumnName = column.getQuotedName( factory.getJdbcServices().getDialect() ); - discriminatorColumnReaders = column.getReadExpr( factory.getJdbcServices().getDialect() ); + discriminatorColumnName = column.getQuotedName( dialect ); + discriminatorColumnReaders = column.getReadExpr( dialect ); discriminatorColumnReaderTemplate = column.getTemplate( - factory.getJdbcServices().getDialect(), + dialect, factory.getQueryEngine().getSqmFunctionRegistry() ); - discriminatorAlias = column.getAlias( factory.getJdbcServices().getDialect(), persistentClass.getRootTable() ); + discriminatorAlias = column.getAlias( dialect, persistentClass.getRootTable() ); discriminatorFormula = null; discriminatorFormulaTemplate = null; } @@ -345,7 +349,8 @@ public class SingleTableEntityPersister extends AbstractEntityPersister { discriminatorInsertable = persistentClass.isDiscriminatorInsertable() && !discrimValue.hasFormula(); try { discriminatorValue = discriminatorType.stringToObject( persistentClass.getDiscriminatorValue() ); - discriminatorSQLValue = ((DiscriminatorType) discriminatorType).objectToSQLString( discriminatorValue, factory.getJdbcServices().getDialect() ); + discriminatorSQLValue = ((DiscriminatorType) discriminatorType) + .objectToSQLString( discriminatorValue, dialect ); } catch (ClassCastException cce) { throw new MappingException( "Illegal discriminator type: " + discriminatorType.getName() ); @@ -983,10 +988,35 @@ public class SingleTableEntityPersister extends AbstractEntityPersister { if ( hasSubclasses() ) { final Object[] discriminatorValues = fullDiscriminatorValues(); final List values = new ArrayList<>( discriminatorValues.length ); + boolean hasNull = false, hasNonNull = false; for ( Object discriminatorValue : discriminatorValues ) { - values.add( new QueryLiteral<>( discriminatorValue, discriminatorType ) ); + if ( discriminatorValue == NULL_DISCRIMINATOR ) { + hasNull = true; + } + else if ( discriminatorValue == NOT_NULL_DISCRIMINATOR ) { + hasNonNull = true; + } + else { + values.add( new QueryLiteral<>( discriminatorValue, discriminatorType ) ); + } } - return new InListPredicate( sqlExpression, values ); + final Predicate p = new InListPredicate( sqlExpression, values ); + if ( hasNull || hasNonNull ) { + final Junction junction = new Junction( + Junction.Nature.DISJUNCTION + ); + + // This essentially means we need to select everything, so we don't need a predicate at all + // so we return an empty Junction + if ( hasNull && hasNonNull ) { + return junction; + } + + junction.add( new NullnessPredicate( sqlExpression ) ); + junction.add( p ); + return junction; + } + return p; } return new ComparisonPredicate( sqlExpression, diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java index 1c1ed0deea..331e6cfce6 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java @@ -3933,7 +3933,8 @@ public abstract class AbstractSqlAstTranslator implemen @Override public void visitInListPredicate(InListPredicate inListPredicate) { - if ( inListPredicate.getListExpressions().isEmpty() ) { + final List listExpressions = inListPredicate.getListExpressions(); + if ( listExpressions.isEmpty() ) { appendSql( "false" ); return; } @@ -3947,7 +3948,7 @@ public abstract class AbstractSqlAstTranslator implemen } appendSql( " in (" ); String separator = NO_SEPARATOR; - for ( Expression expression : inListPredicate.getListExpressions() ) { + for ( Expression expression : listExpressions ) { appendSql( separator ); getTuple( expression ).getExpressions().get( 0 ).accept( this ); separator = COMA_SEPARATOR; @@ -3966,7 +3967,7 @@ public abstract class AbstractSqlAstTranslator implemen } appendSql( " in (" ); String separator = NO_SEPARATOR; - for ( Expression expression : inListPredicate.getListExpressions() ) { + for ( Expression expression : listExpressions ) { appendSql( separator ); renderExpressionsAsSubquery( getTuple( expression ).getExpressions() @@ -3977,7 +3978,7 @@ public abstract class AbstractSqlAstTranslator implemen } else { String separator = NO_SEPARATOR; - for ( Expression expression : inListPredicate.getListExpressions() ) { + for ( Expression expression : listExpressions ) { appendSql( separator ); emulateTupleComparison( lhsTuple.getExpressions(), @@ -3995,7 +3996,7 @@ public abstract class AbstractSqlAstTranslator implemen appendSql( " not" ); } appendSql( " in (" ); - renderCommaSeparated( inListPredicate.getListExpressions() ); + renderCommaSeparated( listExpressions ); appendSql( CLOSE_PARENTHESIS ); } } @@ -4005,7 +4006,7 @@ public abstract class AbstractSqlAstTranslator implemen appendSql( " not" ); } appendSql( " in (" ); - renderCommaSeparated( inListPredicate.getListExpressions() ); + renderCommaSeparated( listExpressions ); appendSql( CLOSE_PARENTHESIS ); } } @@ -4218,14 +4219,26 @@ public abstract class AbstractSqlAstTranslator implemen return; } - final String separator = junction.getNature() == Junction.Nature.CONJUNCTION - ? " and " - : " or "; + final Junction.Nature nature = junction.getNature(); + final String separator = nature == Junction.Nature.CONJUNCTION + ? " and " + : " or "; final List predicates = junction.getPredicates(); - predicates.get( 0 ).accept( this ); + visitJunctionPredicate( nature, predicates.get( 0 ) ); for ( int i = 1; i < predicates.size(); i++ ) { appendSql( separator ); - predicates.get( i ).accept( this ); + visitJunctionPredicate( nature, predicates.get( i ) ); + } + } + + private void visitJunctionPredicate(Junction.Nature nature, Predicate p) { + if ( p instanceof Junction && nature != ( (Junction) p ).getNature() ) { + appendSql( '(' ); + p.accept( this ); + appendSql( ')' ); + } + else { + p.accept( this ); } }