From 54402da721cb0e2273c073511c2d3cf4045aa432 Mon Sep 17 00:00:00 2001 From: Marco Belladelli Date: Wed, 4 Jan 2023 18:27:36 +0100 Subject: [PATCH] HHH-15985 Custom trunc and round function for PostgreSQL and Cockroach --- .../dialect/CockroachLegacyDialect.java | 8 +- .../dialect/PostgreSQLLegacyDialect.java | 9 ++ .../hibernate/dialect/CockroachDialect.java | 8 +- .../hibernate/dialect/PostgreSQLDialect.java | 11 ++- .../PostgreSQLTruncRoundFunction.java | 99 +++++++++++++++++++ .../PostgreSQLTruncRoundFunctionTest.java | 98 ++++++++++++++++++ 6 files changed, 229 insertions(+), 4 deletions(-) create mode 100644 hibernate-core/src/main/java/org/hibernate/dialect/function/PostgreSQLTruncRoundFunction.java create mode 100644 hibernate-core/src/test/java/org/hibernate/orm/test/dialect/function/PostgreSQLTruncRoundFunctionTest.java diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/CockroachLegacyDialect.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/CockroachLegacyDialect.java index ec9bc2c70a..979103fb78 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/CockroachLegacyDialect.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/CockroachLegacyDialect.java @@ -42,6 +42,7 @@ import org.hibernate.dialect.SpannerDialect; import org.hibernate.dialect.TimeZoneSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.function.FormatFunction; +import org.hibernate.dialect.function.PostgreSQLTruncRoundFunction; import org.hibernate.dialect.identity.CockroachDBIdentityColumnSupport; import org.hibernate.dialect.identity.IdentityColumnSupport; import org.hibernate.dialect.pagination.LimitHandler; @@ -383,9 +384,9 @@ public class CockroachLegacyDialect extends Dialect { functionFactory.degrees(); functionFactory.radians(); functionFactory.pi(); - functionFactory.trunc(); //TODO: emulate second arg functionFactory.log(); functionFactory.log10_log(); + functionFactory.round(); functionFactory.bitandorxornot_operator(); functionFactory.bitAndOr(); @@ -407,6 +408,11 @@ public class CockroachLegacyDialect extends Dialect { functionFactory.listagg_stringAgg( "string" ); functionFactory.inverseDistributionOrderedSetAggregates(); functionFactory.hypotheticalOrderedSetAggregates_windowEmulation(); + + functionContributions.getFunctionRegistry().register( + "trunc", new PostgreSQLTruncRoundFunction( "trunc", getVersion().isSameOrAfter( 22, 2 ) ) + ); + functionContributions.getFunctionRegistry().registerAlternateKey( "truncate", "trunc" ); } @Override diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/PostgreSQLLegacyDialect.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/PostgreSQLLegacyDialect.java index 4edddcadfa..57cbfbdebb 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/PostgreSQLLegacyDialect.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/PostgreSQLLegacyDialect.java @@ -44,6 +44,7 @@ import org.hibernate.dialect.aggregate.AggregateSupport; import org.hibernate.dialect.aggregate.PostgreSQLAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.function.PostgreSQLMinMaxFunction; +import org.hibernate.dialect.function.PostgreSQLTruncRoundFunction; import org.hibernate.dialect.identity.IdentityColumnSupport; import org.hibernate.dialect.identity.PostgreSQLIdentityColumnSupport; import org.hibernate.dialect.pagination.LimitHandler; @@ -605,6 +606,14 @@ public class PostgreSQLLegacyDialect extends Dialect { functionContributions.getFunctionRegistry().register( "min", new PostgreSQLMinMaxFunction( "min" ) ); functionContributions.getFunctionRegistry().register( "max", new PostgreSQLMinMaxFunction( "max" ) ); } + + functionContributions.getFunctionRegistry().register( + "round", new PostgreSQLTruncRoundFunction( "round", true ) + ); + functionContributions.getFunctionRegistry().register( + "trunc", new PostgreSQLTruncRoundFunction( "trunc", true ) + ); + functionContributions.getFunctionRegistry().registerAlternateKey( "truncate", "trunc" ); } /** diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/CockroachDialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/CockroachDialect.java index df128aec12..d8afd9dafe 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/CockroachDialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/CockroachDialect.java @@ -27,6 +27,7 @@ import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.TypeContributions; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.function.FormatFunction; +import org.hibernate.dialect.function.PostgreSQLTruncRoundFunction; import org.hibernate.dialect.identity.CockroachDBIdentityColumnSupport; import org.hibernate.dialect.identity.IdentityColumnSupport; import org.hibernate.dialect.pagination.LimitHandler; @@ -389,9 +390,9 @@ public class CockroachDialect extends Dialect { functionFactory.degrees(); functionFactory.radians(); functionFactory.pi(); - functionFactory.trunc(); //TODO: emulate second arg functionFactory.log(); functionFactory.log10_log(); + functionFactory.round(); functionFactory.bitandorxornot_operator(); functionFactory.bitAndOr(); @@ -413,6 +414,11 @@ public class CockroachDialect extends Dialect { functionFactory.listagg_stringAgg( "string" ); functionFactory.inverseDistributionOrderedSetAggregates(); functionFactory.hypotheticalOrderedSetAggregates_windowEmulation(); + + functionContributions.getFunctionRegistry().register( + "trunc", new PostgreSQLTruncRoundFunction( "trunc", getVersion().isSameOrAfter( 22, 2 ) ) + ); + functionContributions.getFunctionRegistry().registerAlternateKey( "truncate", "trunc" ); } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/PostgreSQLDialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/PostgreSQLDialect.java index 1b1d769ef0..552972101c 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/PostgreSQLDialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/PostgreSQLDialect.java @@ -28,6 +28,7 @@ import org.hibernate.boot.model.TypeContributions; import org.hibernate.dialect.aggregate.AggregateSupport; import org.hibernate.dialect.aggregate.PostgreSQLAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; +import org.hibernate.dialect.function.PostgreSQLTruncRoundFunction; import org.hibernate.dialect.function.PostgreSQLMinMaxFunction; import org.hibernate.dialect.identity.IdentityColumnSupport; import org.hibernate.dialect.identity.PostgreSQLIdentityColumnSupport; @@ -523,8 +524,6 @@ public class PostgreSQLDialect extends Dialect { CommonFunctionFactory functionFactory = new CommonFunctionFactory(functionContributions); - functionFactory.round_roundFloor(); //Postgres round(x,n) does not accept double - functionFactory.trunc_truncFloor(); functionFactory.cot(); functionFactory.radians(); functionFactory.degrees(); @@ -588,6 +587,14 @@ public class PostgreSQLDialect extends Dialect { functionContributions.getFunctionRegistry().register( "min", new PostgreSQLMinMaxFunction( "min" ) ); functionContributions.getFunctionRegistry().register( "max", new PostgreSQLMinMaxFunction( "max" ) ); } + + functionContributions.getFunctionRegistry().register( + "round", new PostgreSQLTruncRoundFunction( "round", true ) + ); + functionContributions.getFunctionRegistry().register( + "trunc", new PostgreSQLTruncRoundFunction( "trunc", true ) + ); + functionContributions.getFunctionRegistry().registerAlternateKey( "truncate", "trunc" ); } /** diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/PostgreSQLTruncRoundFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/PostgreSQLTruncRoundFunction.java new file mode 100644 index 0000000000..2b3df38145 --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/PostgreSQLTruncRoundFunction.java @@ -0,0 +1,99 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.dialect.function; + +import java.util.List; + +import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor; +import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.SqlAstNode; +import org.hibernate.sql.ast.tree.expression.Expression; +import org.hibernate.type.descriptor.jdbc.JdbcType; + +import static org.hibernate.query.sqm.produce.function.FunctionParameterType.INTEGER; +import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUMERIC; + +/** + * PostgreSQL only supports the two-argument {@code trunc} and {@code round} functions + * with the following signatures: + * + *

+ * This custom function falls back to using {@code floor} as a workaround only when necessary, + * e.g. when there are 2 arguments to the function and either: + *

+ * + * @author Marco Belladelli + * @see PostgreSQL documentation + */ +public class PostgreSQLTruncRoundFunction extends AbstractSqmSelfRenderingFunctionDescriptor { + private final boolean supportsTwoArguments; + + public PostgreSQLTruncRoundFunction(String name, boolean supportsTwoArguments) { + super( + name, + new ArgumentTypesValidator( StandardArgumentsValidators.between( 1, 2 ), NUMERIC, INTEGER ), + StandardFunctionReturnTypeResolvers.useArgType( 1 ), + StandardFunctionArgumentTypeResolvers.invariant( NUMERIC, INTEGER ) + ); + this.supportsTwoArguments = supportsTwoArguments; + } + + @Override + public void render(SqlAppender sqlAppender, List arguments, SqlAstTranslator walker) { + final int numberOfArguments = arguments.size(); + final Expression firstArg = (Expression) arguments.get( 0 ); + final JdbcType jdbcType = firstArg.getExpressionType().getJdbcMappings().get( 0 ).getJdbcType(); + if ( numberOfArguments == 1 || supportsTwoArguments && jdbcType.isDecimal() ) { + // use native two-argument function + sqlAppender.appendSql( getName() ); + sqlAppender.appendSql( "(" ); + firstArg.accept( walker ); + if ( numberOfArguments > 1 ) { + sqlAppender.appendSql( ", " ); + arguments.get( 1 ).accept( walker ); + } + sqlAppender.appendSql( ")" ); + } + else { + // workaround using floor + if ( getName().equals( "trunc" ) ) { + sqlAppender.appendSql( "sign(" ); + firstArg.accept( walker ); + sqlAppender.appendSql( ")*floor(abs(" ); + firstArg.accept( walker ); + sqlAppender.appendSql( ")*1e" ); + arguments.get( 1 ).accept( walker ); + } + else { + sqlAppender.appendSql( "floor(" ); + firstArg.accept( walker ); + sqlAppender.appendSql( "*1e" ); + arguments.get( 1 ).accept( walker ); + sqlAppender.appendSql( "+0.5" ); + } + sqlAppender.appendSql( ")/1e" ); + arguments.get( 1 ).accept( walker ); + } + } + + @Override + public String getArgumentListSignature() { + return "(NUMERIC number[, INTEGER places])"; + } +} diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/dialect/function/PostgreSQLTruncRoundFunctionTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/dialect/function/PostgreSQLTruncRoundFunctionTest.java new file mode 100644 index 0000000000..b1d5432059 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/dialect/function/PostgreSQLTruncRoundFunctionTest.java @@ -0,0 +1,98 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.dialect.function; + +import java.math.BigDecimal; + +import org.hibernate.dialect.CockroachDialect; +import org.hibernate.dialect.PostgreSQLDialect; + +import org.hibernate.testing.jdbc.SQLStatementInspector; +import org.hibernate.testing.orm.domain.StandardDomainModel; +import org.hibernate.testing.orm.domain.animal.Human; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialect; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * @author Marco Belladelli + */ +@DomainModel(standardModels = StandardDomainModel.ANIMAL) +@SessionFactory(statementInspectorClass = SQLStatementInspector.class) +public class PostgreSQLTruncRoundFunctionTest { + @AfterEach + public void tearDown(SessionFactoryScope scope) { + scope.inTransaction( session -> session.createMutationQuery( "delete from Human" ).executeUpdate() ); + } + + @Test + @RequiresDialect(PostgreSQLDialect.class) + @RequiresDialect(value = CockroachDialect.class, majorVersion = 22, minorVersion = 2, comment = "CockroachDB didn't support the two-argument trunc before version 22.2") + public void testTrunc(SessionFactoryScope scope) { + testFunction( scope, "trunc", "floor" ); + } + + @Test + @RequiresDialect(PostgreSQLDialect.class) + public void testRound(SessionFactoryScope scope) { + testFunction( scope, "round", "floor" ); + } + + @Test + @RequiresDialect(value = CockroachDialect.class, comment = "CockroachDB natively supports round with two args for both decimal and float types") + public void testRoundWithoutWorkaround(SessionFactoryScope scope) { + testFunction( scope, "round", "round" ); + } + + private void testFunction(SessionFactoryScope scope, String function, String workaround) { + final SQLStatementInspector sqlStatementInspector = (SQLStatementInspector) scope.getStatementInspector(); + scope.inTransaction( session -> { + Human human = new Human(); + human.setId( 1L ); + human.setHeightInches( 1.78253d ); + human.setFloatValue( 1.78253f ); + human.setBigDecimalValue( new BigDecimal( "1.78253" ) ); + session.persist( human ); + } ); + + scope.inTransaction( session -> { + sqlStatementInspector.clear(); + assertEquals( + 1.78d, + session.createQuery( + String.format( "select %s(h.heightInches, 2) from Human h", function ), + Double.class + ).getSingleResult() + ); + assertTrue( sqlStatementInspector.getSqlQueries().get( 0 ).contains( workaround ) ); + sqlStatementInspector.clear(); + assertEquals( + 1.78f, + session.createQuery( + String.format( "select %s(h.floatValue, 2) from Human h", function ), + Float.class + ).getSingleResult() + ); + assertTrue( sqlStatementInspector.getSqlQueries().get( 0 ).contains( workaround ) ); + sqlStatementInspector.clear(); + assertEquals( + 0, + session.createQuery( + String.format( "select %s(h.bigDecimalValue, 2) from Human h", function ), + BigDecimal.class + ).getSingleResult().compareTo( new BigDecimal( "1.78" ) ) + ); + assertTrue( sqlStatementInspector.getSqlQueries().get( 0 ).contains( function ) ); + } ); + } +}