diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/SQLServerLegacyDialect.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/SQLServerLegacyDialect.java index 2d98bb6aac..978cc6a747 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/SQLServerLegacyDialect.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/SQLServerLegacyDialect.java @@ -19,6 +19,8 @@ import org.hibernate.dialect.Replacer; import org.hibernate.dialect.SQLServerCastingXmlArrayJdbcTypeConstructor; import org.hibernate.dialect.SQLServerCastingXmlJdbcType; import org.hibernate.dialect.TimeZoneSupport; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.SQLServerAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.function.CountFunction; import org.hibernate.dialect.function.SQLServerFormatEmulation; @@ -504,6 +506,11 @@ public class SQLServerLegacyDialect extends AbstractTransactSQLDialect { }; } + @Override + public AggregateSupport getAggregateSupport() { + return SQLServerAggregateSupport.valueOf( this ); + } + @Override public SizeStrategy getSizeStrategy() { return sizeStrategy; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/SQLServerDialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/SQLServerDialect.java index 16019bb4e8..8d4a950b62 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/SQLServerDialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/SQLServerDialect.java @@ -25,6 +25,8 @@ import org.hibernate.boot.model.TypeContributions; import org.hibernate.boot.model.relational.QualifiedSequenceName; import org.hibernate.boot.model.relational.Sequence; import org.hibernate.boot.model.relational.SqlStringGenerationContext; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.SQLServerAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.function.CountFunction; import org.hibernate.dialect.function.SQLServerFormatEmulation; @@ -511,6 +513,11 @@ public class SQLServerDialect extends AbstractTransactSQLDialect { }; } + @Override + public AggregateSupport getAggregateSupport() { + return SQLServerAggregateSupport.valueOf( this ); + } + @Override public SizeStrategy getSizeStrategy() { return sizeStrategy; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/SQLServerAggregateSupport.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/SQLServerAggregateSupport.java new file mode 100644 index 0000000000..ab065459ec --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/SQLServerAggregateSupport.java @@ -0,0 +1,401 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.dialect.aggregate; + +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.internal.util.StringHelper; +import org.hibernate.mapping.Column; +import org.hibernate.metamodel.mapping.EmbeddableMappingType; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.metamodel.mapping.SelectableMapping; +import org.hibernate.metamodel.mapping.SelectablePath; +import org.hibernate.metamodel.mapping.SqlTypedMapping; +import org.hibernate.sql.ast.SqlAstNodeRenderingMode; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.BasicPluralType; +import org.hibernate.type.BasicType; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.jdbc.AggregateJdbcType; +import org.hibernate.type.descriptor.sql.DdlType; +import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry; +import org.hibernate.type.spi.TypeConfiguration; + +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.hibernate.type.SqlTypes.*; + +public class SQLServerAggregateSupport extends AggregateSupportImpl { + + private static final AggregateSupport INSTANCE = new SQLServerAggregateSupport(); + + private static final String JSON_QUERY_START = "json_query("; + private static final String JSON_QUERY_JSON_END = "')"; + private static final int JSON_VALUE_MAX_LENGTH = 4000; + + private SQLServerAggregateSupport() { + } + + public static AggregateSupport valueOf(Dialect dialect) { + return dialect.getVersion().isSameOrAfter( 13 ) + ? SQLServerAggregateSupport.INSTANCE + : AggregateSupportImpl.INSTANCE; + } + + @Override + public String aggregateComponentCustomReadExpression( + String template, + String placeholder, + String aggregateParentReadExpression, + String columnExpression, + int aggregateColumnTypeCode, + SqlTypedMapping column) { + switch ( aggregateColumnTypeCode ) { + case JSON: + case JSON_ARRAY: + final String parentPartExpression; + if ( aggregateParentReadExpression.startsWith( JSON_QUERY_START ) + && aggregateParentReadExpression.endsWith( JSON_QUERY_JSON_END ) ) { + parentPartExpression = aggregateParentReadExpression.substring( JSON_QUERY_START.length(), aggregateParentReadExpression.length() - JSON_QUERY_JSON_END.length() ) + "."; + } + else { + parentPartExpression = aggregateParentReadExpression + ",'$."; + } + switch ( column.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) { + case JSON: + case JSON_ARRAY: + return template.replace( + placeholder, + "json_query(" + parentPartExpression + columnExpression + "')" + ); + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + case BLOB: + // We encode binary data as hex, so we have to decode here + if ( determineLength( column ) * 2 > JSON_VALUE_MAX_LENGTH ) { + // Since data is HEX encoded, multiply the max length by 2 since we need 2 hex chars per byte + return template.replace( + placeholder, + "(select convert(" + column.getColumnDefinition() + ",v,2) from openjson(" + aggregateParentReadExpression + ") with (v varchar(max) '$." + columnExpression + "'))" + ); + } + else { + return template.replace( + placeholder, + "convert(" + column.getColumnDefinition() + ",json_value(" + parentPartExpression + columnExpression + "'),2)" + ); + } + case CHAR: + case NCHAR: + case VARCHAR: + case NVARCHAR: + case LONG32VARCHAR: + case LONG32NVARCHAR: + case CLOB: + case NCLOB: + if ( determineLength( column ) > JSON_VALUE_MAX_LENGTH ) { + return template.replace( + placeholder, + "(select * from openjson(" + aggregateParentReadExpression + ") with (v " + column.getColumnDefinition() + " '$." + columnExpression + "'))" + ); + } + // Fall-through intended + case BIT: + case TINYINT: + case SMALLINT: + case INTEGER: + case BIGINT: + case REAL: + case FLOAT: + case DOUBLE: + case NUMERIC: + case DECIMAL: + case TIME: + case TIME_UTC: + case TIME_WITH_TIMEZONE: + case DATE: + case TIMESTAMP: + case TIMESTAMP_UTC: + case TIMESTAMP_WITH_TIMEZONE: + return template.replace( + placeholder, + "cast(json_value(" + parentPartExpression + columnExpression + "') as " + column.getColumnDefinition() + ")" + ); + default: + return template.replace( + placeholder, + "(select * from openjson(" + aggregateParentReadExpression + ") with (v " + column.getColumnDefinition() + " '$." + columnExpression + "'))" + ); + } + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode ); + } + + private static Long determineLength(SqlTypedMapping column) { + final Long length = column.getLength(); + if ( length != null ) { + return length; + } + else { + final String columnDefinition = column.getColumnDefinition(); + assert columnDefinition != null; + final int parenthesisIndex = columnDefinition.indexOf( '(' ); + if ( parenthesisIndex != -1 ) { + int end; + for ( end = parenthesisIndex + 1; end < columnDefinition.length(); end++ ) { + if ( !Character.isDigit( columnDefinition.charAt( end ) ) ) { + break; + } + } + return Long.parseLong( columnDefinition.substring( parenthesisIndex + 1, end ) ); + } + // Default to the max varchar length + return 8000L; + } + } + + @Override + public String aggregateComponentAssignmentExpression( + String aggregateParentAssignmentExpression, + String columnExpression, + int aggregateColumnTypeCode, + Column column) { + switch ( aggregateColumnTypeCode ) { + case JSON: + case JSON_ARRAY: + // For JSON we always have to replace the whole object + return aggregateParentAssignmentExpression; + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode ); + } + + private String jsonCustomWriteExpression( + String customWriteExpression, + JdbcMapping jdbcMapping, + SelectableMapping column, + TypeConfiguration typeConfiguration) { + switch ( jdbcMapping.getJdbcType().getDefaultSqlTypeCode() ) { + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + case BLOB: + return "convert(nvarchar(max)," + customWriteExpression + ",2)"; + case TIME: + return "left(" + customWriteExpression + ",8)"; + case DATE: + return "format(" + customWriteExpression + ",'yyyy-MM-dd')"; + case TIMESTAMP: + return "format(" + customWriteExpression + ",'yyyy-MM-ddTHH:mm:ss.fffffff')"; + case TIMESTAMP_UTC: + case TIMESTAMP_WITH_TIMEZONE: + return "format(" + customWriteExpression + ",'yyyy-MM-ddTHH:mm:ss.fffffffzzz')"; + case UUID: + return "cast(" + customWriteExpression + " as nvarchar(36))"; + case JSON: + case JSON_ARRAY: + return "json_query(" + customWriteExpression + ")"; + default: + return customWriteExpression; + } + } + + private static String determineElementTypeName( + Size castTargetSize, + BasicPluralType pluralType, + TypeConfiguration typeConfiguration) { + final DdlTypeRegistry ddlTypeRegistry = typeConfiguration.getDdlTypeRegistry(); + final BasicType expressionType = pluralType.getElementType(); + DdlType ddlType = ddlTypeRegistry.getDescriptor( expressionType.getJdbcType().getDdlTypeCode() ); + if ( ddlType == null ) { + // this may happen when selecting a null value like `SELECT null from ...` + // some dbs need the value to be cast so not knowing the real type we fall back to INTEGER + ddlType = ddlTypeRegistry.getDescriptor( SqlTypes.INTEGER ); + } + + return ddlType.getTypeName( castTargetSize, expressionType, ddlTypeRegistry ); + } + + @Override + public boolean requiresAggregateCustomWriteExpressionRenderer(int aggregateSqlTypeCode) { + return aggregateSqlTypeCode == JSON; + } + + @Override + public WriteExpressionRenderer aggregateCustomWriteExpressionRenderer( + SelectableMapping aggregateColumn, + SelectableMapping[] columnsToUpdate, + TypeConfiguration typeConfiguration) { + final int aggregateSqlTypeCode = aggregateColumn.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode(); + switch ( aggregateSqlTypeCode ) { + case JSON: + return jsonAggregateColumnWriter( aggregateColumn, columnsToUpdate, typeConfiguration ); + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateSqlTypeCode ); + } + + private WriteExpressionRenderer jsonAggregateColumnWriter( + SelectableMapping aggregateColumn, + SelectableMapping[] columns, + TypeConfiguration typeConfiguration) { + return new RootJsonWriteExpression( aggregateColumn, columns, this, typeConfiguration ); + } + + interface JsonWriteExpression { + void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression); + } + private static class AggregateJsonWriteExpression implements JsonWriteExpression { + private final LinkedHashMap subExpressions = new LinkedHashMap<>(); + protected final EmbeddableMappingType embeddableMappingType; + + public AggregateJsonWriteExpression(SelectableMapping selectableMapping, SQLServerAggregateSupport aggregateSupport) { + this.embeddableMappingType = ( (AggregateJdbcType) selectableMapping.getJdbcMapping().getJdbcType() ) + .getEmbeddableMappingType(); + } + + protected void initializeSubExpressions( + SelectableMapping[] columns, + SQLServerAggregateSupport aggregateSupport, + TypeConfiguration typeConfiguration) { + for ( SelectableMapping column : columns ) { + final SelectablePath selectablePath = column.getSelectablePath(); + final SelectablePath[] parts = selectablePath.getParts(); + AggregateJsonWriteExpression currentAggregate = this; + EmbeddableMappingType currentMappingType = embeddableMappingType; + for ( int i = 1; i < parts.length - 1; i++ ) { + final SelectableMapping selectableMapping = currentMappingType.getJdbcValueSelectable( + currentMappingType.getSelectableIndex( parts[i].getSelectableName() ) + ); + currentAggregate = (AggregateJsonWriteExpression) currentAggregate.subExpressions.computeIfAbsent( + parts[i].getSelectableName(), + k -> new AggregateJsonWriteExpression( selectableMapping, aggregateSupport ) + ); + currentMappingType = currentAggregate.embeddableMappingType; + } + final String customWriteExpression = column.getWriteExpression(); + currentAggregate.subExpressions.put( + parts[parts.length - 1].getSelectableName(), + new BasicJsonWriteExpression( + column, + aggregateSupport.jsonCustomWriteExpression( + customWriteExpression, + column.getJdbcMapping(), + column, + typeConfiguration + ) + ) + ); + } + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + for ( int i = 0; i < subExpressions.size() - 1; i++ ) { + sb.append( "json_modify(" ); + } + sb.append( "json_modify(" ); + sb.append( path ); + for ( Map.Entry entry : subExpressions.entrySet() ) { + final String column = entry.getKey(); + final JsonWriteExpression value = entry.getValue(); + final String subPath = "json_query(" + path + ",'$." + column + "')"; + sb.append( ",'$." ); + sb.append( column ); + sb.append( "'," ); + if ( value instanceof AggregateJsonWriteExpression ) { + value.append( sb, subPath, translator, expression ); + } + else { + value.append( sb, subPath, translator, expression ); + } + sb.append( ')' ); + } + } + } + + private static class RootJsonWriteExpression extends AggregateJsonWriteExpression + implements WriteExpressionRenderer { + private final boolean nullable; + private final String path; + + RootJsonWriteExpression( + SelectableMapping aggregateColumn, + SelectableMapping[] columns, + SQLServerAggregateSupport aggregateSupport, + TypeConfiguration typeConfiguration) { + super( aggregateColumn, aggregateSupport ); + this.nullable = aggregateColumn.isNullable(); + this.path = aggregateColumn.getSelectionExpression(); + initializeSubExpressions( columns, aggregateSupport, typeConfiguration ); + } + + @Override + public void render( + SqlAppender sqlAppender, + SqlAstTranslator translator, + AggregateColumnWriteExpression aggregateColumnWriteExpression, + String qualifier) { + final String basePath; + if ( qualifier == null || qualifier.isBlank() ) { + basePath = path; + } + else { + basePath = qualifier + "." + path; + } + append( + sqlAppender, + nullable ? "coalesce(" + basePath + ",'{}')" : basePath, + translator, + aggregateColumnWriteExpression + ); + } + } + + private static class BasicJsonWriteExpression implements JsonWriteExpression { + + private final SelectableMapping selectableMapping; + private final String customWriteExpressionStart; + private final String customWriteExpressionEnd; + + BasicJsonWriteExpression(SelectableMapping selectableMapping, String customWriteExpression) { + this.selectableMapping = selectableMapping; + if ( customWriteExpression.equals( "?" ) ) { + this.customWriteExpressionStart = ""; + this.customWriteExpressionEnd = ""; + } + else { + final String[] parts = StringHelper.split( "?", customWriteExpression ); + assert parts.length == 2; + this.customWriteExpressionStart = parts[0]; + this.customWriteExpressionEnd = parts[1]; + } + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + sb.append( customWriteExpressionStart ); + // We use NO_UNTYPED here so that expressions which require type inference are casted explicitly, + // since we don't know how the custom write expression looks like where this is embedded, + // so we have to be pessimistic and avoid ambiguities + translator.render( expression.getValueExpression( selectableMapping ), SqlAstNodeRenderingMode.NO_UNTYPED ); + sb.append( customWriteExpressionEnd ); + } + } + +} diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SQLServerUnnestFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SQLServerUnnestFunction.java index fbd0d6bdad..989d88ac9d 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SQLServerUnnestFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SQLServerUnnestFunction.java @@ -39,14 +39,14 @@ public class SQLServerUnnestFunction extends UnnestFunction { final ModelPart ordinalityPart = tupleType.findSubPart( CollectionPart.Nature.INDEX.getName(), null ); if ( ordinalityPart != null ) { sqlAppender.appendSql( "(select t.*,row_number() over (order by (select null)) " ); - sqlAppender.appendSql( ordinalityPart.asBasicValuedModelPart().getSelectableName() ); + sqlAppender.appendSql( ordinalityPart.asBasicValuedModelPart().getSelectionExpression() ); sqlAppender.appendSql( " from openjson(" ); } else { sqlAppender.appendSql( "openjson(" ); } array.accept( walker ); - sqlAppender.appendSql( ",'$[*]') with (" ); + sqlAppender.appendSql( ") with (" ); boolean[] comma = new boolean[1]; if ( tupleType.findSubPart( CollectionPart.Nature.ELEMENT.getName(), null ) == null ) { @@ -62,7 +62,7 @@ public class SQLServerUnnestFunction extends UnnestFunction { sqlAppender.append( selectableMapping.getSelectionExpression() ); sqlAppender.append( ' ' ); sqlAppender.append( getDdlType( selectableMapping, SqlTypes.JSON_ARRAY, walker ) ); - sqlAppender.appendSql( " path '$." ); + sqlAppender.appendSql( " '$." ); sqlAppender.append( selectableMapping.getSelectableName() ); sqlAppender.appendSql( '\'' ); } @@ -81,7 +81,7 @@ public class SQLServerUnnestFunction extends UnnestFunction { sqlAppender.append( selectableMapping.getSelectionExpression() ); sqlAppender.append( ' ' ); sqlAppender.append( getDdlType( selectableMapping, SqlTypes.JSON_ARRAY, walker ) ); - sqlAppender.appendSql( " path '$'" ); + sqlAppender.appendSql( " '$'" ); } } ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/SQLServerJsonTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/SQLServerJsonTableFunction.java index 9ab5f39111..afd06c189e 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/SQLServerJsonTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/SQLServerJsonTableFunction.java @@ -10,7 +10,6 @@ import org.hibernate.query.derived.AnonymousTupleTableGroupProducer; import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.expression.JsonExistsErrorBehavior; -import org.hibernate.sql.ast.tree.expression.JsonPathPassingClause; import org.hibernate.sql.ast.tree.expression.JsonQueryEmptyBehavior; import org.hibernate.sql.ast.tree.expression.JsonQueryErrorBehavior; import org.hibernate.sql.ast.tree.expression.JsonQueryWrapMode; @@ -44,23 +43,21 @@ public class SQLServerJsonTableFunction extends JsonTableFunction { arguments.jsonDocument().accept( walker ); if ( arguments.jsonPath() != null ) { sqlAppender.appendSql( ',' ); - final JsonPathPassingClause passingClause = arguments.passingClause(); - if ( passingClause != null ) { - JsonPathHelper.appendInlinedJsonPathIncludingPassingClause( - sqlAppender, - // Default behavior is NULL ON ERROR - arguments.errorBehavior() == JsonTableErrorBehavior.ERROR ? "strict " : "", - arguments.jsonPath(), - passingClause, - walker - ); + // Default behavior is NULL ON ERROR + final String prefix = arguments.errorBehavior() == JsonTableErrorBehavior.ERROR ? "strict " : ""; + final String jsonPathString; + if ( arguments.passingClause() != null ) { + jsonPathString = prefix + JsonPathHelper.inlinedJsonPathIncludingPassingClause( arguments.jsonPath(), + arguments.passingClause(), walker ); } else { - if ( arguments.errorBehavior() == JsonTableErrorBehavior.ERROR ) { - // Default behavior is NULL ON ERROR - sqlAppender.appendSql( "'strict '+" ); - } - arguments.jsonPath().accept( walker ); + jsonPathString = prefix + walker.getLiteralValue( arguments.jsonPath() ); + } + if ( jsonPathString.endsWith( "[*]" ) ) { + sqlAppender.appendSingleQuoteEscapedString( jsonPathString.substring( 0, jsonPathString.length() - 3 ) ); + } + else { + sqlAppender.appendSingleQuoteEscapedString( jsonPathString ); } } else if ( arguments.errorBehavior() == JsonTableErrorBehavior.ERROR ) {