promote trunc() / truncate() to the list of standard HQL functions

also support the single-argument form of round() for consistency
This commit is contained in:
Gavin 2022-12-21 02:52:04 +01:00 committed by Gavin King
parent 46a4c5e8f0
commit 023e73cb46
25 changed files with 256 additions and 57 deletions

View File

@ -994,7 +994,8 @@ Of course, we also have a number of functions for working with numeric values.
| `exp()` | Exponential function | `exp(x)` | ✓
| `power()` | Exponentiation | `power(x,y)` | ✓
| `ln()` | Natural logarithm | `ln(x)` | ✓
| `round()` | Numeric rounding | `round(number, places)` | ✓
| `round()` | Numeric rounding | `round(number)`, `round(number, places)` | ✓
| `trunc()` or `truncate()` | Numeric truncation | `truncate(number)`, `truncate(number, places)` | ✗
| `floor()` | Floor function | `floor(x)` | ✓
| `ceiling()` | Ceiling function | `ceiling(x)` | ✓

View File

@ -232,7 +232,7 @@ public class CUBRIDDialect extends Dialect {
functionFactory.bitLength();
functionFactory.md5();
functionFactory.trunc();
functionFactory.truncate();
// functionFactory.truncate();
functionFactory.toCharNumberDateTimestamp();
functionFactory.substr();
//also natively supports ANSI-style substring()

View File

@ -297,7 +297,7 @@ public class DB2LegacyDialect extends Dialect {
functionFactory.ascii();
functionFactory.char_chr();
functionFactory.trunc();
functionFactory.truncate();
// functionFactory.truncate();
functionFactory.insert();
functionFactory.characterLength_length( SqlAstNodeRenderingMode.DEFAULT );
functionFactory.stddev();

View File

@ -317,6 +317,7 @@ public class DerbyLegacyDialect extends Dialect {
functionFactory.characterLength_length( SqlAstNodeRenderingMode.NO_PLAIN_PARAMETER );
functionFactory.power_expLn();
functionFactory.round_floor();
functionFactory.trunc_floor();
functionFactory.octetLength_pattern( "length(?1)" );
functionFactory.bitLength_pattern( "length(?1)*8" );

View File

@ -296,7 +296,6 @@ public class H2LegacyDialect extends Dialect {
functionFactory.log10();
functionFactory.mod_operator();
functionFactory.rand();
functionFactory.truncate();
functionFactory.soundex();
functionFactory.translate();
functionFactory.bitand();
@ -311,6 +310,9 @@ public class H2LegacyDialect extends Dialect {
if ( useLocalTime ) {
functionFactory.localtimeLocaltimestamp();
}
functionFactory.trunc();
// functionFactory.truncate();
functionFactory.dateTrunc();
functionFactory.bitLength();
functionFactory.octetLength();
functionFactory.ascii();

View File

@ -201,7 +201,7 @@ public class HSQLLegacyDialect extends Dialect {
functionFactory.log10();
functionFactory.rand();
functionFactory.trunc();
functionFactory.truncate();
// functionFactory.truncate();
functionFactory.pi();
functionFactory.soundex();
functionFactory.reverse();

View File

@ -252,7 +252,7 @@ public class IngresDialect extends Dialect {
functionFactory.repeat();
functionFactory.trim2();
functionFactory.trunc();
functionFactory.truncate();
// functionFactory.truncate();
functionFactory.initcap();
functionFactory.yearMonthDay();
functionFactory.hourMinuteSecond();

View File

@ -555,7 +555,7 @@ public class MySQLLegacyDialect extends Dialect {
//also natively supports ANSI-style substring()
functionFactory.position();
functionFactory.nowCurdateCurtime();
functionFactory.truncate();
functionFactory.trunc_truncate();
functionFactory.insert();
functionFactory.bitandorxornot_operator();
functionFactory.bitAndOr();

View File

@ -534,11 +534,11 @@ public class PostgreSQLLegacyDialect extends Dialect {
CommonFunctionFactory functionFactory = new CommonFunctionFactory(queryEngine);
functionFactory.round_floor(); //Postgres round(x,n) does not accept double
functionFactory.round_roundFloor(); //Postgres round(x,n) does not accept double
functionFactory.trunc_truncFloor();
functionFactory.cot();
functionFactory.radians();
functionFactory.degrees();
functionFactory.trunc();
functionFactory.log();
functionFactory.mod_operator();
if ( getVersion().isSameOrAfter( 12 ) ) {

View File

@ -213,7 +213,7 @@ public class RDMSOS2200Dialect extends Dialect {
functionFactory.pi();
functionFactory.rand();
functionFactory.trunc();
functionFactory.truncate();
// functionFactory.truncate();
functionFactory.soundex();
functionFactory.trim2();
functionFactory.space();

View File

@ -303,7 +303,8 @@ public class SQLServerLegacyDialect extends AbstractTransactSQLDialect {
functionFactory.log_log();
functionFactory.truncate_round();
functionFactory.trunc_round();
functionFactory.round_round();
functionFactory.everyAny_minMaxIif();
functionFactory.octetLength_pattern( "datalength(?1)" );
functionFactory.bitLength_pattern( "datalength(?1)*8" );

View File

@ -216,6 +216,8 @@ public class SybaseLegacyDialect extends AbstractTransactSQLDialect {
functionFactory.varPopSamp_varp();
functionFactory.stddevPopSamp();
functionFactory.varPopSamp();
functionFactory.trunc_floorPower();
functionFactory.round_round();
// For SQL-Server we need to cast certain arguments to varchar(16384) to be able to concat them
queryEngine.getSqmFunctionRegistry().register(

View File

@ -282,7 +282,7 @@ public class DB2Dialect extends Dialect {
functionFactory.ascii();
functionFactory.char_chr();
functionFactory.trunc();
functionFactory.truncate();
// functionFactory.truncate();
functionFactory.insert();
functionFactory.characterLength_length( SqlAstNodeRenderingMode.DEFAULT );
functionFactory.stddev();

View File

@ -302,6 +302,7 @@ public class DerbyDialect extends Dialect {
functionFactory.characterLength_length( SqlAstNodeRenderingMode.NO_PLAIN_PARAMETER );
functionFactory.power_expLn();
functionFactory.round_floor();
functionFactory.trunc_floor();
functionFactory.octetLength_pattern( "length(?1)" );
functionFactory.bitLength_pattern( "length(?1)*8" );

View File

@ -759,7 +759,8 @@ public abstract class Dialect implements ConversionContext {
* <li> <code>acos(arg)</code>
* <li> <code>atan(arg)</code>
* <li> <code>atan2(arg0, arg1)</code>
* <li> <code>round(arg0, arg1)</code>
* <li> <code>round(arg0[, arg1])</code>
* <li> <code>truncate(arg0[, arg1])</code>
* <li> <code>sinh(arg)</code>
* <li> <code>tanh(arg)</code>
* <li> <code>cosh(arg)</code>
@ -833,6 +834,8 @@ public abstract class Dialect implements ConversionContext {
//to implement such a silly thing, it would be dog slow.
functionFactory.math();
functionFactory.round();
//trig functions supported on almost every database

View File

@ -261,7 +261,6 @@ public class H2Dialect extends Dialect {
functionFactory.log10();
functionFactory.mod_operator();
functionFactory.rand();
functionFactory.truncate();
functionFactory.soundex();
functionFactory.translate();
functionFactory.bitand();
@ -277,6 +276,7 @@ public class H2Dialect extends Dialect {
functionFactory.localtimeLocaltimestamp();
}
functionFactory.trunc();
// functionFactory.truncate();
functionFactory.dateTrunc();
functionFactory.bitLength();
functionFactory.octetLength();

View File

@ -157,7 +157,7 @@ public class HSQLDialect extends Dialect {
functionFactory.log10();
functionFactory.rand();
functionFactory.trunc();
functionFactory.truncate();
// functionFactory.truncate();
functionFactory.pi();
functionFactory.soundex();
functionFactory.reverse();

View File

@ -544,7 +544,7 @@ public class MySQLDialect extends Dialect {
//also natively supports ANSI-style substring()
functionFactory.position();
functionFactory.nowCurdateCurtime();
functionFactory.truncate();
functionFactory.trunc_truncate();
functionFactory.insert();
functionFactory.bitandorxornot_operator();
functionFactory.bitAndOr();

View File

@ -517,11 +517,11 @@ public class PostgreSQLDialect extends Dialect {
CommonFunctionFactory functionFactory = new CommonFunctionFactory(queryEngine);
functionFactory.round_floor(); //Postgres round(x,n) does not accept double
functionFactory.round_roundFloor(); //Postgres round(x,n) does not accept double
functionFactory.trunc_truncFloor();
functionFactory.cot();
functionFactory.radians();
functionFactory.degrees();
functionFactory.trunc();
functionFactory.log();
functionFactory.mod_operator();
if ( getVersion().isSameOrAfter( 12 ) ) {

View File

@ -307,7 +307,8 @@ public class SQLServerDialect extends AbstractTransactSQLDialect {
functionFactory.log_log();
functionFactory.truncate_round();
functionFactory.trunc_round();
functionFactory.round_round();
functionFactory.everyAny_minMaxIif();
functionFactory.octetLength_pattern( "datalength(?1)" );
functionFactory.bitLength_pattern( "datalength(?1)*8" );

View File

@ -220,6 +220,8 @@ public class SybaseDialect extends AbstractTransactSQLDialect {
functionFactory.varPopSamp_varp();
functionFactory.stddevPopSamp();
functionFactory.varPopSamp();
functionFactory.trunc_floorPower();
functionFactory.round_round();
// For SQL-Server we need to cast certain arguments to varchar(16384) to be able to concat them
queryEngine.getSqmFunctionRegistry().register(

View File

@ -263,11 +263,82 @@ public class CommonFunctionFactory {
public void trunc() {
functionRegistry.namedDescriptorBuilder( "trunc" )
.setReturnTypeResolver( useArgType( 1 ) )
.setArgumentCountBetween( 1, 2 )
.setParameterTypes(NUMERIC, INTEGER)
.setInvariantType(doubleType)
.setArgumentListSignature( "(NUMERIC number[, INTEGER places])" )
.register();
functionRegistry.registerAlternateKey( "truncate", "trunc" );
}
/**
* MySQL
*/
public void trunc_truncate() {
functionRegistry.registerUnaryBinaryPattern(
"trunc",
"truncate(?1,0)",
"truncate(?1,?2)",
NUMERIC, INTEGER,
typeConfiguration
).setArgumentListSignature( "(NUMERIC number[, INTEGER places])" );
functionRegistry.registerAlternateKey( "truncate", "trunc" );
}
/**
* SQL Server
*/
public void trunc_round() {
functionRegistry.registerUnaryBinaryPattern(
"trunc",
"round(?1,0,1)",
"round(?1,?2,1)",
NUMERIC, INTEGER,
typeConfiguration
).setArgumentListSignature( "(NUMERIC number[, INTEGER places])" );
functionRegistry.registerAlternateKey( "truncate", "trunc" );
}
/**
* Sybase
*/
public void trunc_floorPower() {
functionRegistry.registerUnaryBinaryPattern(
"trunc",
"sign(?1)*floor(abs(?1))",
"sign(?1)*floor(abs(?1)*power(10,?2))/power(10,?2)",
NUMERIC, INTEGER,
typeConfiguration
).setArgumentListSignature( "(NUMERIC number[, INTEGER places])" );
functionRegistry.registerAlternateKey( "truncate", "trunc" );
}
/**
* PostgreSQL (only works if the second arg is constant, as it almost always is)
*/
public void trunc_truncFloor() {
functionRegistry.registerUnaryBinaryPattern(
"trunc",
"trunc(?1)",
"sign(?1)*floor(abs(?1)*1e?2)/1e?2",
NUMERIC, INTEGER,
typeConfiguration
).setArgumentListSignature( "(NUMERIC number[, INTEGER places])" );
functionRegistry.registerAlternateKey( "truncate", "trunc" );
}
/**
* Derby (only works if the second arg is constant, as it almost always is)
*/
public void trunc_floor() {
functionRegistry.registerUnaryBinaryPattern(
"trunc",
"sign(?1)*floor(abs(?1))",
"sign(?1)*floor(abs(?1)*1e?2)/1e?2",
NUMERIC, INTEGER,
typeConfiguration
).setArgumentListSignature( "(NUMERIC number[, INTEGER places])" );
functionRegistry.registerAlternateKey( "truncate", "trunc" );
}
public void truncate() {
@ -416,7 +487,6 @@ public class CommonFunctionFactory {
}
public void regrLinearRegressionAggregates() {
Arrays.asList(
"regr_avgx", "regr_avgy", "regr_count", "regr_intercept", "regr_r2",
"regr_slope", "regr_sxx", "regr_sxy", "regr_syy"
@ -2087,13 +2157,6 @@ public class CommonFunctionFactory {
}
public void math() {
functionRegistry.namedDescriptorBuilder( "round" )
// To avoid truncating to a specific data type, we default to using the argument type
.setReturnTypeResolver( useArgType( 1 ) )
.setExactArgumentCount( 2 )
.setParameterTypes(NUMERIC, INTEGER)
.register();
functionRegistry.namedDescriptorBuilder( "floor" )
// To avoid truncating to a specific data type, we default to using the argument type
.setReturnTypeResolver( useArgType( 1 ) )
@ -2170,14 +2233,56 @@ public class CommonFunctionFactory {
.register();
}
public void round_floor() {
functionRegistry.patternDescriptorBuilder( "round", "floor(?1*1e?2+0.5)/1e?2")
.setReturnTypeResolver( useArgType(1) )
.setExactArgumentCount( 2 )
public void round() {
functionRegistry.namedDescriptorBuilder( "round" )
// To avoid truncating to a specific data type, we default to using the argument type
.setReturnTypeResolver( useArgType( 1 ) )
.setArgumentCountBetween( 1, 2 )
.setParameterTypes(NUMERIC, INTEGER)
.setArgumentListSignature( "(NUMERIC number[, INTEGER places])" )
.register();
}
/**
* SQL Server
*/
public void round_round() {
functionRegistry.registerUnaryBinaryPattern(
"round",
"round(?1,0)",
"round(?1,?2)",
NUMERIC, INTEGER,
typeConfiguration
).setArgumentListSignature( "(NUMERIC number[, INTEGER places])" );
functionRegistry.registerAlternateKey( "truncate", "trunc" );
}
/**
* Derby (only works if the second arg is constant, as it almost always is)
*/
public void round_floor() {
functionRegistry.registerUnaryBinaryPattern(
"round",
"floor(?1+0.5)",
"floor(?1*1e?2+0.5)/1e?2",
NUMERIC, INTEGER,
typeConfiguration
).setArgumentListSignature( "(NUMERIC number[, INTEGER places])" );
}
/**
* PostgreSQL (only works if the second arg is constant, as it almost always is)
*/
public void round_roundFloor() {
functionRegistry.registerUnaryBinaryPattern(
"round",
"round(?1)",
"floor(?1*1e?2+0.5)/1e?2",
NUMERIC, INTEGER,
typeConfiguration
).setArgumentListSignature( "(NUMERIC number[, INTEGER places])" );
}
public void square() {
functionRegistry.namedDescriptorBuilder( "square" )
.setExactArgumentCount( 1 )

View File

@ -11,14 +11,16 @@ import org.hibernate.query.spi.QueryEngine;
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
import org.hibernate.query.sqm.produce.function.FunctionParameterType;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.query.sqm.tree.SqmTypedNode;
import org.hibernate.type.BasicType;
import org.hibernate.type.spi.TypeConfiguration;
import java.util.List;
import static org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers.invariant;
import static org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers.invariant;
import static org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers.useArgType;
/**
* Support for overloaded functions defined in terms of a
* list of patterns, one for each possible function arity.
@ -72,8 +74,39 @@ public class MultipatternSqmFunctionDescriptor extends AbstractSqmFunctionDescri
),
parameterTypes
),
StandardFunctionReturnTypeResolvers.invariant( type ),
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, parameterTypes )
invariant( type ),
invariant( typeConfiguration, parameterTypes )
);
this.functions = functions;
}
/**
* Construct an instance with the given function templates
* where the position of each function template in the
* given array corresponds to the arity of the function
* template. The array must be padded with leading nulls
* where there is no overloaded form corresponding to
* lower arities.
*
* @param name
* @param functions the function templates to delegate to,
*/
public MultipatternSqmFunctionDescriptor(
String name,
SqmFunctionDescriptor[] functions,
TypeConfiguration typeConfiguration,
FunctionParameterType... parameterTypes) {
super(
name,
new ArgumentTypesValidator(
StandardArgumentsValidators.between(
first(functions),
last(functions)
),
parameterTypes
),
useArgType( 1 ),
invariant( typeConfiguration, parameterTypes )
);
this.functions = functions;
}

View File

@ -20,6 +20,7 @@ import org.hibernate.type.spi.TypeConfiguration;
import org.jboss.logging.Logger;
import static java.lang.String.CASE_INSENSITIVE_ORDER;
import static org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers.useArgType;
/**
* Defines a registry for {@link SqmFunctionDescriptor} instances
@ -319,6 +320,27 @@ public class SqmFunctionRegistry {
);
}
/**
* Register a unary/binary function.
*
* i.e. a function which accepts 1-2 arguments.
*/
public MultipatternSqmFunctionDescriptor registerUnaryBinaryPattern(
String name,
String pattern1,
String pattern2,
FunctionParameterType parameterType1,
FunctionParameterType parameterType2,
TypeConfiguration typeConfiguration) {
return registerPatterns(
name,
new FunctionParameterType[] { parameterType1, parameterType2 },
typeConfiguration,
null,
pattern1,
pattern2
);
}
/**
* Register a unary/binary function.
*
@ -402,6 +424,31 @@ public class SqmFunctionRegistry {
);
}
private MultipatternSqmFunctionDescriptor registerPatterns(
String name,
FunctionParameterType[] parameterTypes,
TypeConfiguration typeConfiguration,
String... patterns) {
SqmFunctionDescriptor[] descriptors =
new SqmFunctionDescriptor[patterns.length];
for ( int i = 0; i < patterns.length; i++ ) {
String pattern = patterns[i];
if ( pattern != null ) {
descriptors[i] =
patternDescriptorBuilder( name, pattern )
.setExactArgumentCount( i )
.setParameterTypes( parameterTypes )
.setReturnTypeResolver( useArgType(1) )
.descriptor();
}
}
MultipatternSqmFunctionDescriptor function =
new MultipatternSqmFunctionDescriptor( name, descriptors, typeConfiguration, parameterTypes );
register( name, function );
return function;
}
private MultipatternSqmFunctionDescriptor registerPatterns(
String name,
BasicType<?> type,

View File

@ -16,8 +16,6 @@ import org.hibernate.dialect.MariaDBDialect;
import org.hibernate.dialect.MySQLDialect;
import org.hibernate.dialect.OracleDialect;
import org.hibernate.dialect.PostgreSQLDialect;
import org.hibernate.dialect.SQLServerDialect;
import org.hibernate.dialect.SybaseDialect;
import org.hibernate.dialect.TiDBDialect;
import org.hibernate.testing.TestForIssue;
@ -480,28 +478,30 @@ public class FunctionTests {
}
@Test
@SkipForDialect(dialectClass = MySQLDialect.class, matchSubTypes = true)
@SkipForDialect(dialectClass = SQLServerDialect.class)
@SkipForDialect(dialectClass = SybaseDialect.class, matchSubTypes = true)
@SkipForDialect(dialectClass = DerbyDialect.class)
public void testTruncFunction(SessionFactoryScope scope) {
public void testRoundTruncFunctions(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
assertThat( session.createQuery("select trunc(32.92345)").getSingleResult(), is(32d) );
assertThat( session.createQuery("select trunc(32.92345,3)").getSingleResult(), is(32.923d) );
}
);
}
assertThat( session.createQuery("select trunc(32.92345f)").getSingleResult(), is(32f) );
assertThat( session.createQuery("select trunc(32.92345f,3)").getSingleResult(), is(32.923f) );
assertThat( session.createQuery("select trunc(-32.92345f)").getSingleResult(), is(-32f) );
assertThat( session.createQuery("select trunc(-32.92345f,3)").getSingleResult(), is(-32.923f) );
assertThat( session.createQuery("select truncate(32.92345f)").getSingleResult(), is(32f) );
assertThat( session.createQuery("select truncate(32.92345f,3)").getSingleResult(), is(32.923f) );
assertThat( session.createQuery("select round(32.92345f)").getSingleResult(), is(33f) );
assertThat( session.createQuery("select round(32.92345f,1)").getSingleResult(), is(32.9f) );
assertThat( session.createQuery("select round(32.92345f,3)").getSingleResult(), is(32.923f) );
assertThat( session.createQuery("select round(32.923451f,4)").getSingleResult(), is(32.9235f) );
@Test
@RequiresDialect(MySQLDialect.class)
@RequiresDialect(SQLServerDialect.class)
@RequiresDialect(DB2Dialect.class)
@RequiresDialect(H2Dialect.class)
public void testTruncateFunction(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
assertThat( session.createQuery("select truncate(32.92345,3)").getSingleResult(), is(32.923d) );
assertThat( session.createQuery("select trunc(32.92345d)").getSingleResult(), is(32d) );
assertThat( session.createQuery("select trunc(32.92345d,3)").getSingleResult(), is(32.923d) );
assertThat( session.createQuery("select trunc(-32.92345d)").getSingleResult(), is(-32d) );
assertThat( session.createQuery("select trunc(-32.92345d,3)").getSingleResult(), is(-32.923d) );
assertThat( session.createQuery("select truncate(32.92345d)").getSingleResult(), is(32d) );
assertThat( session.createQuery("select truncate(32.92345d,3)").getSingleResult(), is(32.923d) );
assertThat( session.createQuery("select round(32.92345d)").getSingleResult(), is(33d) );
assertThat( session.createQuery("select round(32.92345d,1)").getSingleResult(), is(32.9d) );
assertThat( session.createQuery("select round(32.92345d,3)").getSingleResult(), is(32.923d) );
assertThat( session.createQuery("select round(32.923451d,4)").getSingleResult(), is(32.9235d) );
}
);
}