From 78209dc5067ba06fb18b9bb859536ca10ae17a29 Mon Sep 17 00:00:00 2001 From: Jan Schatteman Date: Fri, 11 Jun 2021 23:10:53 +0200 Subject: [PATCH] Add filter clause for aggregate functions in HQL Signed-off-by: Jan Schatteman --- .../org/hibernate/grammars/hql/HqlLexer.g4 | 1 + .../org/hibernate/grammars/hql/HqlParser.g4 | 21 +- .../function/CaseWhenEveryAnyEmulation.java | 66 +++++ .../function/CommonFunctionFactory.java | 49 +--- .../dialect/function/EveryAnyEmulation.java | 78 +++++ .../function/SQLServerEveryAnyEmulation.java | 66 +++++ .../hql/internal/SemanticQueryBuilder.java | 50 +++- .../AbstractSqmFunctionDescriptor.java | 18 ++ ...actSqmSelfRenderingFunctionDescriptor.java | 7 +- .../function/NamedSqmFunctionDescriptor.java | 9 +- .../query/hql/AggregateFilterClauseTest.java | 268 ++++++++++++++++++ 11 files changed, 564 insertions(+), 69 deletions(-) create mode 100644 hibernate-core/src/main/java/org/hibernate/dialect/function/CaseWhenEveryAnyEmulation.java create mode 100644 hibernate-core/src/main/java/org/hibernate/dialect/function/EveryAnyEmulation.java create mode 100644 hibernate-core/src/main/java/org/hibernate/dialect/function/SQLServerEveryAnyEmulation.java create mode 100644 hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/AggregateFilterClauseTest.java diff --git a/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlLexer.g4 b/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlLexer.g4 index 3b373b39cb..79ab64a33e 100644 --- a/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlLexer.g4 +++ b/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlLexer.g4 @@ -182,6 +182,7 @@ EXISTS : [eE] [xX] [iI] [sS] [tT] [sS]; EXP : [eE] [xX] [pP]; EXTRACT : [eE] [xX] [tT] [rR] [aA] [cC] [tT]; FETCH : [fF] [eE] [tT] [cC] [hH]; +FILTER : [fF] [iI] [lL] [tT] [eE] [rR]; FIRST : [fF] [iI] [rR] [sS] [tT]; FLOOR : [fF] [lL] [oO] [oO] [rR]; FROM : [fF] [rR] [oO] [mM]; diff --git a/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 b/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 index 23ee395428..ccfbb2a51e 100644 --- a/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 +++ b/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 @@ -690,31 +690,35 @@ aggregateFunction ; avgFunction - : AVG LEFT_PAREN DISTINCT? expression RIGHT_PAREN + : AVG LEFT_PAREN DISTINCT? expression RIGHT_PAREN filterClause? ; sumFunction - : SUM LEFT_PAREN DISTINCT? expression RIGHT_PAREN + : SUM LEFT_PAREN DISTINCT? expression RIGHT_PAREN filterClause? ; minFunction - : MIN LEFT_PAREN DISTINCT? expression RIGHT_PAREN + : MIN LEFT_PAREN DISTINCT? expression RIGHT_PAREN filterClause? ; maxFunction - : MAX LEFT_PAREN DISTINCT? expression RIGHT_PAREN + : MAX LEFT_PAREN DISTINCT? expression RIGHT_PAREN filterClause? ; countFunction - : COUNT LEFT_PAREN DISTINCT? (expression | ASTERISK) RIGHT_PAREN + : COUNT LEFT_PAREN DISTINCT? (expression | ASTERISK) RIGHT_PAREN filterClause? ; everyFunction - : (EVERY|ALL) LEFT_PAREN (predicate | subQuery) RIGHT_PAREN + : (EVERY|ALL) LEFT_PAREN (predicate | subQuery) RIGHT_PAREN filterClause? ; anyFunction - : (ANY|SOME) LEFT_PAREN (predicate | subQuery) RIGHT_PAREN + : (ANY|SOME) LEFT_PAREN (predicate | subQuery) RIGHT_PAREN filterClause? + ; + +filterClause + : FILTER LEFT_PAREN whereClause RIGHT_PAREN ; standardFunction @@ -1102,7 +1106,7 @@ rollup * The lexer hands us recognized keywords using their specific tokens. This is important * for the recognition of sqm structure, especially in terms of performance! * - * However we want to continue to allow users to use mopst keywords as identifiers (e.g., attribute names). + * However we want to continue to allow users to use most keywords as identifiers (e.g., attribute names). * This parser rule helps with that. Here we expect that the caller already understands their * context enough to know that keywords-as-identifiers are allowed. */ @@ -1149,6 +1153,7 @@ identifier | EXP | EXTRACT | FETCH + | FILTER | FLOOR | FROM | FOR diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/CaseWhenEveryAnyEmulation.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/CaseWhenEveryAnyEmulation.java new file mode 100644 index 0000000000..26011a5693 --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/CaseWhenEveryAnyEmulation.java @@ -0,0 +1,66 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.dialect.function; + +import java.util.List; + +import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.SqlAstNode; +import org.hibernate.sql.ast.tree.predicate.Predicate; +import org.hibernate.type.StandardBasicTypes; + +/** + * @author Jan Schatteman + */ +public class CaseWhenEveryAnyEmulation extends AbstractSqmSelfRenderingFunctionDescriptor { + + private final boolean every; + + public CaseWhenEveryAnyEmulation(boolean every) { + super( + every ? "every" : "any", + true, + StandardArgumentsValidators.exactly( 1 ), + StandardFunctionReturnTypeResolvers.invariant( StandardBasicTypes.BOOLEAN ) + ); + this.every = every; + } + + @Override + public void render( + SqlAppender sqlAppender, + List sqlAstArguments, + Predicate filter, + SqlAstTranslator walker) { + if ( every ) { + sqlAppender.appendSql( "min(case when " ); + } + else { + sqlAppender.appendSql( "max(case when " ); + } + if ( filter != null ) { + filter.accept( walker ); + sqlAppender.appendSql( " then case when " ); + sqlAstArguments.get( 0 ).accept( walker ); + sqlAppender.appendSql( " then 1 else 0 end else null end)" ); + } + else { + sqlAstArguments.get( 0 ).accept( walker ); + sqlAppender.appendSql( " then 1 else 0 end)" ); + } + } + + @Override + public void render( + SqlAppender sqlAppender, List sqlAstArguments, SqlAstTranslator walker) { + this.render( sqlAppender, sqlAstArguments, null, walker ); + } +} diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java index 4caafc751f..01aab71bd3 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java @@ -647,7 +647,7 @@ public class CommonFunctionFactory { .register(); //MySQL has it but how is that even useful? -// queryEngine.getSqmFunctionRegistry().namedTemplateBuilder( "bit_xor" ) + // queryEngine.getSqmFunctionRegistry().namedTemplateBuilder( "bit_xor" ) // .setExactArgumentCount( 1 ) // .register(); } @@ -689,26 +689,14 @@ public class CommonFunctionFactory { .register(); queryEngine.getSqmFunctionRegistry().registerAlternateKey( "any", "bool_or" ); } - /** * These are aggregate functions taking one argument, * for databases that have to emulate the boolean * aggregation functions using sum() and case. */ public static void everyAny_sumCase(QueryEngine queryEngine) { - queryEngine.getSqmFunctionRegistry().patternAggregateDescriptorBuilder( "every", - "(sum(case when ?1 then 0 else 1 end)=0)" ) - .setExactArgumentCount( 1 ) - .setInvariantType( StandardBasicTypes.BOOLEAN ) - .setArgumentListSignature("(predicate)") - .register(); - - queryEngine.getSqmFunctionRegistry().patternAggregateDescriptorBuilder( "any", - "(sum(case when ?1 then 1 else 0 end)>0)" ) - .setExactArgumentCount( 1 ) - .setInvariantType( StandardBasicTypes.BOOLEAN ) - .setArgumentListSignature("(predicate)") - .register(); + queryEngine.getSqmFunctionRegistry().register( "every", new EveryAnyEmulation( true ) ); + queryEngine.getSqmFunctionRegistry().register( "any", new EveryAnyEmulation( false ) ); } /** @@ -716,39 +704,18 @@ public class CommonFunctionFactory { * for SQL Server. */ public static void everyAny_sumIif(QueryEngine queryEngine) { - queryEngine.getSqmFunctionRegistry().patternAggregateDescriptorBuilder( "every", - "min(iif(?1,1,0))" ) - .setExactArgumentCount( 1 ) - .setInvariantType( StandardBasicTypes.BOOLEAN ) - .setArgumentListSignature("(predicate)") - .register(); - - queryEngine.getSqmFunctionRegistry().patternAggregateDescriptorBuilder( "any", - "max(iif(?1,1,0))" ) - .setExactArgumentCount( 1 ) - .setInvariantType( StandardBasicTypes.BOOLEAN ) - .setArgumentListSignature("(predicate)") - .register(); + queryEngine.getSqmFunctionRegistry().register( "every", new SQLServerEveryAnyEmulation( true ) ); + queryEngine.getSqmFunctionRegistry().register( "any", new SQLServerEveryAnyEmulation( false ) ); } + /** * These are aggregate functions taking one argument, * for Oracle. */ public static void everyAny_sumCaseCase(QueryEngine queryEngine) { - queryEngine.getSqmFunctionRegistry().patternAggregateDescriptorBuilder( "every", - "min(case when ?1 then 1 else 0 end)" ) - .setExactArgumentCount( 1 ) - .setInvariantType( StandardBasicTypes.BOOLEAN ) - .setArgumentListSignature("(predicate)") - .register(); - - queryEngine.getSqmFunctionRegistry().patternAggregateDescriptorBuilder( "any", - "max(case when ?1 then 1 else 0 end)" ) - .setExactArgumentCount( 1 ) - .setInvariantType( StandardBasicTypes.BOOLEAN ) - .setArgumentListSignature("(predicate)") - .register(); + queryEngine.getSqmFunctionRegistry().register( "every", new CaseWhenEveryAnyEmulation( true ) ); + queryEngine.getSqmFunctionRegistry().register( "any", new CaseWhenEveryAnyEmulation( false ) ); } public static void yearMonthDay(QueryEngine queryEngine) { diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/EveryAnyEmulation.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/EveryAnyEmulation.java new file mode 100644 index 0000000000..f6593df297 --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/EveryAnyEmulation.java @@ -0,0 +1,78 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.dialect.function; + +import java.util.List; + +import org.hibernate.metamodel.model.domain.AllowableFunctionReturnType; +import org.hibernate.query.spi.QueryEngine; +import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor; +import org.hibernate.query.sqm.function.SelfRenderingSqmFunction; +import org.hibernate.query.sqm.produce.function.ArgumentsValidator; +import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.query.sqm.tree.SqmTypedNode; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.SqlAstNode; +import org.hibernate.sql.ast.tree.predicate.Predicate; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.spi.TypeConfiguration; + +/** + * @author Jan Schatteman + */ +public class EveryAnyEmulation extends AbstractSqmSelfRenderingFunctionDescriptor { + + private final boolean every; + + public EveryAnyEmulation(boolean every) { + super( + every ? "every" : "any", + true, + StandardArgumentsValidators.exactly( 1 ), + StandardFunctionReturnTypeResolvers.invariant( StandardBasicTypes.BOOLEAN ) + ); + this.every = every; + } + + @Override + public void render( + SqlAppender sqlAppender, + List sqlAstArguments, + Predicate filter, + SqlAstTranslator walker) { + sqlAppender.appendSql( "(sum(case when " ); + if ( filter != null ) { + filter.accept( walker ); + sqlAppender.appendSql( " then case when " ); + sqlAstArguments.get( 0 ).accept( walker ); + if ( every ) { + sqlAppender.appendSql( " then 0 else 1 end else null end)=0)" ); + } + else { + sqlAppender.appendSql( " then 1 else 0 end else null end)>0)" ); + } + } + else { + sqlAstArguments.get( 0 ).accept( walker ); + if ( every ) { + sqlAppender.appendSql( " then 0 else 1 end)=0)" ); + } + else { + sqlAppender.appendSql( " then 1 else 0 end)>0)" ); + } + } + } + + @Override + public void render( + SqlAppender sqlAppender, List sqlAstArguments, SqlAstTranslator walker) { + this.render( sqlAppender, sqlAstArguments, null, walker ); + } +} diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/SQLServerEveryAnyEmulation.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/SQLServerEveryAnyEmulation.java new file mode 100644 index 0000000000..ead8e9e701 --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/SQLServerEveryAnyEmulation.java @@ -0,0 +1,66 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.dialect.function; + +import java.util.List; + +import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.SqlAstNode; +import org.hibernate.sql.ast.tree.predicate.Predicate; +import org.hibernate.type.StandardBasicTypes; + +/** + * @author Jan Schatteman + */ +public class SQLServerEveryAnyEmulation extends AbstractSqmSelfRenderingFunctionDescriptor { + + private final boolean every; + + public SQLServerEveryAnyEmulation(boolean every) { + super( + every ? "every" : "any", + true, + StandardArgumentsValidators.exactly( 1 ), + StandardFunctionReturnTypeResolvers.invariant( StandardBasicTypes.BOOLEAN ) + ); + this.every = every; + } + + @Override + public void render( + SqlAppender sqlAppender, + List sqlAstArguments, + Predicate filter, + SqlAstTranslator walker) { + if ( every ) { + sqlAppender.appendSql( "min(iif(" ); + } + else { + sqlAppender.appendSql( "max(iif(" ); + } + if ( filter != null ) { + filter.accept( walker ); + sqlAppender.appendSql( ",iif(" ); + sqlAstArguments.get( 0 ).accept( walker ); + sqlAppender.appendSql( ",1,0),null))" ); + } + else { + sqlAstArguments.get( 0 ).accept( walker ); + sqlAppender.appendSql( ",1,0))" ); + } + } + + @Override + public void render( + SqlAppender sqlAppender, List sqlAstArguments, SqlAstTranslator walker) { + this.render( sqlAppender, sqlAstArguments, null, walker ); + } +} diff --git a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java index cc90ad1c7b..befa9c517e 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java @@ -3548,8 +3548,9 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem public SqmExpression visitMaxFunction(HqlParser.MaxFunctionContext ctx) { final SqmExpression arg = (SqmExpression) ctx.expression().accept( this ); //ignore DISTINCT - return getFunctionDescriptor("max").generateSqmExpression( - arg, + return getFunctionDescriptor("max").generateAggregateSqmExpression( + singletonList( arg ), + getFilterExpression( ctx.filterClause() ), (AllowableFunctionReturnType) arg.getNodeType(), creationContext.getQueryEngine(), creationContext.getJpaMetamodel().getTypeConfiguration() @@ -3560,8 +3561,9 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem public SqmExpression visitMinFunction(HqlParser.MinFunctionContext ctx) { final SqmExpression arg = (SqmExpression) ctx.expression().accept( this ); //ignore DISTINCT - return getFunctionDescriptor("min").generateSqmExpression( - arg, + return getFunctionDescriptor("min").generateAggregateSqmExpression( + singletonList( arg ), + getFilterExpression( ctx.filterClause() ), (AllowableFunctionReturnType) arg.getNodeType(), creationContext.getQueryEngine(), creationContext.getJpaMetamodel().getTypeConfiguration() @@ -3575,8 +3577,9 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem ? new SqmDistinct<>(arg, getCreationContext().getNodeBuilder()) : arg; - return getFunctionDescriptor("sum").generateSqmExpression( - argument, + return getFunctionDescriptor("sum").generateAggregateSqmExpression( + singletonList( argument ), + getFilterExpression(ctx.filterClause()), null, creationContext.getQueryEngine(), creationContext.getJpaMetamodel().getTypeConfiguration() @@ -3593,8 +3596,13 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem final SqmExpression argument = (SqmExpression) ctx.predicate().accept( this ); - return getFunctionDescriptor("every").generateSqmExpression( - argument, + if ( argument instanceof SqmSubQuery && ctx.filterClause() != null ) { + throw new SemanticException( "Quantified expression cannot have a filter clause!" ); + } + + return getFunctionDescriptor("every").generateAggregateSqmExpression( + singletonList( argument ), + getFilterExpression( ctx.filterClause() ), resolveExpressableTypeBasic( Boolean.class ), creationContext.getQueryEngine(), creationContext.getJpaMetamodel().getTypeConfiguration() @@ -3611,8 +3619,13 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem final SqmExpression argument = (SqmExpression) ctx.predicate().accept( this ); - return getFunctionDescriptor("any").generateSqmExpression( - argument, + if ( argument instanceof SqmSubQuery && ctx.filterClause() != null ) { + throw new SemanticException( "Quantified expression cannot have a filter clause!" ); + } + + return getFunctionDescriptor("any").generateAggregateSqmExpression( + singletonList( argument ), + getFilterExpression( ctx.filterClause() ), resolveExpressableTypeBasic( Boolean.class ), creationContext.getQueryEngine(), creationContext.getJpaMetamodel().getTypeConfiguration() @@ -3626,8 +3639,9 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem ? new SqmDistinct<>( arg, getCreationContext().getNodeBuilder() ) : arg; - return getFunctionDescriptor("avg").generateSqmExpression( - argument, + return getFunctionDescriptor("avg").generateAggregateSqmExpression( + singletonList( argument ), + getFilterExpression( ctx.filterClause() ), resolveExpressableTypeBasic( Double.class ), creationContext.getQueryEngine(), creationContext.getJpaMetamodel().getTypeConfiguration() @@ -3643,14 +3657,22 @@ public class SemanticQueryBuilder extends HqlParserBaseVisitor implem ? new SqmDistinct<>( arg, getCreationContext().getNodeBuilder() ) : arg; - return getFunctionDescriptor("count").generateSqmExpression( - argument, + return getFunctionDescriptor("count").generateAggregateSqmExpression( + singletonList( argument ), + getFilterExpression( ctx.filterClause() ), resolveExpressableTypeBasic( Long.class ), creationContext.getQueryEngine(), creationContext.getJpaMetamodel().getTypeConfiguration() ); } + private SqmPredicate getFilterExpression(HqlParser.FilterClauseContext filterClauseCtx) { + if (filterClauseCtx == null) { + return null; + } + return (SqmPredicate) filterClauseCtx.whereClause().predicate().accept( this ); + } + @Override public SqmExpression visitCube(HqlParser.CubeContext ctx) { return new SqmSummarization<>( diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/function/AbstractSqmFunctionDescriptor.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/function/AbstractSqmFunctionDescriptor.java index f0651e65ee..024c9e7950 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/function/AbstractSqmFunctionDescriptor.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/function/AbstractSqmFunctionDescriptor.java @@ -108,6 +108,24 @@ public abstract class AbstractSqmFunctionDescriptor implements SqmFunctionDescri ); } + @Override + public final SelfRenderingSqmFunction generateAggregateSqmExpression( + List> arguments, + SqmPredicate filter, + AllowableFunctionReturnType impliedResultType, + QueryEngine queryEngine, + TypeConfiguration typeConfiguration) { + argumentsValidator.validate( arguments ); + + return generateSqmAggregateFunctionExpression( + arguments, + filter, + impliedResultType, + queryEngine, + typeConfiguration + ); + } + /** * Return an SQM node or subtree representing an invocation of this function * with the given arguments. This method may be overridden in the case of diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/function/AbstractSqmSelfRenderingFunctionDescriptor.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/function/AbstractSqmSelfRenderingFunctionDescriptor.java index 94aabba095..4e7d4a6c45 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/function/AbstractSqmSelfRenderingFunctionDescriptor.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/function/AbstractSqmSelfRenderingFunctionDescriptor.java @@ -24,7 +24,7 @@ import java.util.List; * @author Gavin King */ public abstract class AbstractSqmSelfRenderingFunctionDescriptor - extends AbstractSqmFunctionDescriptor { + extends AbstractSqmFunctionDescriptor implements FunctionRenderingSupport { private final boolean isAggregate; @@ -59,7 +59,7 @@ public abstract class AbstractSqmSelfRenderingFunctionDescriptor } @Override - public SelfRenderingSqmFunction generateAggregateSqmExpression( + public SelfRenderingSqmAggregateFunction generateSqmAggregateFunctionExpression( List> arguments, SqmPredicate filter, AllowableFunctionReturnType impliedResultType, @@ -70,7 +70,7 @@ public abstract class AbstractSqmSelfRenderingFunctionDescriptor } return new SelfRenderingSqmAggregateFunction<>( this, - this::render, + this, arguments, filter, impliedResultType, @@ -95,5 +95,4 @@ public abstract class AbstractSqmSelfRenderingFunctionDescriptor SqlAstTranslator walker) { render( sqlAppender, sqlAstArguments, walker ); } - } diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/function/NamedSqmFunctionDescriptor.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/function/NamedSqmFunctionDescriptor.java index acf1c7cc7d..c5e3a1852c 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/function/NamedSqmFunctionDescriptor.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/function/NamedSqmFunctionDescriptor.java @@ -115,11 +115,16 @@ public class NamedSqmFunctionDescriptor if ( !firstPass ) { sqlAppender.appendSql( ", " ); } - if ( caseWrapper && !( arg instanceof Distinct ) && !( arg instanceof Star ) ) { + if ( caseWrapper && !( arg instanceof Distinct ) ) { sqlAppender.appendSql( "case when " ); filter.accept( translator ); sqlAppender.appendSql( " then " ); - translator.render( arg, argumentRenderingMode ); + if ( ( arg instanceof Star ) ) { + sqlAppender.appendSql( "1" ); + } + else { + translator.render( arg, argumentRenderingMode ); + } sqlAppender.appendSql( " else null end" ); } else { diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/AggregateFilterClauseTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/AggregateFilterClauseTest.java new file mode 100644 index 0000000000..c981d741f0 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/AggregateFilterClauseTest.java @@ -0,0 +1,268 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.query.hql; + +import java.util.Date; + +import org.hibernate.query.Query; +import org.hibernate.query.SemanticException; + +import org.hibernate.testing.orm.domain.StandardDomainModel; +import org.hibernate.testing.orm.domain.gambit.EntityOfBasics; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.ServiceRegistry; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; + +/** + * @author Jan Schatteman + */ +@ServiceRegistry +@DomainModel(standardModels = StandardDomainModel.GAMBIT) +@SessionFactory +public class AggregateFilterClauseTest { + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( + em -> { + Date now = new Date(); + + EntityOfBasics entity1 = new EntityOfBasics(); + entity1.setId( 1 ); + entity1.setTheInt( 5 ); + entity1.setTheInteger( -1 ); + entity1.setTheDouble( 5.0 ); + entity1.setTheDate( now ); + entity1.setTheBoolean( true ); + em.persist( entity1 ); + + EntityOfBasics entity2 = new EntityOfBasics(); + entity2.setId( 2 ); + entity2.setTheInt( 6 ); + entity2.setTheInteger( -2 ); + entity2.setTheDouble( 6.0 ); + entity2.setTheBoolean( true ); + em.persist( entity2 ); + + EntityOfBasics entity3 = new EntityOfBasics(); + entity3.setId( 3 ); + entity3.setTheInt( 7 ); + entity3.setTheInteger( 3 ); + entity3.setTheDouble( 7.0 ); + entity3.setTheBoolean( false ); + entity3.setTheDate( new Date(now.getTime() + 200000L) ); + em.persist( entity3 ); + + EntityOfBasics entity4 = new EntityOfBasics(); + entity4.setId( 4 ); + entity4.setTheInt( 13 ); + entity4.setTheInteger( 4 ); + entity4.setTheDouble( 13.0 ); + entity4.setTheBoolean( false ); + entity4.setTheDate( new Date(now.getTime() + 300000L) ); + em.persist( entity4 ); + + EntityOfBasics entity5 = new EntityOfBasics(); + entity5.setId( 5 ); + entity5.setTheInteger( 5 ); + entity5.setTheDouble( 9.0 ); + entity5.setTheBoolean( false ); + em.persist( entity5 ); + } + ); + } + + @AfterEach + public void tearDown(SessionFactoryScope scope) { + scope.inTransaction( + session -> session.createQuery( "delete from EntityOfBasics" ).executeUpdate() + ); + } + + @Test + public void testSimpleSum(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Long expected = 31L; + Query q = session.createQuery( "select sum(eob.theInt) from EntityOfBasics eob" ); + assertEquals( expected, q.getSingleResult(), "expected " + expected + ", got " + q.getSingleResult() ); + } + ); + } + + @Test + public void testSumWithFilterClause(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Long expected = 11L; + Query q = session.createQuery( "select sum(eob.theInt) filter(where eob.theBoolean = true) from EntityOfBasics eob" ); + assertEquals( expected, q.getSingleResult(), "expected " + expected + ", got " + q.getSingleResult() ); + } + ); + } + + @Test + public void testSimpleAvg(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Double expected = 8.0D; + Query q = session.createQuery( "select avg(eob.theDouble) from EntityOfBasics eob" ); + assertEquals( expected, q.getSingleResult(), "expected " + expected + ", got " + q.getSingleResult() ); + } + ); + } + + @Test + public void testAvgWithFilterClause(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Double expected = 5.5D; + Query q = session.createQuery( "select avg(eob.theDouble) filter(where eob.theBoolean = true) from EntityOfBasics eob" ); + assertEquals( expected, q.getSingleResult(), "expected " + expected + ", got " + q.getSingleResult() ); + } + ); + } + + @Test + public void testSimpleMin(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Double expected = 5D; + Query q = session.createQuery( "select min(eob.theDouble) from EntityOfBasics eob" ); + assertEquals( expected, q.getSingleResult(), "expected " + expected + ", got " + q.getSingleResult() ); + } + ); + } + + @Test + public void testMinWithFilterClause(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Double expected = 7D; + Query q = session.createQuery( "select min(eob.theDouble) filter(where eob.theBoolean = false) from EntityOfBasics eob" ); + assertEquals( expected, q.getSingleResult(), "expected " + expected + ", got " + q.getSingleResult() ); + } + ); + } + + @Test + public void testSimpleMax(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Double expected = 13D; + Query q = session.createQuery( "select max(eob.theDouble) from EntityOfBasics eob" ); + assertEquals( expected, q.getSingleResult(), "expected " + expected + ", got " + q.getSingleResult() ); + } + ); + } + + @Test + public void testMaxWithFilterClause(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Double expected = 6D; + Query q = session.createQuery( "select max(eob.theDouble) filter(where eob.theBoolean = true) from EntityOfBasics eob" ); + assertEquals( expected, q.getSingleResult(), "expected " + expected + ", got " + q.getSingleResult() ); + } + ); + } + + @Test + public void testSimpleCount(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Long expected = 5L; + Query q = session.createQuery( "select count(*) from EntityOfBasics eob" ); + assertEquals( expected, q.getSingleResult(), "expected " + expected + ", got " + q.getSingleResult() ); + + expected = 3L; + q = session.createQuery( "select count(eob.theDate) from EntityOfBasics eob" ); + assertEquals( expected, q.getSingleResult(), "expected " + expected + ", got " + q.getSingleResult() ); + } + ); + } + + @Test + public void testCountWithFilterClause(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Long expected = 3L; + Query q = session.createQuery( "select count(*) filter(where eob.theBoolean = false) from EntityOfBasics eob" ); + assertEquals( expected, q.getSingleResult(), "expected " + expected + ", got " + q.getSingleResult() ); + + expected = 2L; + q = session.createQuery( "select count(eob.theDate) filter(where eob.theBoolean = false) from EntityOfBasics eob" ); + assertEquals( expected, q.getSingleResult(), "expected " + expected + ", got " + q.getSingleResult() ); + } + ); + } + + @Test + // poor test verification, but ok ... + public void testSimpleEveryAll(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Query q = session.createQuery( "select every( eob.theInteger > 0 ) from EntityOfBasics eob" ); + assertFalse( (Boolean) q.getSingleResult() ); + + q = session.createQuery( "select any( eob.theInteger < 0 ) from EntityOfBasics eob" ); + assertTrue( (Boolean) q.getSingleResult() ); + } + ); + } + + @Test + // poor test verification, but ok ... + public void testEveryAllWithFilterClause(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Query q = session.createQuery( "select every( eob.theInteger > 0 ) filter ( where eob.theBoolean = false ) from EntityOfBasics eob" ); + assertTrue( (Boolean) q.getSingleResult() ); + + q = session.createQuery( "select any( eob.theInteger < 0 ) filter ( where eob.theBoolean = false ) from EntityOfBasics eob" ); + assertFalse( (Boolean) q.getSingleResult() ); + } + ); + } + + @Test + public void testIllegalSubquery(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Exception e = Assertions.assertThrows( + IllegalArgumentException.class, + () -> { + session.createQuery( "select every( eob.theInteger > 0 ) filter ( where select 1 ) from EntityOfBasics eob" ); + } + + ); + assertEquals( SemanticException.class, e.getCause().getClass() ); + } + ); + scope.inTransaction( + session -> { + Exception e = Assertions.assertThrows( + IllegalArgumentException.class, + () -> { + session.createQuery( "select any( eob.theInteger > 0 ) filter ( where select 1 ) from EntityOfBasics eob" ); + } + + ); + assertEquals( SemanticException.class, e.getCause().getClass() ); + } + ); + } +}