From 842ebd0e7e9769c54426c0938209995e5606c8be Mon Sep 17 00:00:00 2001 From: Christian Beikov Date: Tue, 8 Feb 2022 15:26:57 +0100 Subject: [PATCH] Make use of function return type resolver for elements/indices functions --- .../AbstractSqmFunctionDescriptor.java | 8 +- .../sqm/sql/BaseSqmToSqlAstConverter.java | 171 ++++++++++++------ .../orm/test/query/hql/FunctionTests.java | 11 +- 3 files changed, 128 insertions(+), 62 deletions(-) diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/function/AbstractSqmFunctionDescriptor.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/function/AbstractSqmFunctionDescriptor.java index aa64f2445b..8385ea210a 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/function/AbstractSqmFunctionDescriptor.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/function/AbstractSqmFunctionDescriptor.java @@ -33,10 +33,6 @@ public abstract class AbstractSqmFunctionDescriptor implements SqmFunctionDescri private final FunctionReturnTypeResolver returnTypeResolver; private final String name; - protected FunctionReturnTypeResolver getReturnTypeResolver() { - return returnTypeResolver; - } - public AbstractSqmFunctionDescriptor(String name) { this( name, null, null ); } @@ -71,6 +67,10 @@ public abstract class AbstractSqmFunctionDescriptor implements SqmFunctionDescri return argumentsValidator; } + public FunctionReturnTypeResolver getReturnTypeResolver() { + return returnTypeResolver; + } + public String getReturnSignature() { String result = returnTypeResolver.getReturnType(); return result.isEmpty() ? "" : result + " "; diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java index 73b43bde6f..52bc0a674d 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java @@ -259,6 +259,7 @@ import org.hibernate.sql.ast.spi.SqlAstQueryPartProcessingState; import org.hibernate.sql.ast.spi.SqlAstTreeHelper; import org.hibernate.sql.ast.spi.SqlExpressionResolver; import org.hibernate.sql.ast.spi.SqlSelection; +import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.Statement; import org.hibernate.sql.ast.tree.cte.CteColumn; import org.hibernate.sql.ast.tree.cte.CteStatement; @@ -3341,10 +3342,8 @@ public abstract class BaseSqmToSqlAstConverter extends Base AbstractSqmSpecificPluralPartPath pluralPartPath, boolean index, String functionName) { - boolean isMinOrMax = functionName.equalsIgnoreCase("min") - || functionName.equalsIgnoreCase("max"); // Try to create a lateral sub-query join if possible which allows the re-use of the expression - if ( isMinOrMax && creationContext.getSessionFactory().getJdbcServices().getDialect().supportsLateral() ) { + if ( creationContext.getSessionFactory().getJdbcServices().getDialect().supportsLateral() ) { return createLateralJoinExpression( pluralPartPath, index, functionName ); } else { @@ -3627,7 +3626,10 @@ public abstract class BaseSqmToSqlAstConverter extends Base functionDescriptor, arguments, null, - (ReturnableType) modelPart.getJdbcMappings().get( 0 ), + (ReturnableType) functionDescriptor.getReturnTypeResolver().resolveFunctionReturnType( + () -> null, + arguments + ).getJdbcMapping(), modelPart ); subQuerySpec.getSelectClause().addSqlSelection( new SqlSelectionImpl( 1, 0, expression ) ); @@ -3713,6 +3715,15 @@ public abstract class BaseSqmToSqlAstConverter extends Base final List columnNames = new ArrayList<>( jdbcTypeCount ); final List resultColumnReferences = new ArrayList<>( jdbcTypeCount ); final NavigablePath navigablePath = pluralPartPath.getNavigablePath(); + final Boolean max = functionName.equalsIgnoreCase( "max" ) ? Boolean.TRUE + : ( functionName.equalsIgnoreCase( "min" ) ? Boolean.FALSE : null ); + final AbstractSqmSelfRenderingFunctionDescriptor functionDescriptor = + (AbstractSqmSelfRenderingFunctionDescriptor) creationContext + .getSessionFactory() + .getQueryEngine() + .getSqmFunctionRegistry() + .findFunctionDescriptor( functionName ); + final List subQueryColumns = new ArrayList<>( jdbcTypeCount ); modelPart.forEachSelectable( (selectionIndex, selectionMapping) -> { final ColumnReference columnReference = new ColumnReference( @@ -3731,39 +3742,83 @@ public abstract class BaseSqmToSqlAstConverter extends Base columnName = selectionMapping.getSelectionExpression(); } columnNames.add( columnName ); - subQuerySpec.getSelectClause().addSqlSelection( - new SqlSelectionImpl( - selectionIndex - 1, - selectionIndex, - columnReference - ) - ); - subQuerySpec.addSortSpecification( - new SortSpecification( - columnReference, - functionName.equalsIgnoreCase("max") - ? SortOrder.DESCENDING - : SortOrder.ASCENDING - ) - ); - resultColumnReferences.add( - new ColumnReference( - identifierVariable, - columnName, - false, - null, - null, - selectionMapping.getJdbcMapping(), - creationContext.getSessionFactory() - ) - ); + subQueryColumns.add( columnReference ); + if ( max != null ) { + subQuerySpec.addSortSpecification( + new SortSpecification( + columnReference, + max ? SortOrder.DESCENDING : SortOrder.ASCENDING + ) + ); + } } ); - subQuerySpec.setFetchClauseExpression( - new QueryLiteral<>( 1, basicType( Integer.class ) ), - FetchClauseType.ROWS_ONLY - ); + if ( max != null ) { + for ( int i = 0; i < subQueryColumns.size(); i++ ) { + subQuerySpec.getSelectClause().addSqlSelection( + new SqlSelectionImpl( + i + 1, + i, + subQueryColumns.get( i ) + ) + ); + resultColumnReferences.add( + new ColumnReference( + identifierVariable, + columnNames.get( i ), + false, + null, + null, + subQueryColumns.get( i ).getJdbcMapping(), + creationContext.getSessionFactory() + ) + ); + } + subQuerySpec.setFetchClauseExpression( + new QueryLiteral<>( 1, basicType( Integer.class ) ), + FetchClauseType.ROWS_ONLY + ); + } + else { + final List arguments; + if ( jdbcTypeCount == 1 ) { + arguments = subQueryColumns; + } + else { + arguments = Collections.singletonList( new SqlTuple( subQueryColumns, modelPart ) ); + } + final Expression expression = new SelfRenderingAggregateFunctionSqlAstExpression( + functionDescriptor.getName(), + functionDescriptor, + arguments, + null, + (ReturnableType) functionDescriptor.getReturnTypeResolver().resolveFunctionReturnType( + () -> null, + arguments + ).getJdbcMapping(), + modelPart + ); + + subQuerySpec.getSelectClause().addSqlSelection( + new SqlSelectionImpl( + 1, + 0, + expression + ) + ); + resultColumnReferences.add( + new ColumnReference( + identifierVariable, + columnNames.get( 0 ), + false, + null, + null, + expression.getExpressionType().getJdbcMappings().get( 0 ), + creationContext.getSessionFactory() + ) + ); + } subQuerySpec.applyPredicate( pluralAttributeMapping.getKeyDescriptor().generateJoinPredicate( parentFromClauseAccess.findTableGroup( @@ -3831,11 +3886,14 @@ public abstract class BaseSqmToSqlAstConverter extends Base } parentFromClauseAccess.registerTableGroup( lateralTableGroup.getNavigablePath(), lateralTableGroup ); if ( jdbcTypeCount == 1 ) { - return new BasicValuedPathInterpretation<>( - resultColumnReferences.get( 0 ), - queryPath, - (BasicValuedModelPart) modelPart, - lateralTableGroup + return new SelfRenderingFunctionSqlAstExpression( + pathName, + (sqlAppender, sqlAstArguments, walker) -> { + sqlAstArguments.get( 0 ).accept( walker ); + }, + resultColumnReferences, + (ReturnableType) resultColumnReferences.get( 0 ).getJdbcMapping(), + resultColumnReferences.get( 0 ).getJdbcMapping() ); } else { @@ -3848,19 +3906,28 @@ public abstract class BaseSqmToSqlAstConverter extends Base } final QueryPartTableReference tableReference = (QueryPartTableReference) lateralTableGroup.getPrimaryTableReference(); if ( jdbcTypeCount == 1 ) { - return new BasicValuedPathInterpretation<>( - new ColumnReference( - identifierVariable, - tableReference.getColumnNames().get( 0 ), - false, - null, - null, - modelPart.getJdbcMappings().get( 0 ), - creationContext.getSessionFactory() - ), - queryPath, - (BasicValuedModelPart) modelPart, - lateralTableGroup + final List sqlSelections = tableReference.getQueryPart() + .getFirstQuerySpec() + .getSelectClause() + .getSqlSelections(); + return new SelfRenderingFunctionSqlAstExpression( + pathName, + (sqlAppender, sqlAstArguments, walker) -> { + sqlAstArguments.get( 0 ).accept( walker ); + }, + Collections.singletonList( + new ColumnReference( + identifierVariable, + tableReference.getColumnNames().get( 0 ), + false, + null, + null, + sqlSelections.get( 0 ).getExpressionType().getJdbcMappings().get( 0 ), + creationContext.getSessionFactory() + ) + ), + (ReturnableType) sqlSelections.get( 0 ).getExpressionType().getJdbcMappings().get( 0 ), + sqlSelections.get( 0 ).getExpressionType() ); } else { diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/FunctionTests.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/FunctionTests.java index 939b2d41d2..aee305ec25 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/FunctionTests.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/FunctionTests.java @@ -100,13 +100,12 @@ public class FunctionTests { .getSingleResult(), is(2.0) ); assertThat( session.createQuery("select sum(index eol.listOfNumbers) from EntityOfLists eol") - .getSingleResult(), is(1) ); + .getSingleResult(), is(1L) ); assertThat( session.createQuery("select sum(element eol.listOfNumbers) from EntityOfLists eol") .getSingleResult(), is(3.0) ); - //TODO: why does this fail?? -// assertThat( session.createQuery("select avg(index eol.listOfNumbers) from EntityOfLists eol") -// .getSingleResult(), is(0.5) ); + assertThat( session.createQuery("select avg(index eol.listOfNumbers) from EntityOfLists eol") + .getSingleResult(), is(0.5) ); assertThat( session.createQuery("select avg(element eol.listOfNumbers) from EntityOfLists eol") .getSingleResult(), is(1.5) ); @@ -116,12 +115,12 @@ public class FunctionTests { .getSingleResult(), is(1.0) ); assertThat( session.createQuery("select sum(index eom.numberByNumber) from EntityOfMaps eom") - .getSingleResult(), is(1) ); + .getSingleResult(), is(1L) ); assertThat( session.createQuery("select sum(element eom.numberByNumber) from EntityOfMaps eom") .getSingleResult(), is(1.0) ); assertThat( session.createQuery("select avg(index eom.numberByNumber) from EntityOfMaps eom") - .getSingleResult(), is(1) ); + .getSingleResult(), is(1.0) ); assertThat( session.createQuery("select avg(element eom.numberByNumber) from EntityOfMaps eom") .getSingleResult(), is(1.0) ); }