HHH-15985 Custom trunc and round function for PostgreSQL and Cockroach

This commit is contained in:
Marco Belladelli 2023-01-04 18:27:36 +01:00 committed by Christian Beikov
parent 74689f26a5
commit 54402da721
6 changed files with 229 additions and 4 deletions

View File

@ -42,6 +42,7 @@ import org.hibernate.dialect.SpannerDialect;
import org.hibernate.dialect.TimeZoneSupport;
import org.hibernate.dialect.function.CommonFunctionFactory;
import org.hibernate.dialect.function.FormatFunction;
import org.hibernate.dialect.function.PostgreSQLTruncRoundFunction;
import org.hibernate.dialect.identity.CockroachDBIdentityColumnSupport;
import org.hibernate.dialect.identity.IdentityColumnSupport;
import org.hibernate.dialect.pagination.LimitHandler;
@ -383,9 +384,9 @@ public class CockroachLegacyDialect extends Dialect {
functionFactory.degrees();
functionFactory.radians();
functionFactory.pi();
functionFactory.trunc(); //TODO: emulate second arg
functionFactory.log();
functionFactory.log10_log();
functionFactory.round();
functionFactory.bitandorxornot_operator();
functionFactory.bitAndOr();
@ -407,6 +408,11 @@ public class CockroachLegacyDialect extends Dialect {
functionFactory.listagg_stringAgg( "string" );
functionFactory.inverseDistributionOrderedSetAggregates();
functionFactory.hypotheticalOrderedSetAggregates_windowEmulation();
functionContributions.getFunctionRegistry().register(
"trunc", new PostgreSQLTruncRoundFunction( "trunc", getVersion().isSameOrAfter( 22, 2 ) )
);
functionContributions.getFunctionRegistry().registerAlternateKey( "truncate", "trunc" );
}
@Override

View File

@ -44,6 +44,7 @@ import org.hibernate.dialect.aggregate.AggregateSupport;
import org.hibernate.dialect.aggregate.PostgreSQLAggregateSupport;
import org.hibernate.dialect.function.CommonFunctionFactory;
import org.hibernate.dialect.function.PostgreSQLMinMaxFunction;
import org.hibernate.dialect.function.PostgreSQLTruncRoundFunction;
import org.hibernate.dialect.identity.IdentityColumnSupport;
import org.hibernate.dialect.identity.PostgreSQLIdentityColumnSupport;
import org.hibernate.dialect.pagination.LimitHandler;
@ -605,6 +606,14 @@ public class PostgreSQLLegacyDialect extends Dialect {
functionContributions.getFunctionRegistry().register( "min", new PostgreSQLMinMaxFunction( "min" ) );
functionContributions.getFunctionRegistry().register( "max", new PostgreSQLMinMaxFunction( "max" ) );
}
functionContributions.getFunctionRegistry().register(
"round", new PostgreSQLTruncRoundFunction( "round", true )
);
functionContributions.getFunctionRegistry().register(
"trunc", new PostgreSQLTruncRoundFunction( "trunc", true )
);
functionContributions.getFunctionRegistry().registerAlternateKey( "truncate", "trunc" );
}
/**

View File

@ -27,6 +27,7 @@ import org.hibernate.boot.model.FunctionContributions;
import org.hibernate.boot.model.TypeContributions;
import org.hibernate.dialect.function.CommonFunctionFactory;
import org.hibernate.dialect.function.FormatFunction;
import org.hibernate.dialect.function.PostgreSQLTruncRoundFunction;
import org.hibernate.dialect.identity.CockroachDBIdentityColumnSupport;
import org.hibernate.dialect.identity.IdentityColumnSupport;
import org.hibernate.dialect.pagination.LimitHandler;
@ -389,9 +390,9 @@ public class CockroachDialect extends Dialect {
functionFactory.degrees();
functionFactory.radians();
functionFactory.pi();
functionFactory.trunc(); //TODO: emulate second arg
functionFactory.log();
functionFactory.log10_log();
functionFactory.round();
functionFactory.bitandorxornot_operator();
functionFactory.bitAndOr();
@ -413,6 +414,11 @@ public class CockroachDialect extends Dialect {
functionFactory.listagg_stringAgg( "string" );
functionFactory.inverseDistributionOrderedSetAggregates();
functionFactory.hypotheticalOrderedSetAggregates_windowEmulation();
functionContributions.getFunctionRegistry().register(
"trunc", new PostgreSQLTruncRoundFunction( "trunc", getVersion().isSameOrAfter( 22, 2 ) )
);
functionContributions.getFunctionRegistry().registerAlternateKey( "truncate", "trunc" );
}
@Override

View File

@ -28,6 +28,7 @@ import org.hibernate.boot.model.TypeContributions;
import org.hibernate.dialect.aggregate.AggregateSupport;
import org.hibernate.dialect.aggregate.PostgreSQLAggregateSupport;
import org.hibernate.dialect.function.CommonFunctionFactory;
import org.hibernate.dialect.function.PostgreSQLTruncRoundFunction;
import org.hibernate.dialect.function.PostgreSQLMinMaxFunction;
import org.hibernate.dialect.identity.IdentityColumnSupport;
import org.hibernate.dialect.identity.PostgreSQLIdentityColumnSupport;
@ -523,8 +524,6 @@ public class PostgreSQLDialect extends Dialect {
CommonFunctionFactory functionFactory = new CommonFunctionFactory(functionContributions);
functionFactory.round_roundFloor(); //Postgres round(x,n) does not accept double
functionFactory.trunc_truncFloor();
functionFactory.cot();
functionFactory.radians();
functionFactory.degrees();
@ -588,6 +587,14 @@ public class PostgreSQLDialect extends Dialect {
functionContributions.getFunctionRegistry().register( "min", new PostgreSQLMinMaxFunction( "min" ) );
functionContributions.getFunctionRegistry().register( "max", new PostgreSQLMinMaxFunction( "max" ) );
}
functionContributions.getFunctionRegistry().register(
"round", new PostgreSQLTruncRoundFunction( "round", true )
);
functionContributions.getFunctionRegistry().register(
"trunc", new PostgreSQLTruncRoundFunction( "trunc", true )
);
functionContributions.getFunctionRegistry().registerAlternateKey( "truncate", "trunc" );
}
/**

View File

@ -0,0 +1,99 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function;
import java.util.List;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.type.descriptor.jdbc.JdbcType;
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.INTEGER;
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUMERIC;
/**
* PostgreSQL only supports the two-argument {@code trunc} and {@code round} functions
* with the following signatures:
* <ul>
* <li>{@code trunc(numeric, integer)}</li>
* <li>{@code round(numeric, integer)}</li>
* </ul>
* <p>
* This custom function falls back to using {@code floor} as a workaround only when necessary,
* e.g. when there are 2 arguments to the function and either:
* <ul>
* <li>The first argument is not of type {@code numeric}</li>
* or
* <li>The dialect doesn't support the two-argument {@code trunc} function</li>
* </ul>
*
* @author Marco Belladelli
* @see <a href="https://www.postgresql.org/docs/current/functions-math.html">PostgreSQL documentation</a>
*/
public class PostgreSQLTruncRoundFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
private final boolean supportsTwoArguments;
public PostgreSQLTruncRoundFunction(String name, boolean supportsTwoArguments) {
super(
name,
new ArgumentTypesValidator( StandardArgumentsValidators.between( 1, 2 ), NUMERIC, INTEGER ),
StandardFunctionReturnTypeResolvers.useArgType( 1 ),
StandardFunctionArgumentTypeResolvers.invariant( NUMERIC, INTEGER )
);
this.supportsTwoArguments = supportsTwoArguments;
}
@Override
public void render(SqlAppender sqlAppender, List<? extends SqlAstNode> arguments, SqlAstTranslator<?> walker) {
final int numberOfArguments = arguments.size();
final Expression firstArg = (Expression) arguments.get( 0 );
final JdbcType jdbcType = firstArg.getExpressionType().getJdbcMappings().get( 0 ).getJdbcType();
if ( numberOfArguments == 1 || supportsTwoArguments && jdbcType.isDecimal() ) {
// use native two-argument function
sqlAppender.appendSql( getName() );
sqlAppender.appendSql( "(" );
firstArg.accept( walker );
if ( numberOfArguments > 1 ) {
sqlAppender.appendSql( ", " );
arguments.get( 1 ).accept( walker );
}
sqlAppender.appendSql( ")" );
}
else {
// workaround using floor
if ( getName().equals( "trunc" ) ) {
sqlAppender.appendSql( "sign(" );
firstArg.accept( walker );
sqlAppender.appendSql( ")*floor(abs(" );
firstArg.accept( walker );
sqlAppender.appendSql( ")*1e" );
arguments.get( 1 ).accept( walker );
}
else {
sqlAppender.appendSql( "floor(" );
firstArg.accept( walker );
sqlAppender.appendSql( "*1e" );
arguments.get( 1 ).accept( walker );
sqlAppender.appendSql( "+0.5" );
}
sqlAppender.appendSql( ")/1e" );
arguments.get( 1 ).accept( walker );
}
}
@Override
public String getArgumentListSignature() {
return "(NUMERIC number[, INTEGER places])";
}
}

View File

@ -0,0 +1,98 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.orm.test.dialect.function;
import java.math.BigDecimal;
import org.hibernate.dialect.CockroachDialect;
import org.hibernate.dialect.PostgreSQLDialect;
import org.hibernate.testing.jdbc.SQLStatementInspector;
import org.hibernate.testing.orm.domain.StandardDomainModel;
import org.hibernate.testing.orm.domain.animal.Human;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.RequiresDialect;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* @author Marco Belladelli
*/
@DomainModel(standardModels = StandardDomainModel.ANIMAL)
@SessionFactory(statementInspectorClass = SQLStatementInspector.class)
public class PostgreSQLTruncRoundFunctionTest {
@AfterEach
public void tearDown(SessionFactoryScope scope) {
scope.inTransaction( session -> session.createMutationQuery( "delete from Human" ).executeUpdate() );
}
@Test
@RequiresDialect(PostgreSQLDialect.class)
@RequiresDialect(value = CockroachDialect.class, majorVersion = 22, minorVersion = 2, comment = "CockroachDB didn't support the two-argument trunc before version 22.2")
public void testTrunc(SessionFactoryScope scope) {
testFunction( scope, "trunc", "floor" );
}
@Test
@RequiresDialect(PostgreSQLDialect.class)
public void testRound(SessionFactoryScope scope) {
testFunction( scope, "round", "floor" );
}
@Test
@RequiresDialect(value = CockroachDialect.class, comment = "CockroachDB natively supports round with two args for both decimal and float types")
public void testRoundWithoutWorkaround(SessionFactoryScope scope) {
testFunction( scope, "round", "round" );
}
private void testFunction(SessionFactoryScope scope, String function, String workaround) {
final SQLStatementInspector sqlStatementInspector = (SQLStatementInspector) scope.getStatementInspector();
scope.inTransaction( session -> {
Human human = new Human();
human.setId( 1L );
human.setHeightInches( 1.78253d );
human.setFloatValue( 1.78253f );
human.setBigDecimalValue( new BigDecimal( "1.78253" ) );
session.persist( human );
} );
scope.inTransaction( session -> {
sqlStatementInspector.clear();
assertEquals(
1.78d,
session.createQuery(
String.format( "select %s(h.heightInches, 2) from Human h", function ),
Double.class
).getSingleResult()
);
assertTrue( sqlStatementInspector.getSqlQueries().get( 0 ).contains( workaround ) );
sqlStatementInspector.clear();
assertEquals(
1.78f,
session.createQuery(
String.format( "select %s(h.floatValue, 2) from Human h", function ),
Float.class
).getSingleResult()
);
assertTrue( sqlStatementInspector.getSqlQueries().get( 0 ).contains( workaround ) );
sqlStatementInspector.clear();
assertEquals(
0,
session.createQuery(
String.format( "select %s(h.bigDecimalValue, 2) from Human h", function ),
BigDecimal.class
).getSingleResult().compareTo( new BigDecimal( "1.78" ) )
);
assertTrue( sqlStatementInspector.getSqlQueries().get( 0 ).contains( function ) );
} );
}
}