From 75888b94f284dffba48b3f055bb555a70538e636 Mon Sep 17 00:00:00 2001 From: Gavin King Date: Mon, 7 Feb 2022 13:26:31 +0100 Subject: [PATCH] use CastFunction to do typecasts --- .../dialect/function/AvgFunction.java | 26 +++++----- .../dialect/function/CastFunction.java | 5 +- .../function/CommonFunctionFactory.java | 10 ++-- .../dialect/function/CountFunction.java | 3 +- .../dialect/function/ListaggFunction.java | 4 +- .../function/TimestampaddFunction.java | 4 +- .../sql/ast/tree/expression/CastTarget.java | 1 - .../orm/test/query/hql/AvgFunctionTest.java | 51 +++++++++++++++++++ 8 files changed, 73 insertions(+), 31 deletions(-) create mode 100644 hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/AvgFunctionTest.java diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/AvgFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/AvgFunction.java index 518fc422ea..7d4bb2aa45 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/AvgFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/AvgFunction.java @@ -6,25 +6,25 @@ */ package org.hibernate.dialect.function; -import java.util.Collections; +import java.util.Arrays; import java.util.List; import org.hibernate.dialect.Dialect; import org.hibernate.metamodel.mapping.JdbcMapping; -import org.hibernate.query.sqm.CastType; import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor; import org.hibernate.query.sqm.function.FunctionKind; import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator; import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; -import org.hibernate.query.sqm.produce.function.internal.PatternRenderer; import org.hibernate.sql.ast.SqlAstNodeRenderingMode; import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.SqlAstNode; +import org.hibernate.sql.ast.tree.expression.CastTarget; import org.hibernate.sql.ast.tree.expression.Distinct; import org.hibernate.sql.ast.tree.expression.Expression; import org.hibernate.sql.ast.tree.predicate.Predicate; +import org.hibernate.type.BasicType; import org.hibernate.type.StandardBasicTypes; import org.hibernate.type.spi.TypeConfiguration; @@ -35,27 +35,27 @@ import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUM */ public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor { - public static final String FUNCTION_NAME = "avg"; - private final Dialect dialect; private final SqlAstNodeRenderingMode defaultArgumentRenderingMode; - private final String doubleCastType; + private final CastFunction castFunction; + private final BasicType doubleType; public AvgFunction( Dialect dialect, TypeConfiguration typeConfiguration, - SqlAstNodeRenderingMode defaultArgumentRenderingMode, - String doubleCastType) { + SqlAstNodeRenderingMode defaultArgumentRenderingMode) { super( - FUNCTION_NAME, + "avg", FunctionKind.AGGREGATE, new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 1 ), NUMERIC ), StandardFunctionReturnTypeResolvers.invariant( typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE ) ) ); - this.dialect = dialect; this.defaultArgumentRenderingMode = defaultArgumentRenderingMode; - this.doubleCastType = doubleCastType; + doubleType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE ); + //This is kinda wrong, we're supposed to use findFunctionDescriptor("cast"), not instantiate CastFunction + //However, since no Dialects currently override the cast() function, it's OK for now + castFunction = new CastFunction( dialect, dialect.getPreferredSqlTypeCodeForBoolean() ); } @Override @@ -101,9 +101,7 @@ public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor { final JdbcMapping sourceMapping = realArg.getExpressionType().getJdbcMappings().get( 0 ); // Only cast to float/double if this is an integer if ( sourceMapping.getJdbcType().isInteger() ) { - final String cast = dialect.castPattern( sourceMapping.getCastType(), CastType.DOUBLE ); - new PatternRenderer( cast.replace( "?2", doubleCastType ) ) - .render( sqlAppender, Collections.singletonList( realArg ), translator ); + castFunction.render( sqlAppender, Arrays.asList( realArg, new CastTarget(doubleType) ), translator ); } else { translator.render( realArg, defaultArgumentRenderingMode ); diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java index ce068d04ca..cb54fc4741 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java @@ -67,10 +67,7 @@ public class CastFunction extends AbstractSqmSelfRenderingFunctionDescriptor { private CastType getCastType(JdbcMapping sourceMapping) { final CastType castType = sourceMapping.getCastType(); - if ( castType == CastType.BOOLEAN ) { - return booleanCastType; - } - return castType; + return castType == CastType.BOOLEAN ? booleanCastType : castType; } // @Override diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java index 7f9051cc77..66c9479bde 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java @@ -19,7 +19,6 @@ import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; import org.hibernate.sql.ast.SqlAstNodeRenderingMode; import org.hibernate.type.BasicType; import org.hibernate.type.BasicTypeRegistry; -import org.hibernate.type.SqlTypes; import org.hibernate.type.StandardBasicTypes; import org.hibernate.type.spi.TypeConfiguration; @@ -1705,7 +1704,7 @@ public class CommonFunctionFactory { .register(); functionRegistry.register( - CountFunction.FUNCTION_NAME, + "count", new CountFunction( dialect, typeConfiguration, @@ -1720,19 +1719,18 @@ public class CommonFunctionFactory { Dialect dialect, SqlAstNodeRenderingMode inferenceArgumentRenderingMode) { functionRegistry.register( - AvgFunction.FUNCTION_NAME, + "avg", new AvgFunction( dialect, typeConfiguration, - inferenceArgumentRenderingMode, - dialect.getTypeName( SqlTypes.DOUBLE ) + inferenceArgumentRenderingMode ) ); } public void listagg(String emptyWithinReplacement) { functionRegistry.register( - ListaggFunction.FUNCTION_NAME, + "listagg", new ListaggFunction( emptyWithinReplacement, typeConfiguration ) ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/CountFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/CountFunction.java index 7fd3b6b94c..6f7e5c5438 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/CountFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/CountFunction.java @@ -41,7 +41,6 @@ import org.hibernate.type.spi.TypeConfiguration; */ public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor { - public static final String FUNCTION_NAME = "count"; private final Dialect dialect; private final SqlAstNodeRenderingMode defaultArgumentRenderingMode; private final String concatOperator; @@ -54,7 +53,7 @@ public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor { String concatOperator, String concatArgumentCastType) { super( - FUNCTION_NAME, + "count", FunctionKind.AGGREGATE, StandardArgumentsValidators.exactly( 1 ), StandardFunctionReturnTypeResolvers.invariant( diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/ListaggFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/ListaggFunction.java index b47ddab9f9..de313d545b 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/ListaggFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/ListaggFunction.java @@ -31,13 +31,11 @@ import static org.hibernate.query.sqm.produce.function.FunctionParameterType.STR */ public class ListaggFunction extends AbstractSqmSelfRenderingFunctionDescriptor { - public static final String FUNCTION_NAME = "listagg"; - private final String emptyWithinReplacement; public ListaggFunction(String emptyWithinReplacement, TypeConfiguration typeConfiguration) { super( - FUNCTION_NAME, + "listagg", FunctionKind.ORDERED_SET_AGGREGATE, new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 2 ), STRING, STRING ), StandardFunctionReturnTypeResolvers.invariant( diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/TimestampaddFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/TimestampaddFunction.java index c56bd2814a..d53cb2f44b 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/TimestampaddFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/TimestampaddFunction.java @@ -60,8 +60,10 @@ public class TimestampaddFunction StandardFunctionReturnTypeResolvers.useArgType( 3 ) ); this.dialect = dialect; - this.castFunction = new CastFunction( dialect, Types.BOOLEAN ); this.integerType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.INTEGER ); + //This is kinda wrong, we're supposed to use findFunctionDescriptor("cast"), not instantiate CastFunction + //However, since no Dialects currently override the cast() function, it's OK for now + this.castFunction = new CastFunction( dialect, dialect.getPreferredSqlTypeCodeForBoolean() ); } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CastTarget.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CastTarget.java index 79d0b67617..1d6cb501ef 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CastTarget.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CastTarget.java @@ -6,7 +6,6 @@ */ package org.hibernate.sql.ast.tree.expression; -import org.hibernate.metamodel.mapping.BasicValuedMapping; import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.sql.ast.SqlAstWalker; diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/AvgFunctionTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/AvgFunctionTest.java new file mode 100644 index 0000000000..c806b36919 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/AvgFunctionTest.java @@ -0,0 +1,51 @@ +package org.hibernate.orm.test.query.hql; + +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +@DomainModel(annotatedClasses = AvgFunctionTest.Value.class) +@SessionFactory +public class AvgFunctionTest { + + @Test + public void test(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + session.persist( new Value(0) ); + session.persist( new Value(1) ); + session.persist( new Value(2) ); + session.persist( new Value(3) ); + assertThat( + session.createQuery("select avg(value) from Value", Double.class) + .getSingleResult(), + is(1.5) + ); + assertThat( + session.createQuery("select avg(integerValue) from Value", Double.class) + .getSingleResult(), + is(1.5) + ); + } + ); + } + + @Entity(name="Value") + public static class Value { + public Value() {} + public Value(int value) { + this.value = value; + this.integerValue = value; + } + @Id + double value; + int integerValue; + } + +}