Make use of function return type resolver for elements/indices functions

This commit is contained in:
Christian Beikov 2022-02-08 15:26:57 +01:00
parent 9b53ca8559
commit 842ebd0e7e
3 changed files with 128 additions and 62 deletions

View File

@ -33,10 +33,6 @@ public abstract class AbstractSqmFunctionDescriptor implements SqmFunctionDescri
private final FunctionReturnTypeResolver returnTypeResolver;
private final String name;
protected FunctionReturnTypeResolver getReturnTypeResolver() {
return returnTypeResolver;
}
public AbstractSqmFunctionDescriptor(String name) {
this( name, null, null );
}
@ -71,6 +67,10 @@ public abstract class AbstractSqmFunctionDescriptor implements SqmFunctionDescri
return argumentsValidator;
}
public FunctionReturnTypeResolver getReturnTypeResolver() {
return returnTypeResolver;
}
public String getReturnSignature() {
String result = returnTypeResolver.getReturnType();
return result.isEmpty() ? "" : result + " ";

View File

@ -259,6 +259,7 @@ import org.hibernate.sql.ast.spi.SqlAstQueryPartProcessingState;
import org.hibernate.sql.ast.spi.SqlAstTreeHelper;
import org.hibernate.sql.ast.spi.SqlExpressionResolver;
import org.hibernate.sql.ast.spi.SqlSelection;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.Statement;
import org.hibernate.sql.ast.tree.cte.CteColumn;
import org.hibernate.sql.ast.tree.cte.CteStatement;
@ -3341,10 +3342,8 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
AbstractSqmSpecificPluralPartPath<?> pluralPartPath,
boolean index,
String functionName) {
boolean isMinOrMax = functionName.equalsIgnoreCase("min")
|| functionName.equalsIgnoreCase("max");
// Try to create a lateral sub-query join if possible which allows the re-use of the expression
if ( isMinOrMax && creationContext.getSessionFactory().getJdbcServices().getDialect().supportsLateral() ) {
if ( creationContext.getSessionFactory().getJdbcServices().getDialect().supportsLateral() ) {
return createLateralJoinExpression( pluralPartPath, index, functionName );
}
else {
@ -3627,7 +3626,10 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
functionDescriptor,
arguments,
null,
(ReturnableType<?>) modelPart.getJdbcMappings().get( 0 ),
(ReturnableType<?>) functionDescriptor.getReturnTypeResolver().resolveFunctionReturnType(
() -> null,
arguments
).getJdbcMapping(),
modelPart
);
subQuerySpec.getSelectClause().addSqlSelection( new SqlSelectionImpl( 1, 0, expression ) );
@ -3713,6 +3715,15 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
final List<String> columnNames = new ArrayList<>( jdbcTypeCount );
final List<ColumnReference> resultColumnReferences = new ArrayList<>( jdbcTypeCount );
final NavigablePath navigablePath = pluralPartPath.getNavigablePath();
final Boolean max = functionName.equalsIgnoreCase( "max" ) ? Boolean.TRUE
: ( functionName.equalsIgnoreCase( "min" ) ? Boolean.FALSE : null );
final AbstractSqmSelfRenderingFunctionDescriptor functionDescriptor =
(AbstractSqmSelfRenderingFunctionDescriptor) creationContext
.getSessionFactory()
.getQueryEngine()
.getSqmFunctionRegistry()
.findFunctionDescriptor( functionName );
final List<ColumnReference> subQueryColumns = new ArrayList<>( jdbcTypeCount );
modelPart.forEachSelectable(
(selectionIndex, selectionMapping) -> {
final ColumnReference columnReference = new ColumnReference(
@ -3731,39 +3742,83 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
columnName = selectionMapping.getSelectionExpression();
}
columnNames.add( columnName );
subQuerySpec.getSelectClause().addSqlSelection(
new SqlSelectionImpl(
selectionIndex - 1,
selectionIndex,
columnReference
)
);
subQuerySpec.addSortSpecification(
new SortSpecification(
columnReference,
functionName.equalsIgnoreCase("max")
? SortOrder.DESCENDING
: SortOrder.ASCENDING
)
);
resultColumnReferences.add(
new ColumnReference(
identifierVariable,
columnName,
false,
null,
null,
selectionMapping.getJdbcMapping(),
creationContext.getSessionFactory()
)
);
subQueryColumns.add( columnReference );
if ( max != null ) {
subQuerySpec.addSortSpecification(
new SortSpecification(
columnReference,
max ? SortOrder.DESCENDING : SortOrder.ASCENDING
)
);
}
}
);
subQuerySpec.setFetchClauseExpression(
new QueryLiteral<>( 1, basicType( Integer.class ) ),
FetchClauseType.ROWS_ONLY
);
if ( max != null ) {
for ( int i = 0; i < subQueryColumns.size(); i++ ) {
subQuerySpec.getSelectClause().addSqlSelection(
new SqlSelectionImpl(
i + 1,
i,
subQueryColumns.get( i )
)
);
resultColumnReferences.add(
new ColumnReference(
identifierVariable,
columnNames.get( i ),
false,
null,
null,
subQueryColumns.get( i ).getJdbcMapping(),
creationContext.getSessionFactory()
)
);
}
subQuerySpec.setFetchClauseExpression(
new QueryLiteral<>( 1, basicType( Integer.class ) ),
FetchClauseType.ROWS_ONLY
);
}
else {
final List<? extends SqlAstNode> arguments;
if ( jdbcTypeCount == 1 ) {
arguments = subQueryColumns;
}
else {
arguments = Collections.singletonList( new SqlTuple( subQueryColumns, modelPart ) );
}
final Expression expression = new SelfRenderingAggregateFunctionSqlAstExpression(
functionDescriptor.getName(),
functionDescriptor,
arguments,
null,
(ReturnableType<?>) functionDescriptor.getReturnTypeResolver().resolveFunctionReturnType(
() -> null,
arguments
).getJdbcMapping(),
modelPart
);
subQuerySpec.getSelectClause().addSqlSelection(
new SqlSelectionImpl(
1,
0,
expression
)
);
resultColumnReferences.add(
new ColumnReference(
identifierVariable,
columnNames.get( 0 ),
false,
null,
null,
expression.getExpressionType().getJdbcMappings().get( 0 ),
creationContext.getSessionFactory()
)
);
}
subQuerySpec.applyPredicate(
pluralAttributeMapping.getKeyDescriptor().generateJoinPredicate(
parentFromClauseAccess.findTableGroup(
@ -3831,11 +3886,14 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
}
parentFromClauseAccess.registerTableGroup( lateralTableGroup.getNavigablePath(), lateralTableGroup );
if ( jdbcTypeCount == 1 ) {
return new BasicValuedPathInterpretation<>(
resultColumnReferences.get( 0 ),
queryPath,
(BasicValuedModelPart) modelPart,
lateralTableGroup
return new SelfRenderingFunctionSqlAstExpression(
pathName,
(sqlAppender, sqlAstArguments, walker) -> {
sqlAstArguments.get( 0 ).accept( walker );
},
resultColumnReferences,
(ReturnableType<?>) resultColumnReferences.get( 0 ).getJdbcMapping(),
resultColumnReferences.get( 0 ).getJdbcMapping()
);
}
else {
@ -3848,19 +3906,28 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
}
final QueryPartTableReference tableReference = (QueryPartTableReference) lateralTableGroup.getPrimaryTableReference();
if ( jdbcTypeCount == 1 ) {
return new BasicValuedPathInterpretation<>(
new ColumnReference(
identifierVariable,
tableReference.getColumnNames().get( 0 ),
false,
null,
null,
modelPart.getJdbcMappings().get( 0 ),
creationContext.getSessionFactory()
),
queryPath,
(BasicValuedModelPart) modelPart,
lateralTableGroup
final List<SqlSelection> sqlSelections = tableReference.getQueryPart()
.getFirstQuerySpec()
.getSelectClause()
.getSqlSelections();
return new SelfRenderingFunctionSqlAstExpression(
pathName,
(sqlAppender, sqlAstArguments, walker) -> {
sqlAstArguments.get( 0 ).accept( walker );
},
Collections.singletonList(
new ColumnReference(
identifierVariable,
tableReference.getColumnNames().get( 0 ),
false,
null,
null,
sqlSelections.get( 0 ).getExpressionType().getJdbcMappings().get( 0 ),
creationContext.getSessionFactory()
)
),
(ReturnableType<?>) sqlSelections.get( 0 ).getExpressionType().getJdbcMappings().get( 0 ),
sqlSelections.get( 0 ).getExpressionType()
);
}
else {

View File

@ -100,13 +100,12 @@ public class FunctionTests {
.getSingleResult(), is(2.0) );
assertThat( session.createQuery("select sum(index eol.listOfNumbers) from EntityOfLists eol")
.getSingleResult(), is(1) );
.getSingleResult(), is(1L) );
assertThat( session.createQuery("select sum(element eol.listOfNumbers) from EntityOfLists eol")
.getSingleResult(), is(3.0) );
//TODO: why does this fail??
// assertThat( session.createQuery("select avg(index eol.listOfNumbers) from EntityOfLists eol")
// .getSingleResult(), is(0.5) );
assertThat( session.createQuery("select avg(index eol.listOfNumbers) from EntityOfLists eol")
.getSingleResult(), is(0.5) );
assertThat( session.createQuery("select avg(element eol.listOfNumbers) from EntityOfLists eol")
.getSingleResult(), is(1.5) );
@ -116,12 +115,12 @@ public class FunctionTests {
.getSingleResult(), is(1.0) );
assertThat( session.createQuery("select sum(index eom.numberByNumber) from EntityOfMaps eom")
.getSingleResult(), is(1) );
.getSingleResult(), is(1L) );
assertThat( session.createQuery("select sum(element eom.numberByNumber) from EntityOfMaps eom")
.getSingleResult(), is(1.0) );
assertThat( session.createQuery("select avg(index eom.numberByNumber) from EntityOfMaps eom")
.getSingleResult(), is(1) );
.getSingleResult(), is(1.0) );
assertThat( session.createQuery("select avg(element eom.numberByNumber) from EntityOfMaps eom")
.getSingleResult(), is(1.0) );
}