HHH-15985 Custom trunc and round function for PostgreSQL and Cockroach
This commit is contained in:
parent
74689f26a5
commit
54402da721
|
@ -42,6 +42,7 @@ import org.hibernate.dialect.SpannerDialect;
|
|||
import org.hibernate.dialect.TimeZoneSupport;
|
||||
import org.hibernate.dialect.function.CommonFunctionFactory;
|
||||
import org.hibernate.dialect.function.FormatFunction;
|
||||
import org.hibernate.dialect.function.PostgreSQLTruncRoundFunction;
|
||||
import org.hibernate.dialect.identity.CockroachDBIdentityColumnSupport;
|
||||
import org.hibernate.dialect.identity.IdentityColumnSupport;
|
||||
import org.hibernate.dialect.pagination.LimitHandler;
|
||||
|
@ -383,9 +384,9 @@ public class CockroachLegacyDialect extends Dialect {
|
|||
functionFactory.degrees();
|
||||
functionFactory.radians();
|
||||
functionFactory.pi();
|
||||
functionFactory.trunc(); //TODO: emulate second arg
|
||||
functionFactory.log();
|
||||
functionFactory.log10_log();
|
||||
functionFactory.round();
|
||||
|
||||
functionFactory.bitandorxornot_operator();
|
||||
functionFactory.bitAndOr();
|
||||
|
@ -407,6 +408,11 @@ public class CockroachLegacyDialect extends Dialect {
|
|||
functionFactory.listagg_stringAgg( "string" );
|
||||
functionFactory.inverseDistributionOrderedSetAggregates();
|
||||
functionFactory.hypotheticalOrderedSetAggregates_windowEmulation();
|
||||
|
||||
functionContributions.getFunctionRegistry().register(
|
||||
"trunc", new PostgreSQLTruncRoundFunction( "trunc", getVersion().isSameOrAfter( 22, 2 ) )
|
||||
);
|
||||
functionContributions.getFunctionRegistry().registerAlternateKey( "truncate", "trunc" );
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -44,6 +44,7 @@ import org.hibernate.dialect.aggregate.AggregateSupport;
|
|||
import org.hibernate.dialect.aggregate.PostgreSQLAggregateSupport;
|
||||
import org.hibernate.dialect.function.CommonFunctionFactory;
|
||||
import org.hibernate.dialect.function.PostgreSQLMinMaxFunction;
|
||||
import org.hibernate.dialect.function.PostgreSQLTruncRoundFunction;
|
||||
import org.hibernate.dialect.identity.IdentityColumnSupport;
|
||||
import org.hibernate.dialect.identity.PostgreSQLIdentityColumnSupport;
|
||||
import org.hibernate.dialect.pagination.LimitHandler;
|
||||
|
@ -605,6 +606,14 @@ public class PostgreSQLLegacyDialect extends Dialect {
|
|||
functionContributions.getFunctionRegistry().register( "min", new PostgreSQLMinMaxFunction( "min" ) );
|
||||
functionContributions.getFunctionRegistry().register( "max", new PostgreSQLMinMaxFunction( "max" ) );
|
||||
}
|
||||
|
||||
functionContributions.getFunctionRegistry().register(
|
||||
"round", new PostgreSQLTruncRoundFunction( "round", true )
|
||||
);
|
||||
functionContributions.getFunctionRegistry().register(
|
||||
"trunc", new PostgreSQLTruncRoundFunction( "trunc", true )
|
||||
);
|
||||
functionContributions.getFunctionRegistry().registerAlternateKey( "truncate", "trunc" );
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.hibernate.boot.model.FunctionContributions;
|
|||
import org.hibernate.boot.model.TypeContributions;
|
||||
import org.hibernate.dialect.function.CommonFunctionFactory;
|
||||
import org.hibernate.dialect.function.FormatFunction;
|
||||
import org.hibernate.dialect.function.PostgreSQLTruncRoundFunction;
|
||||
import org.hibernate.dialect.identity.CockroachDBIdentityColumnSupport;
|
||||
import org.hibernate.dialect.identity.IdentityColumnSupport;
|
||||
import org.hibernate.dialect.pagination.LimitHandler;
|
||||
|
@ -389,9 +390,9 @@ public class CockroachDialect extends Dialect {
|
|||
functionFactory.degrees();
|
||||
functionFactory.radians();
|
||||
functionFactory.pi();
|
||||
functionFactory.trunc(); //TODO: emulate second arg
|
||||
functionFactory.log();
|
||||
functionFactory.log10_log();
|
||||
functionFactory.round();
|
||||
|
||||
functionFactory.bitandorxornot_operator();
|
||||
functionFactory.bitAndOr();
|
||||
|
@ -413,6 +414,11 @@ public class CockroachDialect extends Dialect {
|
|||
functionFactory.listagg_stringAgg( "string" );
|
||||
functionFactory.inverseDistributionOrderedSetAggregates();
|
||||
functionFactory.hypotheticalOrderedSetAggregates_windowEmulation();
|
||||
|
||||
functionContributions.getFunctionRegistry().register(
|
||||
"trunc", new PostgreSQLTruncRoundFunction( "trunc", getVersion().isSameOrAfter( 22, 2 ) )
|
||||
);
|
||||
functionContributions.getFunctionRegistry().registerAlternateKey( "truncate", "trunc" );
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.hibernate.boot.model.TypeContributions;
|
|||
import org.hibernate.dialect.aggregate.AggregateSupport;
|
||||
import org.hibernate.dialect.aggregate.PostgreSQLAggregateSupport;
|
||||
import org.hibernate.dialect.function.CommonFunctionFactory;
|
||||
import org.hibernate.dialect.function.PostgreSQLTruncRoundFunction;
|
||||
import org.hibernate.dialect.function.PostgreSQLMinMaxFunction;
|
||||
import org.hibernate.dialect.identity.IdentityColumnSupport;
|
||||
import org.hibernate.dialect.identity.PostgreSQLIdentityColumnSupport;
|
||||
|
@ -523,8 +524,6 @@ public class PostgreSQLDialect extends Dialect {
|
|||
|
||||
CommonFunctionFactory functionFactory = new CommonFunctionFactory(functionContributions);
|
||||
|
||||
functionFactory.round_roundFloor(); //Postgres round(x,n) does not accept double
|
||||
functionFactory.trunc_truncFloor();
|
||||
functionFactory.cot();
|
||||
functionFactory.radians();
|
||||
functionFactory.degrees();
|
||||
|
@ -588,6 +587,14 @@ public class PostgreSQLDialect extends Dialect {
|
|||
functionContributions.getFunctionRegistry().register( "min", new PostgreSQLMinMaxFunction( "min" ) );
|
||||
functionContributions.getFunctionRegistry().register( "max", new PostgreSQLMinMaxFunction( "max" ) );
|
||||
}
|
||||
|
||||
functionContributions.getFunctionRegistry().register(
|
||||
"round", new PostgreSQLTruncRoundFunction( "round", true )
|
||||
);
|
||||
functionContributions.getFunctionRegistry().register(
|
||||
"trunc", new PostgreSQLTruncRoundFunction( "trunc", true )
|
||||
);
|
||||
functionContributions.getFunctionRegistry().registerAlternateKey( "truncate", "trunc" );
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
|
||||
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
|
||||
*/
|
||||
package org.hibernate.dialect.function;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
|
||||
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
|
||||
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
|
||||
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
|
||||
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
|
||||
import org.hibernate.sql.ast.SqlAstTranslator;
|
||||
import org.hibernate.sql.ast.spi.SqlAppender;
|
||||
import org.hibernate.sql.ast.tree.SqlAstNode;
|
||||
import org.hibernate.sql.ast.tree.expression.Expression;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcType;
|
||||
|
||||
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.INTEGER;
|
||||
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUMERIC;
|
||||
|
||||
/**
|
||||
* PostgreSQL only supports the two-argument {@code trunc} and {@code round} functions
|
||||
* with the following signatures:
|
||||
* <ul>
|
||||
* <li>{@code trunc(numeric, integer)}</li>
|
||||
* <li>{@code round(numeric, integer)}</li>
|
||||
* </ul>
|
||||
* <p>
|
||||
* This custom function falls back to using {@code floor} as a workaround only when necessary,
|
||||
* e.g. when there are 2 arguments to the function and either:
|
||||
* <ul>
|
||||
* <li>The first argument is not of type {@code numeric}</li>
|
||||
* or
|
||||
* <li>The dialect doesn't support the two-argument {@code trunc} function</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Marco Belladelli
|
||||
* @see <a href="https://www.postgresql.org/docs/current/functions-math.html">PostgreSQL documentation</a>
|
||||
*/
|
||||
public class PostgreSQLTruncRoundFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
|
||||
private final boolean supportsTwoArguments;
|
||||
|
||||
public PostgreSQLTruncRoundFunction(String name, boolean supportsTwoArguments) {
|
||||
super(
|
||||
name,
|
||||
new ArgumentTypesValidator( StandardArgumentsValidators.between( 1, 2 ), NUMERIC, INTEGER ),
|
||||
StandardFunctionReturnTypeResolvers.useArgType( 1 ),
|
||||
StandardFunctionArgumentTypeResolvers.invariant( NUMERIC, INTEGER )
|
||||
);
|
||||
this.supportsTwoArguments = supportsTwoArguments;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void render(SqlAppender sqlAppender, List<? extends SqlAstNode> arguments, SqlAstTranslator<?> walker) {
|
||||
final int numberOfArguments = arguments.size();
|
||||
final Expression firstArg = (Expression) arguments.get( 0 );
|
||||
final JdbcType jdbcType = firstArg.getExpressionType().getJdbcMappings().get( 0 ).getJdbcType();
|
||||
if ( numberOfArguments == 1 || supportsTwoArguments && jdbcType.isDecimal() ) {
|
||||
// use native two-argument function
|
||||
sqlAppender.appendSql( getName() );
|
||||
sqlAppender.appendSql( "(" );
|
||||
firstArg.accept( walker );
|
||||
if ( numberOfArguments > 1 ) {
|
||||
sqlAppender.appendSql( ", " );
|
||||
arguments.get( 1 ).accept( walker );
|
||||
}
|
||||
sqlAppender.appendSql( ")" );
|
||||
}
|
||||
else {
|
||||
// workaround using floor
|
||||
if ( getName().equals( "trunc" ) ) {
|
||||
sqlAppender.appendSql( "sign(" );
|
||||
firstArg.accept( walker );
|
||||
sqlAppender.appendSql( ")*floor(abs(" );
|
||||
firstArg.accept( walker );
|
||||
sqlAppender.appendSql( ")*1e" );
|
||||
arguments.get( 1 ).accept( walker );
|
||||
}
|
||||
else {
|
||||
sqlAppender.appendSql( "floor(" );
|
||||
firstArg.accept( walker );
|
||||
sqlAppender.appendSql( "*1e" );
|
||||
arguments.get( 1 ).accept( walker );
|
||||
sqlAppender.appendSql( "+0.5" );
|
||||
}
|
||||
sqlAppender.appendSql( ")/1e" );
|
||||
arguments.get( 1 ).accept( walker );
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getArgumentListSignature() {
|
||||
return "(NUMERIC number[, INTEGER places])";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
|
||||
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
|
||||
*/
|
||||
package org.hibernate.orm.test.dialect.function;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
import org.hibernate.dialect.CockroachDialect;
|
||||
import org.hibernate.dialect.PostgreSQLDialect;
|
||||
|
||||
import org.hibernate.testing.jdbc.SQLStatementInspector;
|
||||
import org.hibernate.testing.orm.domain.StandardDomainModel;
|
||||
import org.hibernate.testing.orm.domain.animal.Human;
|
||||
import org.hibernate.testing.orm.junit.DomainModel;
|
||||
import org.hibernate.testing.orm.junit.RequiresDialect;
|
||||
import org.hibernate.testing.orm.junit.SessionFactory;
|
||||
import org.hibernate.testing.orm.junit.SessionFactoryScope;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
/**
|
||||
* @author Marco Belladelli
|
||||
*/
|
||||
@DomainModel(standardModels = StandardDomainModel.ANIMAL)
|
||||
@SessionFactory(statementInspectorClass = SQLStatementInspector.class)
|
||||
public class PostgreSQLTruncRoundFunctionTest {
|
||||
@AfterEach
|
||||
public void tearDown(SessionFactoryScope scope) {
|
||||
scope.inTransaction( session -> session.createMutationQuery( "delete from Human" ).executeUpdate() );
|
||||
}
|
||||
|
||||
@Test
|
||||
@RequiresDialect(PostgreSQLDialect.class)
|
||||
@RequiresDialect(value = CockroachDialect.class, majorVersion = 22, minorVersion = 2, comment = "CockroachDB didn't support the two-argument trunc before version 22.2")
|
||||
public void testTrunc(SessionFactoryScope scope) {
|
||||
testFunction( scope, "trunc", "floor" );
|
||||
}
|
||||
|
||||
@Test
|
||||
@RequiresDialect(PostgreSQLDialect.class)
|
||||
public void testRound(SessionFactoryScope scope) {
|
||||
testFunction( scope, "round", "floor" );
|
||||
}
|
||||
|
||||
@Test
|
||||
@RequiresDialect(value = CockroachDialect.class, comment = "CockroachDB natively supports round with two args for both decimal and float types")
|
||||
public void testRoundWithoutWorkaround(SessionFactoryScope scope) {
|
||||
testFunction( scope, "round", "round" );
|
||||
}
|
||||
|
||||
private void testFunction(SessionFactoryScope scope, String function, String workaround) {
|
||||
final SQLStatementInspector sqlStatementInspector = (SQLStatementInspector) scope.getStatementInspector();
|
||||
scope.inTransaction( session -> {
|
||||
Human human = new Human();
|
||||
human.setId( 1L );
|
||||
human.setHeightInches( 1.78253d );
|
||||
human.setFloatValue( 1.78253f );
|
||||
human.setBigDecimalValue( new BigDecimal( "1.78253" ) );
|
||||
session.persist( human );
|
||||
} );
|
||||
|
||||
scope.inTransaction( session -> {
|
||||
sqlStatementInspector.clear();
|
||||
assertEquals(
|
||||
1.78d,
|
||||
session.createQuery(
|
||||
String.format( "select %s(h.heightInches, 2) from Human h", function ),
|
||||
Double.class
|
||||
).getSingleResult()
|
||||
);
|
||||
assertTrue( sqlStatementInspector.getSqlQueries().get( 0 ).contains( workaround ) );
|
||||
sqlStatementInspector.clear();
|
||||
assertEquals(
|
||||
1.78f,
|
||||
session.createQuery(
|
||||
String.format( "select %s(h.floatValue, 2) from Human h", function ),
|
||||
Float.class
|
||||
).getSingleResult()
|
||||
);
|
||||
assertTrue( sqlStatementInspector.getSqlQueries().get( 0 ).contains( workaround ) );
|
||||
sqlStatementInspector.clear();
|
||||
assertEquals(
|
||||
0,
|
||||
session.createQuery(
|
||||
String.format( "select %s(h.bigDecimalValue, 2) from Human h", function ),
|
||||
BigDecimal.class
|
||||
).getSingleResult().compareTo( new BigDecimal( "1.78" ) )
|
||||
);
|
||||
assertTrue( sqlStatementInspector.getSqlQueries().get( 0 ).contains( function ) );
|
||||
} );
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue