HHH-18929 generate check constraints for discriminator columns

This commit is contained in:
Gavin King 2024-12-12 20:32:05 +01:00
parent 18d6a993b2
commit 7f8b685973
7 changed files with 137 additions and 86 deletions

View File

@ -28,6 +28,10 @@ public @interface DiscriminatorOptions {
* instances of a root entity and its subtypes. This is useful if * instances of a root entity and its subtypes. This is useful if
* there are discriminator column values which do <em>not</em> * there are discriminator column values which do <em>not</em>
* map to any subtype of the root entity type. * map to any subtype of the root entity type.
* <p>
* This setting has the side effect of suppressing the generation
* of a {@code check} constraint in the DDL for the discriminator
* column.
* *
* @return {@code true} if allowed discriminator values must always * @return {@code true} if allowed discriminator values must always
* be explicitly enumerated * be explicitly enumerated

View File

@ -0,0 +1,97 @@
/*
* SPDX-License-Identifier: LGPL-2.1-or-later
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.boot.model.internal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.hibernate.MappingException;
import org.hibernate.boot.spi.SecondPass;
import org.hibernate.dialect.Dialect;
import org.hibernate.mapping.CheckConstraint;
import org.hibernate.mapping.Column;
import org.hibernate.mapping.PersistentClass;
import org.hibernate.mapping.Selectable;
import org.hibernate.mapping.Subclass;
import org.hibernate.persister.entity.DiscriminatorHelper;
public class DiscriminatorColumnSecondPass implements SecondPass {
private final String rootEntityName;
private final Dialect dialect;
public DiscriminatorColumnSecondPass(String rootEntityName, Dialect dialect) {
this.rootEntityName = rootEntityName;
this.dialect = dialect;
}
@Override
public void doSecondPass(Map<String, PersistentClass> persistentClasses) throws MappingException {
final PersistentClass rootClass = persistentClasses.get( rootEntityName );
if ( hasNullDiscriminatorValue( rootClass ) ) {
for ( Selectable selectable: rootClass.getDiscriminator().getSelectables() ) {
if ( selectable instanceof Column column ) {
column.setNullable( true );
}
}
}
if ( !hasNotNullDiscriminatorValue( rootClass ) // a "not null" discriminator is a catch-all
&& !rootClass.getDiscriminator().hasFormula() // can't add check constraints to formulas
&& !rootClass.isForceDiscriminator() ) { // the usecase for "forced" discriminators is that there are some rogue values
final Column column = rootClass.getDiscriminator().getColumns().get( 0 );
column.addCheckConstraint( new CheckConstraint( checkConstraint( rootClass, column ) ) );
}
}
private boolean hasNullDiscriminatorValue(PersistentClass rootClass) {
if ( rootClass.isDiscriminatorValueNull() ) {
return true;
}
for ( Subclass subclass : rootClass.getSubclasses() ) {
if ( subclass.isDiscriminatorValueNull() ) {
return true;
}
}
return false;
}
private boolean hasNotNullDiscriminatorValue(PersistentClass rootClass) {
if ( rootClass.isDiscriminatorValueNotNull() ) {
return true;
}
for ( Subclass subclass : rootClass.getSubclasses() ) {
if ( subclass.isDiscriminatorValueNotNull() ) {
return true;
}
}
return false;
}
private String checkConstraint(PersistentClass rootClass, Column column) {
return dialect.getCheckCondition(
column.getQuotedName( dialect ),
discriminatorValues( rootClass ),
column.getType().getJdbcType()
);
}
private static List<String> discriminatorValues(PersistentClass rootClass) {
final List<String> values = new ArrayList<>();
if ( !rootClass.isAbstract()
&& !rootClass.isDiscriminatorValueNull()
&& !rootClass.isDiscriminatorValueNotNull() ) {
values.add( DiscriminatorHelper.getDiscriminatorValue( rootClass ).toString() );
}
for ( Subclass subclass : rootClass.getSubclasses() ) {
if ( !subclass.isAbstract()
&& !subclass.isDiscriminatorValueNull()
&& !subclass.isDiscriminatorValueNotNull() ) {
values.add( DiscriminatorHelper.getDiscriminatorValue( subclass ).toString() );
}
}
return values;
}
}

View File

@ -958,7 +958,9 @@ public class EntityBinder {
rootClass.setPolymorphic( true ); rootClass.setPolymorphic( true );
final String rootEntityName = rootClass.getEntityName(); final String rootEntityName = rootClass.getEntityName();
LOG.tracev( "Setting discriminator for entity {0}", rootEntityName); LOG.tracev( "Setting discriminator for entity {0}", rootEntityName);
getMetadataCollector().addSecondPass( new NullableDiscriminatorColumnSecondPass( rootEntityName ) ); getMetadataCollector()
.addSecondPass( new DiscriminatorColumnSecondPass( rootEntityName,
context.getMetadataCollector().getDatabase().getDialect() ) );
} }
} }

View File

@ -1,46 +0,0 @@
/*
* SPDX-License-Identifier: LGPL-2.1-or-later
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.boot.model.internal;
import java.util.Map;
import org.hibernate.MappingException;
import org.hibernate.boot.spi.SecondPass;
import org.hibernate.mapping.Column;
import org.hibernate.mapping.PersistentClass;
import org.hibernate.mapping.Selectable;
import org.hibernate.mapping.Subclass;
public class NullableDiscriminatorColumnSecondPass implements SecondPass {
private final String rootEntityName;
public NullableDiscriminatorColumnSecondPass(String rootEntityName) {
this.rootEntityName = rootEntityName;
}
@Override
public void doSecondPass(Map<String, PersistentClass> persistentClasses) throws MappingException {
final PersistentClass rootPersistenceClass = persistentClasses.get( rootEntityName );
if ( hasNullDiscriminatorValue( rootPersistenceClass ) ) {
for ( Selectable selectable: rootPersistenceClass.getDiscriminator().getSelectables() ) {
if ( selectable instanceof Column ) {
( (Column) selectable ).setNullable( true );
}
}
}
}
private boolean hasNullDiscriminatorValue(PersistentClass rootPersistenceClass) {
if ( rootPersistenceClass.isDiscriminatorValueNull() ) {
return true;
}
for ( Subclass subclass : rootPersistenceClass.getSubclasses() ) {
if ( subclass.isDiscriminatorValueNull() ) {
return true;
}
}
return false;
}
}

View File

@ -189,6 +189,7 @@ import java.time.temporal.ChronoUnit;
import java.time.temporal.TemporalAccessor; import java.time.temporal.TemporalAccessor;
import java.time.temporal.TemporalAmount; import java.time.temporal.TemporalAmount;
import java.util.Calendar; import java.util.Calendar;
import java.util.Collection;
import java.util.Date; import java.util.Date;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
@ -839,11 +840,17 @@ public abstract class Dialect implements ConversionContext, TypeContributor, Fun
} }
/** /**
* Generate a check condition for column with the given set of values. * Generate a SQL {@code check} condition for the given column,
* constraining to the given values.
* *
* @apiNote Only supports TINYINT, SMALLINT and (VAR)CHAR * @return a SQL expression that will occur in a {@code check} constraint
*
* @apiNote Only supports {@code TINYINT}, {@code SMALLINT}, {@code CHAR},
* and {@code VARCHAR}
*
* @since 7.0
*/ */
public String getCheckCondition(String columnName, Set<?> valueSet, JdbcType jdbcType) { public String getCheckCondition(String columnName, Collection<?> valueSet, JdbcType jdbcType) {
final boolean isCharacterJdbcType = isCharacterType( jdbcType.getJdbcTypeCode() ); final boolean isCharacterJdbcType = isCharacterType( jdbcType.getJdbcTypeCode() );
assert isCharacterJdbcType || isIntegral( jdbcType.getJdbcTypeCode() ); assert isCharacterJdbcType || isIntegral( jdbcType.getJdbcTypeCode() );
@ -856,11 +863,12 @@ public abstract class Dialect implements ConversionContext, TypeContributor, Fun
nullIsValid = true; nullIsValid = true;
continue; continue;
} }
check.append( separator );
if ( isCharacterJdbcType ) { if ( isCharacterJdbcType ) {
check.append( separator ).append('\'').append( value ).append('\''); check.append('\'').append( value ).append('\'');
} }
else { else {
check.append( separator ).append( value ); check.append( value );
} }
separator = ","; separator = ",";
} }

View File

@ -12,6 +12,7 @@ import java.time.Duration;
import java.time.temporal.TemporalAccessor; import java.time.temporal.TemporalAccessor;
import java.time.temporal.TemporalAmount; import java.time.temporal.TemporalAmount;
import java.util.Calendar; import java.util.Calendar;
import java.util.Collection;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -1681,7 +1682,7 @@ public class DialectDelegateWrapper extends Dialect {
} }
@Override @Override
public String getCheckCondition(String columnName, Set<?> valueSet, JdbcType jdbcType) { public String getCheckCondition(String columnName, Collection<?> valueSet, JdbcType jdbcType) {
return wrapped.getCheckCondition( columnName, valueSet, jdbcType ); return wrapped.getCheckCondition( columnName, valueSet, jdbcType );
} }

View File

@ -37,9 +37,9 @@ public class DiscriminatorHelper {
* and the {@link org.hibernate.type.descriptor.jdbc.JdbcType}. * and the {@link org.hibernate.type.descriptor.jdbc.JdbcType}.
*/ */
static BasicType<?> getDiscriminatorType(PersistentClass persistentClass) { static BasicType<?> getDiscriminatorType(PersistentClass persistentClass) {
Type discriminatorType = persistentClass.getDiscriminator().getType(); final Type discriminatorType = persistentClass.getDiscriminator().getType();
if ( discriminatorType instanceof BasicType ) { if ( discriminatorType instanceof BasicType<?> basicType ) {
return (BasicType<?>) discriminatorType; return basicType;
} }
else { else {
throw new MappingException( "Illegal discriminator type: " + discriminatorType.getName() ); throw new MappingException( "Illegal discriminator type: " + discriminatorType.getName() );
@ -47,16 +47,16 @@ public class DiscriminatorHelper {
} }
public static BasicType<?> getDiscriminatorType(Component component) { public static BasicType<?> getDiscriminatorType(Component component) {
Type discriminatorType = component.getDiscriminator().getType(); final Type discriminatorType = component.getDiscriminator().getType();
if ( discriminatorType instanceof BasicType ) { if ( discriminatorType instanceof BasicType<?> basicType ) {
return (BasicType<?>) discriminatorType; return basicType;
} }
else { else {
throw new MappingException( "Illegal discriminator type: " + discriminatorType.getName() ); throw new MappingException( "Illegal discriminator type: " + discriminatorType.getName() );
} }
} }
static String getDiscriminatorSQLValue(PersistentClass persistentClass, Dialect dialect) { public static String getDiscriminatorSQLValue(PersistentClass persistentClass, Dialect dialect) {
if ( persistentClass.isDiscriminatorValueNull() ) { if ( persistentClass.isDiscriminatorValueNull() ) {
return InFragment.NULL; return InFragment.NULL;
} }
@ -69,16 +69,18 @@ public class DiscriminatorHelper {
} }
private static Object parseDiscriminatorValue(PersistentClass persistentClass) { private static Object parseDiscriminatorValue(PersistentClass persistentClass) {
BasicType<?> discriminatorType = getDiscriminatorType( persistentClass ); final BasicType<?> discriminatorType = getDiscriminatorType( persistentClass );
final String discriminatorValue = persistentClass.getDiscriminatorValue();
try { try {
return discriminatorType.getJavaTypeDescriptor().fromString( persistentClass.getDiscriminatorValue() ); return discriminatorType.getJavaTypeDescriptor().fromString( discriminatorValue );
} }
catch ( Exception e ) { catch ( Exception e ) {
throw new MappingException( "Could not parse discriminator value", e ); throw new MappingException( "Could not parse discriminator value '" + discriminatorValue
+ "' as discriminator type '" + discriminatorType.getName() + "'", e );
} }
} }
static Object getDiscriminatorValue(PersistentClass persistentClass) { public static Object getDiscriminatorValue(PersistentClass persistentClass) {
if ( persistentClass.isDiscriminatorValueNull() ) { if ( persistentClass.isDiscriminatorValueNull() ) {
return NULL_DISCRIMINATOR; return NULL_DISCRIMINATOR;
} }
@ -101,10 +103,7 @@ public class DiscriminatorHelper {
); );
} }
public static <T> String jdbcLiteral( public static <T> String jdbcLiteral(T value, JdbcLiteralFormatter<T> formatter, Dialect dialect) {
T value,
JdbcLiteralFormatter<T> formatter,
Dialect dialect) {
try { try {
return formatter.toJdbcLiteral( value, dialect, null ); return formatter.toJdbcLiteral( value, dialect, null );
} }
@ -119,26 +118,12 @@ public class DiscriminatorHelper {
* domain types, or to {@link StandardBasicTypes#CLASS Class} for non-inherited ones. * domain types, or to {@link StandardBasicTypes#CLASS Class} for non-inherited ones.
*/ */
public static <T> SqmExpressible<? super T> getDiscriminatorType( public static <T> SqmExpressible<? super T> getDiscriminatorType(
SqmPathSource<T> domainType, SqmPathSource<T> domainType, NodeBuilder nodeBuilder) {
NodeBuilder nodeBuilder) {
final SqmPathSource<?> subPathSource = domainType.findSubPathSource( DISCRIMINATOR_ROLE_NAME ); final SqmPathSource<?> subPathSource = domainType.findSubPathSource( DISCRIMINATOR_ROLE_NAME );
final SqmExpressible<?> type; final SqmExpressible<?> type = subPathSource != null
if ( subPathSource != null ) { ? subPathSource.getSqmPathType()
type = subPathSource.getSqmPathType(); : nodeBuilder.getTypeConfiguration().getBasicTypeRegistry().resolve( StandardBasicTypes.CLASS );
}
else {
type = nodeBuilder.getTypeConfiguration()
.getBasicTypeRegistry()
.resolve( StandardBasicTypes.CLASS );
}
//noinspection unchecked //noinspection unchecked
return (SqmExpressible<? super T>) type; return (SqmExpressible<? super T>) type;
} }
static String discriminatorLiteral(JdbcLiteralFormatter<Object> formatter, Dialect dialect, Object value) {
return value == NULL_DISCRIMINATOR || value == NOT_NULL_DISCRIMINATOR
? null
: jdbcLiteral( value, formatter, dialect );
}
} }