fix NPE when ordered set agg function missing 'within group'

- also refactor a bit that code in SQB to be more typesafe
- and get rid of some warnings
This commit is contained in:
Gavin 2023-05-28 09:57:00 +02:00 committed by Gavin King
parent a6f036d320
commit 36bdd9d013
8 changed files with 222 additions and 289 deletions

View File

@ -1154,7 +1154,7 @@ listaggFunction
* A 'on overflow' clause: what to do when the text data type used for 'listagg' overflows
*/
onOverflowClause
: ON OVERFLOW (ERROR | (TRUNCATE expression? (WITH|WITHOUT) COUNT))
: ON OVERFLOW (ERROR | TRUNCATE expression? (WITH|WITHOUT) COUNT)
;
/**

View File

@ -14,7 +14,6 @@ import org.hibernate.query.ReturnableType;
import org.hibernate.query.spi.QueryEngine;
import org.hibernate.query.sqm.function.SelfRenderingFunctionSqlAstExpression;
import org.hibernate.query.sqm.function.SelfRenderingOrderedSetAggregateFunctionSqlAstExpression;
import org.hibernate.query.sqm.function.SelfRenderingSqmAggregateFunction;
import org.hibernate.query.sqm.function.SelfRenderingSqmOrderedSetAggregateFunction;
import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
import org.hibernate.query.sqm.sql.SqmToSqlAstConverter;

View File

@ -10,9 +10,10 @@ import java.util.Collections;
import java.util.List;
import org.hibernate.metamodel.mapping.BasicValuedMapping;
import org.hibernate.metamodel.mapping.JdbcMappingContainer;
import org.hibernate.metamodel.mapping.MappingModelExpressible;
import org.hibernate.metamodel.spi.MappingMetamodelImplementor;
import org.hibernate.query.ReturnableType;
import org.hibernate.query.SemanticException;
import org.hibernate.query.spi.QueryEngine;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.function.FunctionKind;
@ -129,13 +130,12 @@ public class InverseDistributionFunction extends AbstractSqmSelfRenderingFunctio
protected class SelfRenderingInverseDistributionFunction<T> extends SelfRenderingSqmOrderedSetAggregateFunction<T> {
private final SqmOrderByClause withinGroupClause;
public SelfRenderingInverseDistributionFunction(
List<? extends SqmTypedNode<?>> arguments,
SqmPredicate filter,
SqmOrderByClause withinGroupClause,
ReturnableType<T> impliedResultType, QueryEngine queryEngine) {
ReturnableType<T> impliedResultType,
QueryEngine queryEngine) {
super(
InverseDistributionFunction.this,
InverseDistributionFunction.this,
@ -148,12 +148,17 @@ public class InverseDistributionFunction extends AbstractSqmSelfRenderingFunctio
queryEngine.getCriteriaBuilder(),
InverseDistributionFunction.this.getName()
);
this.withinGroupClause = withinGroupClause;
if ( withinGroupClause == null ) {
throw new SemanticException("Inverse distribution function '" + getFunctionName()
+ "' must specify WITHIN GROUP");
}
}
@Override
protected ReturnableType<?> resolveResultType(TypeConfiguration typeConfiguration) {
return (ReturnableType<?>) withinGroupClause.getSortSpecifications().get( 0 ).getSortExpression()
return (ReturnableType<?>)
getWithinGroup().getSortSpecifications().get( 0 )
.getSortExpression()
.getExpressible();
}
@ -171,18 +176,20 @@ public class InverseDistributionFunction extends AbstractSqmSelfRenderingFunctio
// here we have something that is not a BasicType,
// and we have no way to get a BasicValuedMapping
// from it directly
final Expression expression = (Expression) withinGroupClause.getSortSpecifications().get( 0 )
final Expression expression = (Expression)
getWithinGroup().getSortSpecifications().get( 0 )
.getSortExpression()
.accept( walker );
if ( expression.getExpressionType() instanceof BasicValuedMapping ) {
return (BasicValuedMapping) expression.getExpressionType();
final JdbcMappingContainer expressionType = expression.getExpressionType();
if ( expressionType instanceof BasicValuedMapping ) {
return (BasicValuedMapping) expressionType;
}
try {
final MappingMetamodelImplementor domainModel = walker.getCreationContext()
return walker.getCreationContext()
.getSessionFactory()
.getRuntimeMetamodels()
.getMappingMetamodel();
return domainModel.resolveMappingExpressible(
.getMappingMetamodel()
.resolveMappingExpressible(
getNodeType(),
walker.getFromClauseAccess()::getTableGroup
);

View File

@ -14,7 +14,6 @@ import org.hibernate.query.ReturnableType;
import org.hibernate.query.spi.QueryEngine;
import org.hibernate.query.sqm.function.SelfRenderingFunctionSqlAstExpression;
import org.hibernate.query.sqm.function.SelfRenderingOrderedSetAggregateFunctionSqlAstExpression;
import org.hibernate.query.sqm.function.SelfRenderingSqmAggregateFunction;
import org.hibernate.query.sqm.function.SelfRenderingSqmOrderedSetAggregateFunction;
import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
import org.hibernate.query.sqm.produce.function.FunctionParameterType;

View File

@ -24,7 +24,6 @@ import java.time.ZonedDateTime;
import java.time.temporal.TemporalAccessor;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collections;
import java.util.GregorianCalendar;
import java.util.HashSet;
import java.util.List;
@ -367,7 +366,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
@Override
public SqmStatement<R> visitStatement(HqlParser.StatementContext ctx) {
// parameters allow multi-valued bindings only in very limited cases, so for
// parameters allow multivalued bindings only in very limited cases, so for
// the base case here we say false
parameterDeclarationContextStack.push( () -> false );
@ -1070,7 +1069,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
return;
}
final SqmOrderByClause orderByClause;
final HqlParser.OrderByClauseContext orderByClauseContext = (HqlParser.OrderByClauseContext) ctx.getChild( 0 );
final HqlParser.OrderByClauseContext orderByClauseContext = ctx.orderByClause();
if ( orderByClauseContext != null ) {
if ( creationOptions.useStrictJpaCompliance() && processingStateStack.depth() > 1 ) {
throw new StrictJpaComplianceViolation(
@ -1085,29 +1084,10 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
orderByClause = null;
}
int currentIndex = 1;
final HqlParser.LimitClauseContext limitClauseContext;
if ( currentIndex < ctx.getChildCount() && ctx.getChild( currentIndex ) instanceof HqlParser.LimitClauseContext ) {
limitClauseContext = (HqlParser.LimitClauseContext) ctx.getChild( currentIndex++ );
}
else {
limitClauseContext = null;
}
final HqlParser.OffsetClauseContext offsetClauseContext;
if ( currentIndex < ctx.getChildCount() && ctx.getChild( currentIndex ) instanceof HqlParser.OffsetClauseContext ) {
offsetClauseContext = (HqlParser.OffsetClauseContext) ctx.getChild( currentIndex++ );
}
else {
offsetClauseContext = null;
}
final HqlParser.FetchClauseContext fetchClauseContext;
if ( currentIndex < ctx.getChildCount() && ctx.getChild( currentIndex ) instanceof HqlParser.FetchClauseContext ) {
fetchClauseContext = (HqlParser.FetchClauseContext) ctx.getChild( currentIndex++ );
}
else {
fetchClauseContext = null;
}
if ( currentIndex != 1 ) {
final HqlParser.LimitClauseContext limitClauseContext = ctx.limitClause();
final HqlParser.OffsetClauseContext offsetClauseContext = ctx.offsetClause();
final HqlParser.FetchClauseContext fetchClauseContext = ctx.fetchClause();
if ( limitClauseContext != null || offsetClauseContext != null || fetchClauseContext != null ) {
if ( getCreationOptions().useStrictJpaCompliance() ) {
throw new StrictJpaComplianceViolation(
StrictJpaComplianceViolation.Type.LIMIT_OFFSET_CLAUSE
@ -1122,7 +1102,10 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
sqmQueryPart.setOffsetExpression( visitOffsetClause( offsetClauseContext ) );
if ( limitClauseContext == null ) {
sqmQueryPart.setFetchExpression( visitFetchClause( fetchClauseContext ), visitFetchClauseType( fetchClauseContext ) );
sqmQueryPart.setFetchExpression(
visitFetchClause( fetchClauseContext ),
visitFetchClauseType( fetchClauseContext )
);
}
else if ( fetchClauseContext == null ) {
sqmQueryPart.setFetchExpression( visitLimitClause( limitClauseContext ) );
@ -1173,7 +1156,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
}
protected SqmSelectClause buildInferredSelectClause(SqmFromClause fromClause) {
// for now, this is slightly different than the legacy behavior where
// for now, this is slightly different to the legacy behavior where
// the root and each non-fetched-join was selected. For now, here, we simply
// select the root
final SqmSelectClause selectClause;
@ -1261,11 +1244,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
// if the node is not a dynamic-instantiation, register it with
// the path-registry
//noinspection StatementWithEmptyBody
if ( selectableNode instanceof SqmDynamicInstantiation ) {
// nothing else to do (avoid kludgy `! ( instanceof )` syntax
}
else {
if ( !(selectableNode instanceof SqmDynamicInstantiation) ) {
getCurrentProcessingState().getPathRegistry().register( selection );
}
@ -1377,11 +1356,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
creationContext.getNodeBuilder()
);
//noinspection StatementWithEmptyBody
if ( argExpression instanceof SqmDynamicInstantiation ) {
// nothing else to do (avoid kludgy `! ( instanceof )` syntax
}
else {
if ( !(argExpression instanceof SqmDynamicInstantiation) ) {
getCurrentProcessingState().getPathRegistry().register( argument );
}
@ -1541,7 +1516,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
// This is syntactically disallowed
throw new ParsingException( "COLLATE is not allowed for alias based order-by or group-by items" );
}
// this will group-by all of the sub-parts in the from-element's model part
// this will group-by all the sub-parts in the from-element's model part
return sqmFrom;
}
@ -2638,7 +2613,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
final int estimatedSize = size >> 1;
final Class<?> testExpressionJavaType = testExpression.getJavaType();
final boolean isEnum = testExpressionJavaType != null && testExpressionJavaType.isEnum();
// Multi-valued bindings are only allowed if there is a single list item, hence size 3 (LP, RP and param)
// Multivalued bindings are only allowed if there is a single list item, hence size 3 (LP, RP and param)
parameterDeclarationContextStack.push( () -> size == 3 );
try {
final List<SqmExpression<?>> listExpressions = new ArrayList<>( estimatedSize );
@ -2681,7 +2656,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
try {
return new SqmInListPredicate(
testExpression,
Collections.singletonList( tupleExpressionListContext.getChild( 0 ).accept( this ) ),
singletonList( tupleExpressionListContext.getChild( 0 ).accept( this ) ),
negated,
creationContext.getNodeBuilder()
);
@ -3912,97 +3887,36 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
return functionName.toString();
}
@Override
public Object visitGenericFunction(HqlParser.GenericFunctionContext ctx) {
private String getFunctionName(HqlParser.GenericFunctionContext ctx) {
final String originalFunctionName = visitGenericFunctionName( ctx.genericFunctionName() );
final String functionName = originalFunctionName.toLowerCase();
if ( creationOptions.useStrictJpaCompliance() && !JPA_STANDARD_FUNCTIONS.contains( functionName ) ) {
throw new StrictJpaComplianceViolation(
"Encountered non-compliant non-standard function call [" +
originalFunctionName + "], but strict JPA " +
"compliance was requested; use JPA's FUNCTION(functionName[,...]) " +
"compliance was requested; use FUNCTION(functionName[,...]) " +
"syntax name instead",
StrictJpaComplianceViolation.Type.FUNCTION_CALL
);
}
//TODO: this fragment of code is extremely fragile and lacking in typesafety!
final ParseTree argumentChild = ctx.getChild( 2 );
final List<SqmTypedNode<?>> functionArguments;
if ( argumentChild instanceof HqlParser.GenericFunctionArgumentsContext ) {
@SuppressWarnings("unchecked")
List<SqmTypedNode<?>> node = (List<SqmTypedNode<?>>) argumentChild.accept(this);
functionArguments = node;
}
else if ( "*".equals( argumentChild.getText() ) ) {
functionArguments = Collections.singletonList( new SqmStar( getCreationContext().getNodeBuilder() ) );
}
else {
functionArguments = emptyList();
return functionName;
}
final Boolean fromFirst = getFromFirst( ctx );
final Boolean respectNulls = getRespectNullsClause( ctx );
final SqmOrderByClause withinGroup = getWithinGroup( ctx );
@Override
public Object visitGenericFunction(HqlParser.GenericFunctionContext ctx) {
final SqmFunctionDescriptor functionTemplate = getFunctionTemplate( ctx );
final List<SqmTypedNode<?>> functionArguments = getFunctionArguments( ctx );
final SqmPredicate filterExpression = getFilterExpression( ctx );
final boolean hasOverClause = ctx.getChild( ctx.getChildCount() - 1 ) instanceof HqlParser.OverClauseContext;
SqmFunctionDescriptor functionTemplate = getFunctionDescriptor( functionName );
if ( functionTemplate == null ) {
FunctionKind functionKind = FunctionKind.NORMAL;
if ( withinGroup != null ) {
functionKind = FunctionKind.ORDERED_SET_AGGREGATE;
}
else if ( hasOverClause ) {
functionKind = FunctionKind.WINDOW;
}
else if ( filterExpression != null ) {
functionKind = FunctionKind.AGGREGATE;
}
functionTemplate = new NamedSqmFunctionDescriptor(
functionName,
true,
null,
StandardFunctionReturnTypeResolvers.invariant(
resolveExpressibleTypeBasic( Object.class )
),
null,
functionName,
functionKind,
null,
SqlAstNodeRenderingMode.DEFAULT
);
}
else {
if ( hasOverClause && functionTemplate.getFunctionKind() == FunctionKind.NORMAL ) {
throw new SemanticException( "OVER clause is illegal for normal function: " + functionName );
}
else if ( !hasOverClause && functionTemplate.getFunctionKind() == FunctionKind.WINDOW ) {
throw new SemanticException( "OVER clause is mandatory for window-only function: " + functionName );
}
if ( respectNulls != null ) {
switch ( functionName ) {
case "lag":
case "lead":
case "first_value":
case "last_value":
case "nth_value":
break;
default:
throw new SemanticException( "RESPECT/IGNORE NULLS is illegal for function: " + functionName );
}
}
if ( fromFirst != null && !"nth_value".equals( functionName ) ) {
throw new SemanticException( "FROM FIRST/LAST is illegal for function: " + functionName );
}
}
final SqmFunction<?> function;
switch ( functionTemplate.getFunctionKind() ) {
case ORDERED_SET_AGGREGATE:
function = functionTemplate.generateOrderedSetAggregateSqmExpression(
functionArguments,
filterExpression,
withinGroup,
ctx.withinGroupClause() == null
? null // this is allowed for e.g. rank(), but not for all
: visitOrderByClause( ctx.withinGroupClause().orderByClause(), false ),
null,
creationContext.getQueryEngine(),
creationContext.getJpaMetamodel().getTypeConfiguration()
@ -4029,9 +3943,6 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
);
break;
default:
if ( filterExpression != null ) {
throw new ParsingException( "Illegal use of a FILTER clause for non-aggregate function: " + originalFunctionName );
}
function = functionTemplate.generateSqmExpression(
functionArguments,
null,
@ -4040,7 +3951,95 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
);
break;
}
return applyOverClause( ctx, function );
return applyOverClause( ctx.overClause(), function );
}
private SqmFunctionDescriptor getFunctionTemplate(HqlParser.GenericFunctionContext ctx) {
final String functionName = getFunctionName( ctx );
final SqmFunctionDescriptor functionTemplate = getFunctionDescriptor( functionName );
if ( functionTemplate == null ) {
return new NamedSqmFunctionDescriptor(
functionName,
true,
null,
StandardFunctionReturnTypeResolvers.invariant(
resolveExpressibleTypeBasic( Object.class )
),
null,
functionName,
inferFunctionKind( ctx ),
null,
SqlAstNodeRenderingMode.DEFAULT
);
}
else {
final FunctionKind functionKind = functionTemplate.getFunctionKind();
if ( ctx.filterClause() != null && functionKind == FunctionKind.NORMAL ) {
throw new ParsingException( "FILTER clause is illegal for non-aggregate function: " + functionName );
}
if ( ctx.overClause() != null && functionKind == FunctionKind.NORMAL ) {
throw new SemanticException( "OVER clause is illegal for non-aggregate function: " + functionName);
}
if ( ctx.withinGroupClause() != null && functionKind == FunctionKind.NORMAL ) {
throw new SemanticException( "WITHIN GROUP clause is illegal for non-aggregate function: " + functionName);
}
if ( ctx.overClause() == null && functionKind == FunctionKind.WINDOW ) {
throw new SemanticException( "OVER clause is mandatory for window-only function: " + functionName );
}
if ( ctx.withinGroupClause() == null && ctx.overClause() == null
&& functionKind == FunctionKind.ORDERED_SET_AGGREGATE ) {
throw new SemanticException( "WITHIN GROUP or OVER clause is mandatory for ordered set aggregate function: " + functionName );
}
if ( ctx.nullsClause() != null ) {
switch ( functionName ) {
case "lag":
case "lead":
case "first_value":
case "last_value":
case "nth_value":
break;
default:
throw new SemanticException( "RESPECT/IGNORE NULLS is illegal for function: " + functionName );
}
}
if ( ctx.nthSideClause() != null && !"nth_value".equals( functionName ) ) {
throw new SemanticException( "FROM FIRST/LAST is illegal for function: " + functionName );
}
return functionTemplate;
}
}
private static FunctionKind inferFunctionKind(HqlParser.GenericFunctionContext ctx) {
if ( ctx.withinGroupClause() != null ) {
return FunctionKind.ORDERED_SET_AGGREGATE;
}
else if ( ctx.overClause() != null ) {
return FunctionKind.WINDOW;
}
else if ( ctx.filterClause() != null ) {
return FunctionKind.AGGREGATE;
}
else {
return FunctionKind.NORMAL;
}
}
private List<SqmTypedNode<?>> getFunctionArguments(HqlParser.GenericFunctionContext ctx) {
if ( ctx.genericFunctionArguments() != null ) {
@SuppressWarnings("unchecked")
final List<SqmTypedNode<?>> node = (List<SqmTypedNode<?>>)
ctx.genericFunctionArguments().accept(this);
return node;
}
else if ( ctx.ASTERISK() != null ) {
return singletonList( new SqmStar( getCreationContext().getNodeBuilder() ) );
}
else {
return emptyList();
}
}
@Override
@ -4048,46 +4047,53 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
if ( creationOptions.useStrictJpaCompliance() ) {
throw new StrictJpaComplianceViolation(
"Encountered non-compliant non-standard function call [listagg], but strict JPA " +
"compliance was requested; use JPA's FUNCTION(functionName[,...]) " +
"compliance was requested; use FUNCTION(functionName[,...]) " +
"syntax name instead",
StrictJpaComplianceViolation.Type.FUNCTION_CALL
);
}
final SqmFunctionDescriptor functionTemplate = getFunctionDescriptor( "listagg" );
if ( functionTemplate == null ) {
throw new SemanticException(
"The listagg function was not registered for the dialect"
throw new SemanticException( "The listagg function was not registered for the dialect" );
}
return applyOverClause(
ctx.overClause(),
functionTemplate.generateOrderedSetAggregateSqmExpression(
getListaggArguments( ctx ),
getFilterExpression( ctx ),
ctx.withinGroupClause() == null
? null // this is allowed
: visitOrderByClause( ctx.withinGroupClause().orderByClause(), false ),
null,
creationContext.getQueryEngine(),
creationContext.getJpaMetamodel().getTypeConfiguration()
)
);
}
final int argumentStartIndex;
final ParseTree thirdChild = ctx.getChild( 2 );
final boolean distinct;
if ( thirdChild instanceof TerminalNode ) {
distinct = true;
argumentStartIndex = 3;
}
else {
distinct = false;
argumentStartIndex = 2;
}
final SqmExpression<?> firstArgument = (SqmExpression<?>) ctx.getChild( argumentStartIndex ).accept( this );
final SqmExpression<?> secondArgument = (SqmExpression<?>) ctx.getChild( argumentStartIndex + 2 ).accept( this );
final ParseTree overflowCtx = ctx.getChild( argumentStartIndex + 3 );
private List<SqmTypedNode<?>> getListaggArguments(ListaggFunctionContext ctx) {
final SqmExpression<?> firstArgument = (SqmExpression<?>) ctx.expressionOrPredicate(0).accept( this );
final SqmExpression<?> secondArgument = (SqmExpression<?>) ctx.expressionOrPredicate(1).accept( this );
final OnOverflowClauseContext overflowCtx = ctx.onOverflowClause();
final List<SqmTypedNode<?>> functionArguments = new ArrayList<>( 3 );
if ( distinct ) {
if ( ctx.DISTINCT() != null ) {
functionArguments.add( new SqmDistinct<>( firstArgument, creationContext.getNodeBuilder() ) );
}
else {
functionArguments.add( firstArgument );
}
if ( overflowCtx instanceof OnOverflowClauseContext ) {
if ( overflowCtx.getChildCount() > 3 ) {
if ( overflowCtx != null ) {
if ( overflowCtx.ERROR() != null ) {
// ON OVERFLOW ERROR
functionArguments.add( new SqmOverflow<>( secondArgument, null, false ) );
}
else {
// ON OVERFLOW TRUNCATE
final TerminalNode countNode = (TerminalNode) overflowCtx.getChild( overflowCtx.getChildCount() - 2 );
final boolean withCount = countNode.getSymbol().getType() == HqlParser.WITH;
final SqmExpression<?> fillerExpression;
if ( overflowCtx.getChildCount() == 6 ) {
fillerExpression = (SqmExpression<?>) overflowCtx.getChild( 3 ).accept( this );
if ( overflowCtx.expression() != null ) {
fillerExpression = (SqmExpression<?>) overflowCtx.expression().accept( this );
}
else {
// The SQL standard says the default is three periods `...`
@ -4097,30 +4103,15 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
secondArgument.nodeBuilder()
);
}
final boolean withCount = overflowCtx.WITH() != null;
//noinspection unchecked,rawtypes
functionArguments.add( new SqmOverflow( secondArgument, fillerExpression, withCount ) );
}
else {
// ON OVERFLOW ERROR
functionArguments.add( new SqmOverflow<>( secondArgument, null, false ) );
}
}
else {
functionArguments.add( secondArgument );
}
final SqmOrderByClause withinGroup = getWithinGroup( ctx );
final SqmPredicate filterExpression = getFilterExpression( ctx );
return applyOverClause(
ctx,
functionTemplate.generateOrderedSetAggregateSqmExpression(
functionArguments,
filterExpression,
withinGroup,
null,
creationContext.getQueryEngine(),
creationContext.getJpaMetamodel().getTypeConfiguration()
)
);
return functionArguments;
}
@Override
@ -4481,23 +4472,20 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
@Override
public SqmExpression<?> visitEveryFunction(HqlParser.EveryFunctionContext ctx) {
final SqmPredicate filterExpression = getFilterExpression( ctx );
final ParseTree argumentChild = ctx.getChild( 2 );
if ( argumentChild instanceof HqlParser.SubqueryContext ) {
final SqmSubQuery<?> subquery = (SqmSubQuery<?>) argumentChild.accept( this );
if ( ctx.subquery() != null ) {
final SqmSubQuery<?> subquery = (SqmSubQuery<?>) ctx.subquery().accept( this );
return new SqmEvery<>( subquery, creationContext.getNodeBuilder() );
}
else if ( argumentChild instanceof HqlParser.PredicateContext ) {
else if ( ctx.predicate() != null ) {
if ( getCreationOptions().useStrictJpaCompliance() ) {
throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.FUNCTION_CALL );
}
final SqmExpression<?> argument = (SqmExpression<?>) argumentChild.accept( this );
final SqmExpression<?> argument = (SqmExpression<?>) ctx.predicate().accept( this );
return applyOverClause(
ctx,
ctx.overClause(),
getFunctionDescriptor( "every" ).generateAggregateSqmExpression(
singletonList( argument ),
filterExpression,
getFilterExpression( ctx ),
resolveExpressibleTypeBasic( Boolean.class ),
creationContext.getQueryEngine(),
creationContext.getJpaMetamodel().getTypeConfiguration()
@ -4509,10 +4497,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.HQL_COLLECTION_FUNCTION );
}
return new SqmEvery<>(
createCollectionReferenceSubQuery(
(HqlParser.SimplePathContext) ctx.getChild( 3 ),
(TerminalNode) ctx.getChild( 1 )
),
createCollectionReferenceSubQuery( ctx.simplePath(), (TerminalNode) ctx.getChild( 1 ) ),
creationContext.getNodeBuilder()
);
}
@ -4520,23 +4505,20 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
@Override
public SqmExpression<?> visitAnyFunction(HqlParser.AnyFunctionContext ctx) {
final SqmPredicate filterExpression = getFilterExpression( ctx );
final ParseTree argumentChild = ctx.getChild( 2 );
if ( argumentChild instanceof HqlParser.SubqueryContext ) {
final SqmSubQuery<?> subquery = (SqmSubQuery<?>) argumentChild.accept( this );
if ( ctx.subquery() != null ) {
final SqmSubQuery<?> subquery = (SqmSubQuery<?>) ctx.subquery().accept( this );
return new SqmAny<>( subquery, creationContext.getNodeBuilder() );
}
else if ( argumentChild instanceof HqlParser.PredicateContext ) {
else if ( ctx.predicate() != null ) {
if ( getCreationOptions().useStrictJpaCompliance() ) {
throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.FUNCTION_CALL );
}
final SqmExpression<?> argument = (SqmExpression<?>) argumentChild.accept( this );
final SqmExpression<?> argument = (SqmExpression<?>) ctx.predicate().accept( this );
return applyOverClause(
ctx,
ctx.overClause(),
getFunctionDescriptor( "any" ).generateAggregateSqmExpression(
singletonList( argument ),
filterExpression,
getFilterExpression( ctx ),
resolveExpressibleTypeBasic( Boolean.class ),
creationContext.getQueryEngine(),
creationContext.getJpaMetamodel().getTypeConfiguration()
@ -4548,10 +4530,7 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.HQL_COLLECTION_FUNCTION );
}
return new SqmAny<>(
createCollectionReferenceSubQuery(
(HqlParser.SimplePathContext) ctx.getChild( 3 ),
(TerminalNode) ctx.getChild( 1 )
),
createCollectionReferenceSubQuery( ctx.simplePath(), (TerminalNode) ctx.getChild( 1 ) ),
creationContext.getNodeBuilder()
);
}
@ -4620,44 +4599,6 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
return (SqmSubQuery<X>) subQuery;
}
private SqmOrderByClause getWithinGroup(ParseTree functionCtx) {
HqlParser.WithinGroupClauseContext ctx = null;
for ( int i = functionCtx.getChildCount() - 3; i < functionCtx.getChildCount(); i++ ) {
final ParseTree child = functionCtx.getChild( i );
if ( child instanceof HqlParser.WithinGroupClauseContext ) {
ctx = (HqlParser.WithinGroupClauseContext) child;
break;
}
}
if ( ctx != null ) {
return visitOrderByClause( (HqlParser.OrderByClauseContext) ctx.getChild( 3 ), false );
}
return null;
}
private Boolean getFromFirst(ParseTree functionCtx) {
// The clause is either on index 3 or 4 is where the
final int end = Math.min( functionCtx.getChildCount(), 5 );
for ( int i = 3; i < end; i++ ) {
final ParseTree child = functionCtx.getChild( i );
if ( child instanceof HqlParser.NthSideClauseContext ) {
final HqlParser.NthSideClauseContext subCtx = (HqlParser.NthSideClauseContext) child.getChild( 6 );
return ( (TerminalNode) subCtx.getChild( 1 ) ).getSymbol().getType() == HqlParser.FIRST;
}
}
return null;
}
private Boolean getRespectNullsClause(ParseTree functionCtx) {
for ( int i = functionCtx.getChildCount() - 3; i < functionCtx.getChildCount(); i++ ) {
final ParseTree child = functionCtx.getChild( i );
if ( child instanceof HqlParser.NullsClauseContext ) {
return ( (TerminalNode) child.getChild( 0 ) ).getSymbol().getType() == HqlParser.RESPECT;
}
}
return null;
}
private SqmPredicate getFilterExpression(ParseTree functionCtx) {
for ( int i = functionCtx.getChildCount() - 2; i < functionCtx.getChildCount(); i++ ) {
final ParseTree child = functionCtx.getChild( i );
@ -4668,48 +4609,36 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
return null;
}
private SqmExpression<?> applyOverClause(ParseTree functionCtx, SqmFunction<?> function) {
final ParseTree lastChild = functionCtx.getChild( functionCtx.getChildCount() - 1 );
if ( lastChild instanceof HqlParser.OverClauseContext ) {
return applyOverClause( (HqlParser.OverClauseContext) lastChild, function );
}
private SqmExpression<?> applyOverClause(HqlParser.OverClauseContext ctx, SqmFunction<?> function) {
if ( ctx == null) {
return function;
}
private SqmExpression<?> applyOverClause(HqlParser.OverClauseContext ctx, SqmFunction<?> function) {
final List<SqmExpression<?>> partitions;
final List<SqmSortSpecification> orderList;
if ( ctx.partitionClause() != null ) {
final HqlParser.PartitionClauseContext partitionClause = ctx.partitionClause();
partitions = new ArrayList<>( ( partitionClause.getChildCount() >> 1 ) - 1 );
for ( int i = 2; i < partitionClause.getChildCount(); i += 2 ) {
partitions.add( (SqmExpression<?>) partitionClause.getChild( i ).accept( this ) );
}
}
else {
partitions = emptyList();
}
final List<SqmSortSpecification> orderList = ctx.orderByClause() != null
? visitOrderByClause( ctx.orderByClause(), false ).getSortSpecifications()
: emptyList();
final FrameMode mode;
final FrameKind startKind;
final SqmExpression<?> startExpression;
final FrameKind endKind;
final SqmExpression<?> endExpression;
final FrameExclusion exclusion;
int index = 2;
if ( ctx.getChild( index ) instanceof HqlParser.PartitionClauseContext ) {
final ParseTree subCtx = ctx.getChild( index );
partitions = new ArrayList<>( ( subCtx.getChildCount() >> 1 ) - 1 );
for ( int i = 2; i < subCtx.getChildCount(); i += 2 ) {
partitions.add( (SqmExpression<?>) subCtx.getChild( i ).accept( this ) );
}
index++;
}
else {
partitions = Collections.emptyList();
}
if ( index < ctx.getChildCount() && ctx.getChild( index ) instanceof HqlParser.OrderByClauseContext ) {
orderList = visitOrderByClause(
(HqlParser.OrderByClauseContext) ctx.getChild( index ),
false
).getSortSpecifications();
index++;
}
else {
orderList = Collections.emptyList();
}
if ( index < ctx.getChildCount() && ctx.getChild( index ) instanceof HqlParser.FrameClauseContext ) {
final ParseTree frameCtx = ctx.getChild( index );
switch ( ( (TerminalNode) frameCtx.getChild( 0 ) ).getSymbol().getType() ) {
final HqlParser.FrameClauseContext frameClause = ctx.frameClause();
if ( frameClause != null ) {
switch ( ( (TerminalNode) frameClause.getChild( 0 ) ).getSymbol().getType() ) {
case HqlParser.RANGE:
mode = FrameMode.RANGE;
break;
@ -4720,14 +4649,14 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
mode = FrameMode.GROUPS;
break;
default:
throw new IllegalArgumentException( "Unexpected frame mode: " + frameCtx.getChild( 0 ) );
throw new IllegalArgumentException( "Unexpected frame mode: " + frameClause.getChild( 0 ) );
}
final int frameStartIndex;
if ( frameCtx.getChild( 1 ) instanceof TerminalNode ) {
if ( frameClause.getChild( 1 ) instanceof TerminalNode ) {
frameStartIndex = 2;
endKind = getFrameKind( frameCtx.getChild( 4 ) );
endKind = getFrameKind( frameClause.getChild( 4 ) );
endExpression = endKind == FrameKind.OFFSET_FOLLOWING || endKind == FrameKind.OFFSET_PRECEDING
? (SqmExpression<?>) frameCtx.getChild( 4 ).getChild( 0 ).accept( this )
? (SqmExpression<?>) frameClause.getChild( 4 ).getChild( 0 ).accept( this )
: null;
}
else {
@ -4735,11 +4664,11 @@ public class SemanticQueryBuilder<R> extends HqlParserBaseVisitor<Object> implem
endKind = FrameKind.CURRENT_ROW;
endExpression = null;
}
startKind = getFrameKind( frameCtx.getChild( frameStartIndex ) );
startKind = getFrameKind( frameClause.getChild( frameStartIndex ) );
startExpression = startKind == FrameKind.OFFSET_FOLLOWING || startKind == FrameKind.OFFSET_PRECEDING
? (SqmExpression<?>) frameCtx.getChild( frameStartIndex ).getChild( 0 ).accept( this )
? (SqmExpression<?>) frameClause.getChild( frameStartIndex ).getChild( 0 ).accept( this )
: null;
final ParseTree lastChild = frameCtx.getChild( frameCtx.getChildCount() - 1 );
final ParseTree lastChild = frameClause.getChild( frameClause.getChildCount() - 1 );
if ( lastChild instanceof HqlParser.FrameExclusionContext ) {
switch ( ( (TerminalNode) lastChild.getChild( 1 ) ).getSymbol().getType() ) {
case HqlParser.CURRENT:

View File

@ -15,5 +15,5 @@ public enum FunctionKind {
NORMAL,
AGGREGATE,
ORDERED_SET_AGGREGATE,
WINDOW;
WINDOW
}

View File

@ -40,7 +40,7 @@ public class SelfRenderingSqmOrderedSetAggregateFunction<T> extends SelfRenderin
FunctionRenderingSupport renderingSupport,
List<? extends SqmTypedNode<?>> arguments,
SqmPredicate filter,
SqmOrderByClause withinGroup,
SqmOrderByClause withinGroupClause,
ReturnableType<T> impliedResultType,
ArgumentsValidator argumentsValidator,
FunctionReturnTypeResolver returnTypeResolver,
@ -57,7 +57,7 @@ public class SelfRenderingSqmOrderedSetAggregateFunction<T> extends SelfRenderin
nodeBuilder,
name
);
this.withinGroup = withinGroup;
this.withinGroup = withinGroupClause;
}
@Override

View File

@ -30,7 +30,6 @@ import jakarta.persistence.TypedQuery;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;
/**