From 36bdd9d0130ebc1e8a3a7eb5d7e5e2c4d9f1d84f Mon Sep 17 00:00:00 2001 From: Gavin Date: Sun, 28 May 2023 09:57:00 +0200 Subject: [PATCH] fix NPE when ordered set agg function missing 'within group' - also refactor a bit that code in SQB to be more typesafe - and get rid of some warnings --- .../org/hibernate/grammars/hql/HqlParser.g4 | 2 +- .../HypotheticalSetWindowEmulation.java | 1 - .../function/InverseDistributionFunction.java | 43 +- .../InverseDistributionWindowEmulation.java | 1 - .../hql/internal/SemanticQueryBuilder.java | 457 ++++++++---------- .../query/sqm/function/FunctionKind.java | 2 +- ...nderingSqmOrderedSetAggregateFunction.java | 4 +- .../test/query/hql/WindowFunctionTest.java | 1 - 8 files changed, 222 insertions(+), 289 deletions(-) diff --git a/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 b/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 index c455625c50..b34cb823fa 100644 --- a/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 +++ b/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 @@ -1154,7 +1154,7 @@ listaggFunction * A 'on overflow' clause: what to do when the text data type used for 'listagg' overflows */ onOverflowClause - : ON OVERFLOW (ERROR | (TRUNCATE expression? (WITH|WITHOUT) COUNT)) + : ON OVERFLOW (ERROR | TRUNCATE expression? (WITH|WITHOUT) COUNT) ; /** diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/HypotheticalSetWindowEmulation.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/HypotheticalSetWindowEmulation.java index bc1d10a58a..c3cd67686d 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/HypotheticalSetWindowEmulation.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/HypotheticalSetWindowEmulation.java @@ -14,7 +14,6 @@ import org.hibernate.query.ReturnableType; import org.hibernate.query.spi.QueryEngine; import org.hibernate.query.sqm.function.SelfRenderingFunctionSqlAstExpression; import org.hibernate.query.sqm.function.SelfRenderingOrderedSetAggregateFunctionSqlAstExpression; -import org.hibernate.query.sqm.function.SelfRenderingSqmAggregateFunction; import org.hibernate.query.sqm.function.SelfRenderingSqmOrderedSetAggregateFunction; import org.hibernate.query.sqm.produce.function.ArgumentsValidator; import org.hibernate.query.sqm.sql.SqmToSqlAstConverter; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/InverseDistributionFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/InverseDistributionFunction.java index 4e8e498609..32b386a9dd 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/InverseDistributionFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/InverseDistributionFunction.java @@ -10,9 +10,10 @@ import java.util.Collections; import java.util.List; import org.hibernate.metamodel.mapping.BasicValuedMapping; +import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.metamodel.mapping.MappingModelExpressible; -import org.hibernate.metamodel.spi.MappingMetamodelImplementor; import org.hibernate.query.ReturnableType; +import org.hibernate.query.SemanticException; import org.hibernate.query.spi.QueryEngine; import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor; import org.hibernate.query.sqm.function.FunctionKind; @@ -129,13 +130,12 @@ public class InverseDistributionFunction extends AbstractSqmSelfRenderingFunctio protected class SelfRenderingInverseDistributionFunction extends SelfRenderingSqmOrderedSetAggregateFunction { - private final SqmOrderByClause withinGroupClause; - public SelfRenderingInverseDistributionFunction( List> arguments, SqmPredicate filter, SqmOrderByClause withinGroupClause, - ReturnableType impliedResultType, QueryEngine queryEngine) { + ReturnableType impliedResultType, + QueryEngine queryEngine) { super( InverseDistributionFunction.this, InverseDistributionFunction.this, @@ -148,13 +148,18 @@ public class InverseDistributionFunction extends AbstractSqmSelfRenderingFunctio queryEngine.getCriteriaBuilder(), InverseDistributionFunction.this.getName() ); - this.withinGroupClause = withinGroupClause; + if ( withinGroupClause == null ) { + throw new SemanticException("Inverse distribution function '" + getFunctionName() + + "' must specify WITHIN GROUP"); + } } @Override protected ReturnableType resolveResultType(TypeConfiguration typeConfiguration) { - return (ReturnableType) withinGroupClause.getSortSpecifications().get( 0 ).getSortExpression() - .getExpressible(); + return (ReturnableType) + getWithinGroup().getSortSpecifications().get( 0 ) + .getSortExpression() + .getExpressible(); } @Override @@ -171,21 +176,23 @@ public class InverseDistributionFunction extends AbstractSqmSelfRenderingFunctio // here we have something that is not a BasicType, // and we have no way to get a BasicValuedMapping // from it directly - final Expression expression = (Expression) withinGroupClause.getSortSpecifications().get( 0 ) - .getSortExpression() - .accept( walker ); - if ( expression.getExpressionType() instanceof BasicValuedMapping ) { - return (BasicValuedMapping) expression.getExpressionType(); + final Expression expression = (Expression) + getWithinGroup().getSortSpecifications().get( 0 ) + .getSortExpression() + .accept( walker ); + final JdbcMappingContainer expressionType = expression.getExpressionType(); + if ( expressionType instanceof BasicValuedMapping ) { + return (BasicValuedMapping) expressionType; } try { - final MappingMetamodelImplementor domainModel = walker.getCreationContext() + return walker.getCreationContext() .getSessionFactory() .getRuntimeMetamodels() - .getMappingMetamodel(); - return domainModel.resolveMappingExpressible( - getNodeType(), - walker.getFromClauseAccess()::getTableGroup - ); + .getMappingMetamodel() + .resolveMappingExpressible( + getNodeType(), + walker.getFromClauseAccess()::getTableGroup + ); } catch (Exception e) { return null; // this works at least approximately diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/InverseDistributionWindowEmulation.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/InverseDistributionWindowEmulation.java index 51c15147d0..e1ad3eda13 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/InverseDistributionWindowEmulation.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/InverseDistributionWindowEmulation.java @@ -14,7 +14,6 @@ import org.hibernate.query.ReturnableType; import org.hibernate.query.spi.QueryEngine; import org.hibernate.query.sqm.function.SelfRenderingFunctionSqlAstExpression; import org.hibernate.query.sqm.function.SelfRenderingOrderedSetAggregateFunctionSqlAstExpression; -import org.hibernate.query.sqm.function.SelfRenderingSqmAggregateFunction; import org.hibernate.query.sqm.function.SelfRenderingSqmOrderedSetAggregateFunction; import org.hibernate.query.sqm.produce.function.ArgumentsValidator; import org.hibernate.query.sqm.produce.function.FunctionParameterType; diff --git a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java index 4ae85b5568..b9b3aa0b7a 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java @@ -24,7 +24,6 @@ import java.time.ZonedDateTime; import java.time.temporal.TemporalAccessor; import java.util.ArrayList; import java.util.Calendar; -import java.util.Collections; import java.util.GregorianCalendar; import java.util.HashSet; import java.util.List; @@ -367,7 +366,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem @Override public SqmStatement visitStatement(HqlParser.StatementContext ctx) { - // parameters allow multi-valued bindings only in very limited cases, so for + // parameters allow multivalued bindings only in very limited cases, so for // the base case here we say false parameterDeclarationContextStack.push( () -> false ); @@ -1070,7 +1069,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem return; } final SqmOrderByClause orderByClause; - final HqlParser.OrderByClauseContext orderByClauseContext = (HqlParser.OrderByClauseContext) ctx.getChild( 0 ); + final HqlParser.OrderByClauseContext orderByClauseContext = ctx.orderByClause(); if ( orderByClauseContext != null ) { if ( creationOptions.useStrictJpaCompliance() && processingStateStack.depth() > 1 ) { throw new StrictJpaComplianceViolation( @@ -1085,29 +1084,10 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem orderByClause = null; } - int currentIndex = 1; - final HqlParser.LimitClauseContext limitClauseContext; - if ( currentIndex < ctx.getChildCount() && ctx.getChild( currentIndex ) instanceof HqlParser.LimitClauseContext ) { - limitClauseContext = (HqlParser.LimitClauseContext) ctx.getChild( currentIndex++ ); - } - else { - limitClauseContext = null; - } - final HqlParser.OffsetClauseContext offsetClauseContext; - if ( currentIndex < ctx.getChildCount() && ctx.getChild( currentIndex ) instanceof HqlParser.OffsetClauseContext ) { - offsetClauseContext = (HqlParser.OffsetClauseContext) ctx.getChild( currentIndex++ ); - } - else { - offsetClauseContext = null; - } - final HqlParser.FetchClauseContext fetchClauseContext; - if ( currentIndex < ctx.getChildCount() && ctx.getChild( currentIndex ) instanceof HqlParser.FetchClauseContext ) { - fetchClauseContext = (HqlParser.FetchClauseContext) ctx.getChild( currentIndex++ ); - } - else { - fetchClauseContext = null; - } - if ( currentIndex != 1 ) { + final HqlParser.LimitClauseContext limitClauseContext = ctx.limitClause(); + final HqlParser.OffsetClauseContext offsetClauseContext = ctx.offsetClause(); + final HqlParser.FetchClauseContext fetchClauseContext = ctx.fetchClause(); + if ( limitClauseContext != null || offsetClauseContext != null || fetchClauseContext != null ) { if ( getCreationOptions().useStrictJpaCompliance() ) { throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.LIMIT_OFFSET_CLAUSE @@ -1122,7 +1102,10 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem sqmQueryPart.setOffsetExpression( visitOffsetClause( offsetClauseContext ) ); if ( limitClauseContext == null ) { - sqmQueryPart.setFetchExpression( visitFetchClause( fetchClauseContext ), visitFetchClauseType( fetchClauseContext ) ); + sqmQueryPart.setFetchExpression( + visitFetchClause( fetchClauseContext ), + visitFetchClauseType( fetchClauseContext ) + ); } else if ( fetchClauseContext == null ) { sqmQueryPart.setFetchExpression( visitLimitClause( limitClauseContext ) ); @@ -1173,7 +1156,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem } protected SqmSelectClause buildInferredSelectClause(SqmFromClause fromClause) { - // for now, this is slightly different than the legacy behavior where + // for now, this is slightly different to the legacy behavior where // the root and each non-fetched-join was selected. For now, here, we simply // select the root final SqmSelectClause selectClause; @@ -1261,11 +1244,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem // if the node is not a dynamic-instantiation, register it with // the path-registry - //noinspection StatementWithEmptyBody - if ( selectableNode instanceof SqmDynamicInstantiation ) { - // nothing else to do (avoid kludgy `! ( instanceof )` syntax - } - else { + if ( !(selectableNode instanceof SqmDynamicInstantiation) ) { getCurrentProcessingState().getPathRegistry().register( selection ); } @@ -1377,11 +1356,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem creationContext.getNodeBuilder() ); - //noinspection StatementWithEmptyBody - if ( argExpression instanceof SqmDynamicInstantiation ) { - // nothing else to do (avoid kludgy `! ( instanceof )` syntax - } - else { + if ( !(argExpression instanceof SqmDynamicInstantiation) ) { getCurrentProcessingState().getPathRegistry().register( argument ); } @@ -1541,7 +1516,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem // This is syntactically disallowed throw new ParsingException( "COLLATE is not allowed for alias based order-by or group-by items" ); } - // this will group-by all of the sub-parts in the from-element's model part + // this will group-by all the sub-parts in the from-element's model part return sqmFrom; } @@ -2638,7 +2613,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem final int estimatedSize = size >> 1; final Class testExpressionJavaType = testExpression.getJavaType(); final boolean isEnum = testExpressionJavaType != null && testExpressionJavaType.isEnum(); - // Multi-valued bindings are only allowed if there is a single list item, hence size 3 (LP, RP and param) + // Multivalued bindings are only allowed if there is a single list item, hence size 3 (LP, RP and param) parameterDeclarationContextStack.push( () -> size == 3 ); try { final List> listExpressions = new ArrayList<>( estimatedSize ); @@ -2681,7 +2656,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem try { return new SqmInListPredicate( testExpression, - Collections.singletonList( tupleExpressionListContext.getChild( 0 ).accept( this ) ), + singletonList( tupleExpressionListContext.getChild( 0 ).accept( this ) ), negated, creationContext.getNodeBuilder() ); @@ -3912,97 +3887,36 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem return functionName.toString(); } - @Override - public Object visitGenericFunction(HqlParser.GenericFunctionContext ctx) { + private String getFunctionName(HqlParser.GenericFunctionContext ctx) { final String originalFunctionName = visitGenericFunctionName( ctx.genericFunctionName() ); final String functionName = originalFunctionName.toLowerCase(); if ( creationOptions.useStrictJpaCompliance() && !JPA_STANDARD_FUNCTIONS.contains( functionName ) ) { throw new StrictJpaComplianceViolation( "Encountered non-compliant non-standard function call [" + originalFunctionName + "], but strict JPA " + - "compliance was requested; use JPA's FUNCTION(functionName[,...]) " + + "compliance was requested; use FUNCTION(functionName[,...]) " + "syntax name instead", StrictJpaComplianceViolation.Type.FUNCTION_CALL ); } + return functionName; + } - //TODO: this fragment of code is extremely fragile and lacking in typesafety! - final ParseTree argumentChild = ctx.getChild( 2 ); - final List> functionArguments; - if ( argumentChild instanceof HqlParser.GenericFunctionArgumentsContext ) { - @SuppressWarnings("unchecked") - List> node = (List>) argumentChild.accept(this); - functionArguments = node; - } - else if ( "*".equals( argumentChild.getText() ) ) { - functionArguments = Collections.singletonList( new SqmStar( getCreationContext().getNodeBuilder() ) ); - } - else { - functionArguments = emptyList(); - } + @Override + public Object visitGenericFunction(HqlParser.GenericFunctionContext ctx) { + final SqmFunctionDescriptor functionTemplate = getFunctionTemplate( ctx ); - final Boolean fromFirst = getFromFirst( ctx ); - final Boolean respectNulls = getRespectNullsClause( ctx ); - final SqmOrderByClause withinGroup = getWithinGroup( ctx ); + final List> functionArguments = getFunctionArguments( ctx ); final SqmPredicate filterExpression = getFilterExpression( ctx ); - final boolean hasOverClause = ctx.getChild( ctx.getChildCount() - 1 ) instanceof HqlParser.OverClauseContext; - SqmFunctionDescriptor functionTemplate = getFunctionDescriptor( functionName ); - if ( functionTemplate == null ) { - FunctionKind functionKind = FunctionKind.NORMAL; - if ( withinGroup != null ) { - functionKind = FunctionKind.ORDERED_SET_AGGREGATE; - } - else if ( hasOverClause ) { - functionKind = FunctionKind.WINDOW; - } - else if ( filterExpression != null ) { - functionKind = FunctionKind.AGGREGATE; - } - functionTemplate = new NamedSqmFunctionDescriptor( - functionName, - true, - null, - StandardFunctionReturnTypeResolvers.invariant( - resolveExpressibleTypeBasic( Object.class ) - ), - null, - functionName, - functionKind, - null, - SqlAstNodeRenderingMode.DEFAULT - ); - } - else { - if ( hasOverClause && functionTemplate.getFunctionKind() == FunctionKind.NORMAL ) { - throw new SemanticException( "OVER clause is illegal for normal function: " + functionName ); - } - else if ( !hasOverClause && functionTemplate.getFunctionKind() == FunctionKind.WINDOW ) { - throw new SemanticException( "OVER clause is mandatory for window-only function: " + functionName ); - } - if ( respectNulls != null ) { - switch ( functionName ) { - case "lag": - case "lead": - case "first_value": - case "last_value": - case "nth_value": - break; - default: - throw new SemanticException( "RESPECT/IGNORE NULLS is illegal for function: " + functionName ); - } - } - if ( fromFirst != null && !"nth_value".equals( functionName ) ) { - throw new SemanticException( "FROM FIRST/LAST is illegal for function: " + functionName ); - } - } - final SqmFunction function; switch ( functionTemplate.getFunctionKind() ) { case ORDERED_SET_AGGREGATE: function = functionTemplate.generateOrderedSetAggregateSqmExpression( functionArguments, filterExpression, - withinGroup, + ctx.withinGroupClause() == null + ? null // this is allowed for e.g. rank(), but not for all + : visitOrderByClause( ctx.withinGroupClause().orderByClause(), false ), null, creationContext.getQueryEngine(), creationContext.getJpaMetamodel().getTypeConfiguration() @@ -4029,9 +3943,6 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem ); break; default: - if ( filterExpression != null ) { - throw new ParsingException( "Illegal use of a FILTER clause for non-aggregate function: " + originalFunctionName ); - } function = functionTemplate.generateSqmExpression( functionArguments, null, @@ -4040,7 +3951,95 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem ); break; } - return applyOverClause( ctx, function ); + return applyOverClause( ctx.overClause(), function ); + } + + private SqmFunctionDescriptor getFunctionTemplate(HqlParser.GenericFunctionContext ctx) { + + final String functionName = getFunctionName( ctx ); + + final SqmFunctionDescriptor functionTemplate = getFunctionDescriptor( functionName ); + if ( functionTemplate == null ) { + return new NamedSqmFunctionDescriptor( + functionName, + true, + null, + StandardFunctionReturnTypeResolvers.invariant( + resolveExpressibleTypeBasic( Object.class ) + ), + null, + functionName, + inferFunctionKind( ctx ), + null, + SqlAstNodeRenderingMode.DEFAULT + ); + } + else { + final FunctionKind functionKind = functionTemplate.getFunctionKind(); + if ( ctx.filterClause() != null && functionKind == FunctionKind.NORMAL ) { + throw new ParsingException( "FILTER clause is illegal for non-aggregate function: " + functionName ); + } + if ( ctx.overClause() != null && functionKind == FunctionKind.NORMAL ) { + throw new SemanticException( "OVER clause is illegal for non-aggregate function: " + functionName); + } + if ( ctx.withinGroupClause() != null && functionKind == FunctionKind.NORMAL ) { + throw new SemanticException( "WITHIN GROUP clause is illegal for non-aggregate function: " + functionName); + } + if ( ctx.overClause() == null && functionKind == FunctionKind.WINDOW ) { + throw new SemanticException( "OVER clause is mandatory for window-only function: " + functionName ); + } + if ( ctx.withinGroupClause() == null && ctx.overClause() == null + && functionKind == FunctionKind.ORDERED_SET_AGGREGATE ) { + throw new SemanticException( "WITHIN GROUP or OVER clause is mandatory for ordered set aggregate function: " + functionName ); + } + + if ( ctx.nullsClause() != null ) { + switch ( functionName ) { + case "lag": + case "lead": + case "first_value": + case "last_value": + case "nth_value": + break; + default: + throw new SemanticException( "RESPECT/IGNORE NULLS is illegal for function: " + functionName ); + } + } + if ( ctx.nthSideClause() != null && !"nth_value".equals( functionName ) ) { + throw new SemanticException( "FROM FIRST/LAST is illegal for function: " + functionName ); + } + return functionTemplate; + } + } + + private static FunctionKind inferFunctionKind(HqlParser.GenericFunctionContext ctx) { + if ( ctx.withinGroupClause() != null ) { + return FunctionKind.ORDERED_SET_AGGREGATE; + } + else if ( ctx.overClause() != null ) { + return FunctionKind.WINDOW; + } + else if ( ctx.filterClause() != null ) { + return FunctionKind.AGGREGATE; + } + else { + return FunctionKind.NORMAL; + } + } + + private List> getFunctionArguments(HqlParser.GenericFunctionContext ctx) { + if ( ctx.genericFunctionArguments() != null ) { + @SuppressWarnings("unchecked") + final List> node = (List>) + ctx.genericFunctionArguments().accept(this); + return node; + } + else if ( ctx.ASTERISK() != null ) { + return singletonList( new SqmStar( getCreationContext().getNodeBuilder() ) ); + } + else { + return emptyList(); + } } @Override @@ -4048,46 +4047,53 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem if ( creationOptions.useStrictJpaCompliance() ) { throw new StrictJpaComplianceViolation( "Encountered non-compliant non-standard function call [listagg], but strict JPA " + - "compliance was requested; use JPA's FUNCTION(functionName[,...]) " + + "compliance was requested; use FUNCTION(functionName[,...]) " + "syntax name instead", StrictJpaComplianceViolation.Type.FUNCTION_CALL ); } + final SqmFunctionDescriptor functionTemplate = getFunctionDescriptor( "listagg" ); if ( functionTemplate == null ) { - throw new SemanticException( - "The listagg function was not registered for the dialect" - ); + throw new SemanticException( "The listagg function was not registered for the dialect" ); } - final int argumentStartIndex; - final ParseTree thirdChild = ctx.getChild( 2 ); - final boolean distinct; - if ( thirdChild instanceof TerminalNode ) { - distinct = true; - argumentStartIndex = 3; - } - else { - distinct = false; - argumentStartIndex = 2; - } - final SqmExpression firstArgument = (SqmExpression) ctx.getChild( argumentStartIndex ).accept( this ); - final SqmExpression secondArgument = (SqmExpression) ctx.getChild( argumentStartIndex + 2 ).accept( this ); - final ParseTree overflowCtx = ctx.getChild( argumentStartIndex + 3 ); + + return applyOverClause( + ctx.overClause(), + functionTemplate.generateOrderedSetAggregateSqmExpression( + getListaggArguments( ctx ), + getFilterExpression( ctx ), + ctx.withinGroupClause() == null + ? null // this is allowed + : visitOrderByClause( ctx.withinGroupClause().orderByClause(), false ), + null, + creationContext.getQueryEngine(), + creationContext.getJpaMetamodel().getTypeConfiguration() + ) + ); + } + + private List> getListaggArguments(ListaggFunctionContext ctx) { + final SqmExpression firstArgument = (SqmExpression) ctx.expressionOrPredicate(0).accept( this ); + final SqmExpression secondArgument = (SqmExpression) ctx.expressionOrPredicate(1).accept( this ); + final OnOverflowClauseContext overflowCtx = ctx.onOverflowClause(); final List> functionArguments = new ArrayList<>( 3 ); - if ( distinct ) { + if ( ctx.DISTINCT() != null ) { functionArguments.add( new SqmDistinct<>( firstArgument, creationContext.getNodeBuilder() ) ); } else { functionArguments.add( firstArgument ); } - if ( overflowCtx instanceof OnOverflowClauseContext ) { - if ( overflowCtx.getChildCount() > 3 ) { + if ( overflowCtx != null ) { + if ( overflowCtx.ERROR() != null ) { + // ON OVERFLOW ERROR + functionArguments.add( new SqmOverflow<>( secondArgument, null, false ) ); + } + else { // ON OVERFLOW TRUNCATE - final TerminalNode countNode = (TerminalNode) overflowCtx.getChild( overflowCtx.getChildCount() - 2 ); - final boolean withCount = countNode.getSymbol().getType() == HqlParser.WITH; final SqmExpression fillerExpression; - if ( overflowCtx.getChildCount() == 6 ) { - fillerExpression = (SqmExpression) overflowCtx.getChild( 3 ).accept( this ); + if ( overflowCtx.expression() != null ) { + fillerExpression = (SqmExpression) overflowCtx.expression().accept( this ); } else { // The SQL standard says the default is three periods `...` @@ -4097,30 +4103,15 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem secondArgument.nodeBuilder() ); } + final boolean withCount = overflowCtx.WITH() != null; //noinspection unchecked,rawtypes functionArguments.add( new SqmOverflow( secondArgument, fillerExpression, withCount ) ); } - else { - // ON OVERFLOW ERROR - functionArguments.add( new SqmOverflow<>( secondArgument, null, false ) ); - } } else { functionArguments.add( secondArgument ); } - final SqmOrderByClause withinGroup = getWithinGroup( ctx ); - final SqmPredicate filterExpression = getFilterExpression( ctx ); - return applyOverClause( - ctx, - functionTemplate.generateOrderedSetAggregateSqmExpression( - functionArguments, - filterExpression, - withinGroup, - null, - creationContext.getQueryEngine(), - creationContext.getJpaMetamodel().getTypeConfiguration() - ) - ); + return functionArguments; } @Override @@ -4481,23 +4472,20 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem @Override public SqmExpression visitEveryFunction(HqlParser.EveryFunctionContext ctx) { - final SqmPredicate filterExpression = getFilterExpression( ctx ); - final ParseTree argumentChild = ctx.getChild( 2 ); - if ( argumentChild instanceof HqlParser.SubqueryContext ) { - final SqmSubQuery subquery = (SqmSubQuery) argumentChild.accept( this ); + if ( ctx.subquery() != null ) { + final SqmSubQuery subquery = (SqmSubQuery) ctx.subquery().accept( this ); return new SqmEvery<>( subquery, creationContext.getNodeBuilder() ); } - else if ( argumentChild instanceof HqlParser.PredicateContext ) { + else if ( ctx.predicate() != null ) { if ( getCreationOptions().useStrictJpaCompliance() ) { throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.FUNCTION_CALL ); } - final SqmExpression argument = (SqmExpression) argumentChild.accept( this ); - + final SqmExpression argument = (SqmExpression) ctx.predicate().accept( this ); return applyOverClause( - ctx, + ctx.overClause(), getFunctionDescriptor( "every" ).generateAggregateSqmExpression( singletonList( argument ), - filterExpression, + getFilterExpression( ctx ), resolveExpressibleTypeBasic( Boolean.class ), creationContext.getQueryEngine(), creationContext.getJpaMetamodel().getTypeConfiguration() @@ -4509,10 +4497,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.HQL_COLLECTION_FUNCTION ); } return new SqmEvery<>( - createCollectionReferenceSubQuery( - (HqlParser.SimplePathContext) ctx.getChild( 3 ), - (TerminalNode) ctx.getChild( 1 ) - ), + createCollectionReferenceSubQuery( ctx.simplePath(), (TerminalNode) ctx.getChild( 1 ) ), creationContext.getNodeBuilder() ); } @@ -4520,23 +4505,20 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem @Override public SqmExpression visitAnyFunction(HqlParser.AnyFunctionContext ctx) { - final SqmPredicate filterExpression = getFilterExpression( ctx ); - final ParseTree argumentChild = ctx.getChild( 2 ); - if ( argumentChild instanceof HqlParser.SubqueryContext ) { - final SqmSubQuery subquery = (SqmSubQuery) argumentChild.accept( this ); + if ( ctx.subquery() != null ) { + final SqmSubQuery subquery = (SqmSubQuery) ctx.subquery().accept( this ); return new SqmAny<>( subquery, creationContext.getNodeBuilder() ); } - else if ( argumentChild instanceof HqlParser.PredicateContext ) { + else if ( ctx.predicate() != null ) { if ( getCreationOptions().useStrictJpaCompliance() ) { throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.FUNCTION_CALL ); } - final SqmExpression argument = (SqmExpression) argumentChild.accept( this ); - + final SqmExpression argument = (SqmExpression) ctx.predicate().accept( this ); return applyOverClause( - ctx, + ctx.overClause(), getFunctionDescriptor( "any" ).generateAggregateSqmExpression( singletonList( argument ), - filterExpression, + getFilterExpression( ctx ), resolveExpressibleTypeBasic( Boolean.class ), creationContext.getQueryEngine(), creationContext.getJpaMetamodel().getTypeConfiguration() @@ -4548,10 +4530,7 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.HQL_COLLECTION_FUNCTION ); } return new SqmAny<>( - createCollectionReferenceSubQuery( - (HqlParser.SimplePathContext) ctx.getChild( 3 ), - (TerminalNode) ctx.getChild( 1 ) - ), + createCollectionReferenceSubQuery( ctx.simplePath(), (TerminalNode) ctx.getChild( 1 ) ), creationContext.getNodeBuilder() ); } @@ -4620,44 +4599,6 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem return (SqmSubQuery) subQuery; } - private SqmOrderByClause getWithinGroup(ParseTree functionCtx) { - HqlParser.WithinGroupClauseContext ctx = null; - for ( int i = functionCtx.getChildCount() - 3; i < functionCtx.getChildCount(); i++ ) { - final ParseTree child = functionCtx.getChild( i ); - if ( child instanceof HqlParser.WithinGroupClauseContext ) { - ctx = (HqlParser.WithinGroupClauseContext) child; - break; - } - } - if ( ctx != null ) { - return visitOrderByClause( (HqlParser.OrderByClauseContext) ctx.getChild( 3 ), false ); - } - return null; - } - - private Boolean getFromFirst(ParseTree functionCtx) { - // The clause is either on index 3 or 4 is where the - final int end = Math.min( functionCtx.getChildCount(), 5 ); - for ( int i = 3; i < end; i++ ) { - final ParseTree child = functionCtx.getChild( i ); - if ( child instanceof HqlParser.NthSideClauseContext ) { - final HqlParser.NthSideClauseContext subCtx = (HqlParser.NthSideClauseContext) child.getChild( 6 ); - return ( (TerminalNode) subCtx.getChild( 1 ) ).getSymbol().getType() == HqlParser.FIRST; - } - } - return null; - } - - private Boolean getRespectNullsClause(ParseTree functionCtx) { - for ( int i = functionCtx.getChildCount() - 3; i < functionCtx.getChildCount(); i++ ) { - final ParseTree child = functionCtx.getChild( i ); - if ( child instanceof HqlParser.NullsClauseContext ) { - return ( (TerminalNode) child.getChild( 0 ) ).getSymbol().getType() == HqlParser.RESPECT; - } - } - return null; - } - private SqmPredicate getFilterExpression(ParseTree functionCtx) { for ( int i = functionCtx.getChildCount() - 2; i < functionCtx.getChildCount(); i++ ) { final ParseTree child = functionCtx.getChild( i ); @@ -4668,48 +4609,36 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem return null; } - private SqmExpression applyOverClause(ParseTree functionCtx, SqmFunction function) { - final ParseTree lastChild = functionCtx.getChild( functionCtx.getChildCount() - 1 ); - if ( lastChild instanceof HqlParser.OverClauseContext ) { - return applyOverClause( (HqlParser.OverClauseContext) lastChild, function ); - } - return function; - } - private SqmExpression applyOverClause(HqlParser.OverClauseContext ctx, SqmFunction function) { + if ( ctx == null) { + return function; + } + final List> partitions; - final List orderList; + if ( ctx.partitionClause() != null ) { + final HqlParser.PartitionClauseContext partitionClause = ctx.partitionClause(); + partitions = new ArrayList<>( ( partitionClause.getChildCount() >> 1 ) - 1 ); + for ( int i = 2; i < partitionClause.getChildCount(); i += 2 ) { + partitions.add( (SqmExpression) partitionClause.getChild( i ).accept( this ) ); + } + } + else { + partitions = emptyList(); + } + + final List orderList = ctx.orderByClause() != null + ? visitOrderByClause( ctx.orderByClause(), false ).getSortSpecifications() + : emptyList(); + final FrameMode mode; final FrameKind startKind; final SqmExpression startExpression; final FrameKind endKind; final SqmExpression endExpression; final FrameExclusion exclusion; - int index = 2; - if ( ctx.getChild( index ) instanceof HqlParser.PartitionClauseContext ) { - final ParseTree subCtx = ctx.getChild( index ); - partitions = new ArrayList<>( ( subCtx.getChildCount() >> 1 ) - 1 ); - for ( int i = 2; i < subCtx.getChildCount(); i += 2 ) { - partitions.add( (SqmExpression) subCtx.getChild( i ).accept( this ) ); - } - index++; - } - else { - partitions = Collections.emptyList(); - } - if ( index < ctx.getChildCount() && ctx.getChild( index ) instanceof HqlParser.OrderByClauseContext ) { - orderList = visitOrderByClause( - (HqlParser.OrderByClauseContext) ctx.getChild( index ), - false - ).getSortSpecifications(); - index++; - } - else { - orderList = Collections.emptyList(); - } - if ( index < ctx.getChildCount() && ctx.getChild( index ) instanceof HqlParser.FrameClauseContext ) { - final ParseTree frameCtx = ctx.getChild( index ); - switch ( ( (TerminalNode) frameCtx.getChild( 0 ) ).getSymbol().getType() ) { + final HqlParser.FrameClauseContext frameClause = ctx.frameClause(); + if ( frameClause != null ) { + switch ( ( (TerminalNode) frameClause.getChild( 0 ) ).getSymbol().getType() ) { case HqlParser.RANGE: mode = FrameMode.RANGE; break; @@ -4720,14 +4649,14 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem mode = FrameMode.GROUPS; break; default: - throw new IllegalArgumentException( "Unexpected frame mode: " + frameCtx.getChild( 0 ) ); + throw new IllegalArgumentException( "Unexpected frame mode: " + frameClause.getChild( 0 ) ); } final int frameStartIndex; - if ( frameCtx.getChild( 1 ) instanceof TerminalNode ) { + if ( frameClause.getChild( 1 ) instanceof TerminalNode ) { frameStartIndex = 2; - endKind = getFrameKind( frameCtx.getChild( 4 ) ); + endKind = getFrameKind( frameClause.getChild( 4 ) ); endExpression = endKind == FrameKind.OFFSET_FOLLOWING || endKind == FrameKind.OFFSET_PRECEDING - ? (SqmExpression) frameCtx.getChild( 4 ).getChild( 0 ).accept( this ) + ? (SqmExpression) frameClause.getChild( 4 ).getChild( 0 ).accept( this ) : null; } else { @@ -4735,11 +4664,11 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem endKind = FrameKind.CURRENT_ROW; endExpression = null; } - startKind = getFrameKind( frameCtx.getChild( frameStartIndex ) ); + startKind = getFrameKind( frameClause.getChild( frameStartIndex ) ); startExpression = startKind == FrameKind.OFFSET_FOLLOWING || startKind == FrameKind.OFFSET_PRECEDING - ? (SqmExpression) frameCtx.getChild( frameStartIndex ).getChild( 0 ).accept( this ) + ? (SqmExpression) frameClause.getChild( frameStartIndex ).getChild( 0 ).accept( this ) : null; - final ParseTree lastChild = frameCtx.getChild( frameCtx.getChildCount() - 1 ); + final ParseTree lastChild = frameClause.getChild( frameClause.getChildCount() - 1 ); if ( lastChild instanceof HqlParser.FrameExclusionContext ) { switch ( ( (TerminalNode) lastChild.getChild( 1 ) ).getSymbol().getType() ) { case HqlParser.CURRENT: diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/function/FunctionKind.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/function/FunctionKind.java index a16cf0f385..13a72c2824 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/function/FunctionKind.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/function/FunctionKind.java @@ -15,5 +15,5 @@ public enum FunctionKind { NORMAL, AGGREGATE, ORDERED_SET_AGGREGATE, - WINDOW; + WINDOW } diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/function/SelfRenderingSqmOrderedSetAggregateFunction.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/function/SelfRenderingSqmOrderedSetAggregateFunction.java index fc9f798412..1c956cc040 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/function/SelfRenderingSqmOrderedSetAggregateFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/function/SelfRenderingSqmOrderedSetAggregateFunction.java @@ -40,7 +40,7 @@ public class SelfRenderingSqmOrderedSetAggregateFunction extends SelfRenderin FunctionRenderingSupport renderingSupport, List> arguments, SqmPredicate filter, - SqmOrderByClause withinGroup, + SqmOrderByClause withinGroupClause, ReturnableType impliedResultType, ArgumentsValidator argumentsValidator, FunctionReturnTypeResolver returnTypeResolver, @@ -57,7 +57,7 @@ public class SelfRenderingSqmOrderedSetAggregateFunction extends SelfRenderin nodeBuilder, name ); - this.withinGroup = withinGroup; + this.withinGroup = withinGroupClause; } @Override diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/WindowFunctionTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/WindowFunctionTest.java index 795ba78d27..4a854910f5 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/WindowFunctionTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/WindowFunctionTest.java @@ -30,7 +30,6 @@ import jakarta.persistence.TypedQuery; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; /**