HHH-18793 Add JSON aggregate support for MySQL

This commit is contained in:
Christian Beikov 2024-11-05 19:49:46 +01:00
parent a2c714e8f8
commit 1f5f778358
10 changed files with 384 additions and 11 deletions

View File

@ -16,6 +16,8 @@ import org.hibernate.boot.model.FunctionContributions;
import org.hibernate.boot.model.TypeContributions;
import org.hibernate.cfg.Environment;
import org.hibernate.dialect.*;
import org.hibernate.dialect.aggregate.AggregateSupport;
import org.hibernate.dialect.aggregate.MySQLAggregateSupport;
import org.hibernate.dialect.function.CommonFunctionFactory;
import org.hibernate.dialect.identity.IdentityColumnSupport;
import org.hibernate.dialect.identity.MySQLIdentityColumnSupport;
@ -263,7 +265,10 @@ public class MySQLLegacyDialect extends Dialect {
//MySQL doesn't let you cast to DOUBLE/FLOAT
//but don't just return 'decimal' because
//the default scale is 0 (no decimal places)
return "decimal($p,$s)";
return getMySQLVersion().isSameOrAfter( 8, 0, 17 )
// In newer versions of MySQL, casting to float/double is supported
? super.castType( sqlTypeCode )
: "decimal($p,$s)";
case CHAR:
case NCHAR:
case VARCHAR:
@ -385,6 +390,13 @@ public class MySQLLegacyDialect extends Dialect {
ddlTypeRegistry.addDescriptor( new NativeOrdinalEnumDdlTypeImpl( this ) );
}
@Override
public AggregateSupport getAggregateSupport() {
return getMySQLVersion().isSameOrAfter( 5, 7 )
? MySQLAggregateSupport.JSON_INSTANCE
: super.getAggregateSupport();
}
@Deprecated
protected static int getCharacterSetBytesPerCharacter(DatabaseMetaData databaseMetaData) {
if ( databaseMetaData != null ) {

View File

@ -83,6 +83,7 @@ import org.hibernate.type.descriptor.java.ByteArrayJavaType;
import org.hibernate.type.descriptor.java.CharacterArrayJavaType;
import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry;
import org.hibernate.type.descriptor.jdbc.JdbcType;
import org.hibernate.type.descriptor.jdbc.JdbcTypeConstructor;
import org.hibernate.type.descriptor.jdbc.JsonArrayJdbcTypeConstructor;
import org.hibernate.type.descriptor.jdbc.JsonAsStringArrayJdbcTypeConstructor;
import org.hibernate.type.descriptor.jdbc.JsonAsStringJdbcType;
@ -101,6 +102,7 @@ import org.hibernate.usertype.CompositeUserType;
import jakarta.persistence.AttributeConverter;
import static org.hibernate.internal.util.collections.CollectionHelper.mutableJoin;
import static org.hibernate.internal.util.config.ConfigurationHelper.getPreferredSqlTypeCodeForArray;
import static org.hibernate.internal.util.config.ConfigurationHelper.getPreferredSqlTypeCodeForDuration;
import static org.hibernate.internal.util.config.ConfigurationHelper.getPreferredSqlTypeCodeForInstant;
import static org.hibernate.internal.util.config.ConfigurationHelper.getPreferredSqlTypeCodeForUuid;
@ -771,6 +773,14 @@ public class MetadataBuildingProcess {
jdbcTypeRegistry.addTypeConstructor( XmlAsStringArrayJdbcTypeConstructor.INSTANCE );
}
}
if ( jdbcTypeRegistry.getConstructor( SqlTypes.ARRAY ) == null ) {
// Default the array constructor to e.g. JSON_ARRAY/XML_ARRAY if needed
final JdbcTypeConstructor constructor =
jdbcTypeRegistry.getConstructor( getPreferredSqlTypeCodeForArray( serviceRegistry ) );
if ( constructor != null ) {
jdbcTypeRegistry.addTypeConstructor( SqlTypes.ARRAY, constructor );
}
}
final int preferredSqlTypeCodeForDuration = getPreferredSqlTypeCodeForDuration( serviceRegistry );
if ( preferredSqlTypeCodeForDuration != SqlTypes.INTERVAL_SECOND ) {

View File

@ -322,6 +322,7 @@ public class JsonHelper {
appender.append( '"' );
break;
case SqlTypes.ARRAY:
case SqlTypes.JSON_ARRAY:
final int length = Array.getLength( value );
appender.append( '[' );
if ( length != 0 ) {

View File

@ -21,6 +21,8 @@ import org.hibernate.PessimisticLockException;
import org.hibernate.boot.model.FunctionContributions;
import org.hibernate.boot.model.TypeContributions;
import org.hibernate.cfg.Environment;
import org.hibernate.dialect.aggregate.AggregateSupport;
import org.hibernate.dialect.aggregate.MySQLAggregateSupport;
import org.hibernate.dialect.function.CommonFunctionFactory;
import org.hibernate.dialect.identity.IdentityColumnSupport;
import org.hibernate.dialect.identity.MySQLIdentityColumnSupport;
@ -316,12 +318,15 @@ public class MySQLDialect extends Dialect {
// MySQL doesn't let you cast to DOUBLE/FLOAT
// but don't just return 'decimal' because
// the default scale is 0 (no decimal places)
case FLOAT, REAL, DOUBLE -> "decimal($p,$s)";
case FLOAT, REAL, DOUBLE -> getMySQLVersion().isSameOrAfter( 8, 0, 17 )
// In newer versions of MySQL, casting to float/double is supported
? super.castType( sqlTypeCode )
: "decimal($p,$s)";
// MySQL doesn't let you cast to TEXT/LONGTEXT
case CHAR, VARCHAR, LONG32VARCHAR -> "char";
case NCHAR, NVARCHAR, LONG32NVARCHAR -> "char character set utf8mb4";
case CHAR, VARCHAR, LONG32VARCHAR, CLOB -> "char";
case NCHAR, NVARCHAR, LONG32NVARCHAR, NCLOB -> "char character set utf8mb4";
// MySQL doesn't let you cast to BLOB/TINYBLOB/LONGBLOB
case BINARY, VARBINARY, LONG32VARBINARY -> "binary";
case BINARY, VARBINARY, LONG32VARBINARY, BLOB -> "binary";
default -> super.castType(sqlTypeCode);
};
}
@ -433,6 +438,11 @@ public class MySQLDialect extends Dialect {
ddlTypeRegistry.addDescriptor( new NativeOrdinalEnumDdlTypeImpl( this ) );
}
@Override
public AggregateSupport getAggregateSupport() {
return MySQLAggregateSupport.valueOf( this );
}
@Deprecated(since="6.4")
protected static int getCharacterSetBytesPerCharacter(DatabaseMetaData databaseMetaData) {
if ( databaseMetaData != null ) {

View File

@ -48,6 +48,11 @@ import java.util.Locale;
*/
public class MySQLSqlAstTranslator<T extends JdbcOperation> extends AbstractSqlAstTranslator<T> {
/**
* On MySQL, 1GB or {@code 2^30 - 1} is the maximum size that a char value can be casted.
*/
private static final int MAX_CHAR_SIZE = (1 << 30) - 1;
public MySQLSqlAstTranslator(SessionFactoryImplementor sessionFactory, Statement statement) {
super( sessionFactory, statement );
}
@ -64,7 +69,7 @@ public class MySQLSqlAstTranslator<T extends JdbcOperation> extends AbstractSqlA
private static String getSqlType(CastTarget castTarget, String sqlType, Dialect dialect) {
if ( sqlType != null ) {
int parenthesesIndex = sqlType.indexOf( '(' );
final String baseName = parenthesesIndex == -1 ? sqlType : sqlType.substring( 0, parenthesesIndex );
final String baseName = parenthesesIndex == -1 ? sqlType : sqlType.substring( 0, parenthesesIndex ).trim();
switch ( baseName.toLowerCase( Locale.ROOT ) ) {
case "bit":
return "unsigned";
@ -76,6 +81,9 @@ public class MySQLSqlAstTranslator<T extends JdbcOperation> extends AbstractSqlA
case "float":
case "real":
case "double precision":
if ( ((MySQLDialect) dialect).getMySQLVersion().isSameOrAfter( 8, 0, 17 ) ) {
return sqlType;
}
final int precision = castTarget.getPrecision() == null
? dialect.getDefaultDecimalPrecision()
: castTarget.getPrecision();
@ -85,6 +93,10 @@ public class MySQLSqlAstTranslator<T extends JdbcOperation> extends AbstractSqlA
case "varchar":
case "nchar":
case "nvarchar":
case "text":
case "mediumtext":
case "longtext":
case "enum":
if ( castTarget.getLength() == null ) {
// TODO: this is ugly and fragile, but could easily be handled in a DdlType
if ( castTarget.getJdbcMapping().getJdbcJavaType().getJavaType() == Character.class ) {
@ -94,9 +106,11 @@ public class MySQLSqlAstTranslator<T extends JdbcOperation> extends AbstractSqlA
return "char";
}
}
return "char(" + castTarget.getLength() + ")";
return castTarget.getLength() > MAX_CHAR_SIZE ? "char" : "char(" + castTarget.getLength() + ")";
case "binary":
case "varbinary":
case "mediumblob":
case "longblob":
return castTarget.getLength() == null
? "binary"
: "binary(" + castTarget.getLength() + ")";

View File

@ -13,6 +13,7 @@ import org.hibernate.mapping.AggregateColumn;
import org.hibernate.mapping.Column;
import org.hibernate.metamodel.mapping.SelectableMapping;
import org.hibernate.metamodel.mapping.SqlTypedMapping;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.spi.TypeConfiguration;
public class AggregateSupportImpl implements AggregateSupport {
@ -76,7 +77,10 @@ public class AggregateSupportImpl implements AggregateSupport {
@Override
public int aggregateComponentSqlTypeCode(int aggregateColumnSqlTypeCode, int columnSqlTypeCode) {
return columnSqlTypeCode;
return switch (aggregateColumnSqlTypeCode) {
case SqlTypes.JSON -> columnSqlTypeCode == SqlTypes.ARRAY ? SqlTypes.JSON_ARRAY : columnSqlTypeCode;
default -> columnSqlTypeCode;
};
}
@Override

View File

@ -0,0 +1,320 @@
/*
* SPDX-License-Identifier: LGPL-2.1-or-later
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.dialect.aggregate;
import org.hibernate.dialect.Dialect;
import org.hibernate.internal.util.StringHelper;
import org.hibernate.mapping.Column;
import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.metamodel.mapping.SelectableMapping;
import org.hibernate.metamodel.mapping.SelectablePath;
import org.hibernate.metamodel.mapping.SqlTypedMapping;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.type.spi.TypeConfiguration;
import java.util.LinkedHashMap;
import java.util.Map;
import static org.hibernate.type.SqlTypes.BIGINT;
import static org.hibernate.type.SqlTypes.BINARY;
import static org.hibernate.type.SqlTypes.BIT;
import static org.hibernate.type.SqlTypes.BLOB;
import static org.hibernate.type.SqlTypes.BOOLEAN;
import static org.hibernate.type.SqlTypes.CHAR;
import static org.hibernate.type.SqlTypes.CLOB;
import static org.hibernate.type.SqlTypes.ENUM;
import static org.hibernate.type.SqlTypes.INTEGER;
import static org.hibernate.type.SqlTypes.JSON;
import static org.hibernate.type.SqlTypes.JSON_ARRAY;
import static org.hibernate.type.SqlTypes.LONG32NVARCHAR;
import static org.hibernate.type.SqlTypes.LONG32VARBINARY;
import static org.hibernate.type.SqlTypes.LONG32VARCHAR;
import static org.hibernate.type.SqlTypes.NCHAR;
import static org.hibernate.type.SqlTypes.NCLOB;
import static org.hibernate.type.SqlTypes.NVARCHAR;
import static org.hibernate.type.SqlTypes.SMALLINT;
import static org.hibernate.type.SqlTypes.TIMESTAMP;
import static org.hibernate.type.SqlTypes.TIMESTAMP_UTC;
import static org.hibernate.type.SqlTypes.TINYINT;
import static org.hibernate.type.SqlTypes.VARBINARY;
import static org.hibernate.type.SqlTypes.VARCHAR;
public class MySQLAggregateSupport extends AggregateSupportImpl {
private static final AggregateSupport INSTANCE = new MySQLAggregateSupport();
public static AggregateSupport valueOf(Dialect dialect) {
return MySQLAggregateSupport.INSTANCE;
}
@Override
public String aggregateComponentCustomReadExpression(
String template,
String placeholder,
String aggregateParentReadExpression,
String columnExpression,
int aggregateColumnTypeCode,
SqlTypedMapping column) {
switch ( aggregateColumnTypeCode ) {
case JSON_ARRAY:
case JSON:
switch ( column.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) {
case JSON:
case JSON_ARRAY:
return template.replace(
placeholder,
queryExpression( aggregateParentReadExpression, columnExpression )
);
case BOOLEAN:
return template.replace(
placeholder,
"case " + queryExpression( aggregateParentReadExpression, columnExpression ) + " when 'true' then true when 'false' then false end"
);
case BINARY:
case VARBINARY:
case LONG32VARBINARY:
// We encode binary data as hex, so we have to decode here
return template.replace(
placeholder,
"unhex(json_unquote(" + queryExpression( aggregateParentReadExpression, columnExpression ) + "))"
);
default:
return template.replace(
placeholder,
valueExpression( aggregateParentReadExpression, columnExpression, columnCastType( column ) )
);
}
}
throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode );
}
private static String columnCastType(SqlTypedMapping column) {
return switch (column.getJdbcMapping().getJdbcType().getDdlTypeCode()) {
// special case for casting to Boolean
case BOOLEAN, BIT -> "unsigned";
// MySQL doesn't let you cast to INTEGER/BIGINT/TINYINT
case TINYINT, SMALLINT, INTEGER, BIGINT -> "signed";
// MySQL doesn't let you cast to TEXT/LONGTEXT
case CHAR, VARCHAR, LONG32VARCHAR, CLOB, ENUM -> "char";
case NCHAR, NVARCHAR, LONG32NVARCHAR, NCLOB -> "char character set utf8mb4";
// MySQL doesn't let you cast to BLOB/TINYBLOB/LONGBLOB
case BINARY, VARBINARY, LONG32VARBINARY, BLOB -> "binary";
default -> column.getColumnDefinition();
};
}
private static String valueExpression(String aggregateParentReadExpression, String columnExpression, String columnType) {
return "cast(json_unquote(" + queryExpression( aggregateParentReadExpression, columnExpression ) + ") as " + columnType + ')';
}
private static String queryExpression(String aggregateParentReadExpression, String columnExpression) {
return "nullif(json_extract(" + aggregateParentReadExpression + ",'$." + columnExpression + "'),cast('null' as json))";
}
private static String jsonCustomWriteExpression(String customWriteExpression, JdbcMapping jdbcMapping) {
final int sqlTypeCode = jdbcMapping.getJdbcType().getDefaultSqlTypeCode();
switch ( sqlTypeCode ) {
case BINARY:
case VARBINARY:
case LONG32VARBINARY:
case BLOB:
// We encode binary data as hex
return "hex(" + customWriteExpression + ")";
case BOOLEAN:
return "(" + customWriteExpression + ")=true";
case TIMESTAMP:
return "date_format(" + customWriteExpression + ",'%Y-%m-%dT%T.%f')";
case TIMESTAMP_UTC:
return "date_format(" + customWriteExpression + ",'%Y-%m-%dT%T.%fZ')";
default:
return customWriteExpression;
}
}
@Override
public int aggregateComponentSqlTypeCode(int aggregateColumnSqlTypeCode, int columnSqlTypeCode) {
return super.aggregateComponentSqlTypeCode( aggregateColumnSqlTypeCode, columnSqlTypeCode );
}
@Override
public String aggregateComponentAssignmentExpression(
String aggregateParentAssignmentExpression,
String columnExpression,
int aggregateColumnTypeCode,
Column column) {
switch ( aggregateColumnTypeCode ) {
case JSON:
case JSON_ARRAY:
// For JSON we always have to replace the whole object
return aggregateParentAssignmentExpression;
}
throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode );
}
@Override
public boolean requiresAggregateCustomWriteExpressionRenderer(int aggregateSqlTypeCode) {
switch ( aggregateSqlTypeCode ) {
case JSON:
return true;
}
return false;
}
@Override
public WriteExpressionRenderer aggregateCustomWriteExpressionRenderer(
SelectableMapping aggregateColumn,
SelectableMapping[] columnsToUpdate,
TypeConfiguration typeConfiguration) {
final int aggregateSqlTypeCode = aggregateColumn.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode();
switch ( aggregateSqlTypeCode ) {
case JSON:
return jsonAggregateColumnWriter( aggregateColumn, columnsToUpdate );
}
throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateSqlTypeCode );
}
private WriteExpressionRenderer jsonAggregateColumnWriter(
SelectableMapping aggregateColumn,
SelectableMapping[] columns) {
return new RootJsonWriteExpression( aggregateColumn, columns );
}
interface JsonWriteExpression {
void append(
SqlAppender sb,
String path,
SqlAstTranslator<?> translator,
AggregateColumnWriteExpression expression);
}
private static class AggregateJsonWriteExpression implements JsonWriteExpression {
private final LinkedHashMap<String, JsonWriteExpression> subExpressions = new LinkedHashMap<>();
protected void initializeSubExpressions(SelectableMapping[] columns) {
for ( SelectableMapping column : columns ) {
final SelectablePath selectablePath = column.getSelectablePath();
final SelectablePath[] parts = selectablePath.getParts();
AggregateJsonWriteExpression currentAggregate = this;
for ( int i = 1; i < parts.length - 1; i++ ) {
currentAggregate = (AggregateJsonWriteExpression) currentAggregate.subExpressions.computeIfAbsent(
parts[i].getSelectableName(),
k -> new AggregateJsonWriteExpression()
);
}
final String customWriteExpression = column.getWriteExpression();
currentAggregate.subExpressions.put(
parts[parts.length - 1].getSelectableName(),
new BasicJsonWriteExpression(
column,
jsonCustomWriteExpression( customWriteExpression, column.getJdbcMapping() )
)
);
}
}
@Override
public void append(
SqlAppender sb,
String path,
SqlAstTranslator<?> translator,
AggregateColumnWriteExpression expression) {
for ( Map.Entry<String, JsonWriteExpression> entry : subExpressions.entrySet() ) {
final String column = entry.getKey();
final JsonWriteExpression value = entry.getValue();
final String subPath = queryExpression( path, column );
sb.append( ',' );
if ( value instanceof AggregateJsonWriteExpression ) {
sb.append( "'$." );
sb.append( column );
sb.append( "',json_set(coalesce(" );
sb.append( subPath );
sb.append( ",json_object())" );
value.append( sb, subPath, translator, expression );
sb.append( ')' );
}
else {
value.append( sb, subPath, translator, expression );
}
}
}
}
private static class RootJsonWriteExpression extends AggregateJsonWriteExpression
implements WriteExpressionRenderer {
private final boolean nullable;
private final String path;
RootJsonWriteExpression(SelectableMapping aggregateColumn, SelectableMapping[] columns) {
this.nullable = aggregateColumn.isNullable();
this.path = aggregateColumn.getSelectionExpression();
initializeSubExpressions( columns );
}
@Override
public void render(
SqlAppender sqlAppender,
SqlAstTranslator<?> translator,
AggregateColumnWriteExpression aggregateColumnWriteExpression,
String qualifier) {
final String basePath;
if ( qualifier == null || qualifier.isBlank() ) {
basePath = path;
}
else {
basePath = qualifier + "." + path;
}
sqlAppender.appendSql( "json_set(" );
if ( nullable ) {
sqlAppender.append( "coalesce(" );
sqlAppender.append( basePath );
sqlAppender.append( ",json_object())" );
}
else {
sqlAppender.append( basePath );
}
append( sqlAppender, basePath, translator, aggregateColumnWriteExpression );
sqlAppender.append( ')' );
}
}
private static class BasicJsonWriteExpression implements JsonWriteExpression {
private final SelectableMapping selectableMapping;
private final String customWriteExpressionStart;
private final String customWriteExpressionEnd;
BasicJsonWriteExpression(SelectableMapping selectableMapping, String customWriteExpression) {
this.selectableMapping = selectableMapping;
if ( customWriteExpression.equals( "?" ) ) {
this.customWriteExpressionStart = "";
this.customWriteExpressionEnd = "";
}
else {
final String[] parts = StringHelper.split( "?", customWriteExpression );
assert parts.length == 2;
this.customWriteExpressionStart = parts[0];
this.customWriteExpressionEnd = parts[1];
}
}
@Override
public void append(
SqlAppender sb,
String path,
SqlAstTranslator<?> translator,
AggregateColumnWriteExpression expression) {
sb.append( "'$." );
sb.append( selectableMapping.getSelectableName() );
sb.append( "'," );
sb.append( customWriteExpressionStart );
// We use NO_UNTYPED here so that expressions which require type inference are casted explicitly,
// since we don't know how the custom write expression looks like where this is embedded,
// so we have to be pessimistic and avoid ambiguities
translator.render( expression.getValueExpression( selectableMapping ), SqlAstNodeRenderingMode.NO_UNTYPED );
sb.append( customWriteExpressionEnd );
}
}
}

View File

@ -218,6 +218,7 @@ public class OracleAggregateSupport extends AggregateSupportImpl {
);
}
case JSON:
case JSON_ARRAY:
return template.replace(
placeholder,
"json_query(" + parentPartExpression + columnExpression + "' returning " + jsonTypeName + ")"

View File

@ -58,6 +58,7 @@ public class PostgreSQLAggregateSupport extends AggregateSupportImpl {
case JSON:
switch ( column.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) {
case JSON:
case JSON_ARRAY:
return template.replace(
placeholder,
aggregateParentReadExpression + "->'" + columnExpression + "'"

View File

@ -65,8 +65,8 @@ import org.hibernate.type.BasicTypeRegistry;
import org.hibernate.type.CollectionType;
import org.hibernate.type.CompositeType;
import org.hibernate.type.EntityType;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.Type;
import org.hibernate.type.descriptor.JdbcTypeNameMapper;
import org.hibernate.type.descriptor.java.BasicPluralJavaType;
import org.hibernate.type.descriptor.java.ImmutableMutabilityPlan;
import org.hibernate.type.descriptor.java.JavaType;
@ -319,9 +319,9 @@ public class EmbeddableMappingTypeImpl extends AbstractEmbeddableMapping impleme
}
final BasicType<?> resolvedJdbcMapping;
if ( isArray ) {
final JdbcTypeConstructor arrayConstructor = jdbcTypeRegistry.getConstructor( SqlTypes.ARRAY );
final JdbcTypeConstructor arrayConstructor = jdbcTypeRegistry.getConstructor( aggregateColumnSqlTypeCode );
if ( arrayConstructor == null ) {
throw new IllegalArgumentException( "No JdbcTypeConstructor registered for SqlTypes.ARRAY" );
throw new IllegalArgumentException( "No JdbcTypeConstructor registered for SqlTypes." + JdbcTypeNameMapper.getTypeName( aggregateColumnSqlTypeCode ) );
}
//noinspection rawtypes,unchecked
final BasicType<?> arrayType = ( (BasicPluralJavaType) resolution.getDomainJavaType() ).resolveType(