use CastFunction to do typecasts

This commit is contained in:
Gavin King 2022-02-07 13:26:31 +01:00
parent 416eeafaa2
commit 75888b94f2
8 changed files with 73 additions and 31 deletions

View File

@ -6,25 +6,25 @@
*/ */
package org.hibernate.dialect.function; package org.hibernate.dialect.function;
import java.util.Collections; import java.util.Arrays;
import java.util.List; import java.util.List;
import org.hibernate.dialect.Dialect; import org.hibernate.dialect.Dialect;
import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.query.sqm.CastType;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor; import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.function.FunctionKind; import org.hibernate.query.sqm.function.FunctionKind;
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator; import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.query.sqm.produce.function.internal.PatternRenderer;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode; import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.CastTarget;
import org.hibernate.sql.ast.tree.expression.Distinct; import org.hibernate.sql.ast.tree.expression.Distinct;
import org.hibernate.sql.ast.tree.expression.Expression; import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.sql.ast.tree.predicate.Predicate; import org.hibernate.sql.ast.tree.predicate.Predicate;
import org.hibernate.type.BasicType;
import org.hibernate.type.StandardBasicTypes; import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.spi.TypeConfiguration; import org.hibernate.type.spi.TypeConfiguration;
@ -35,27 +35,27 @@ import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUM
*/ */
public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor { public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
public static final String FUNCTION_NAME = "avg";
private final Dialect dialect;
private final SqlAstNodeRenderingMode defaultArgumentRenderingMode; private final SqlAstNodeRenderingMode defaultArgumentRenderingMode;
private final String doubleCastType; private final CastFunction castFunction;
private final BasicType<Double> doubleType;
public AvgFunction( public AvgFunction(
Dialect dialect, Dialect dialect,
TypeConfiguration typeConfiguration, TypeConfiguration typeConfiguration,
SqlAstNodeRenderingMode defaultArgumentRenderingMode, SqlAstNodeRenderingMode defaultArgumentRenderingMode) {
String doubleCastType) {
super( super(
FUNCTION_NAME, "avg",
FunctionKind.AGGREGATE, FunctionKind.AGGREGATE,
new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 1 ), NUMERIC ), new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 1 ), NUMERIC ),
StandardFunctionReturnTypeResolvers.invariant( StandardFunctionReturnTypeResolvers.invariant(
typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE ) typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE )
) )
); );
this.dialect = dialect;
this.defaultArgumentRenderingMode = defaultArgumentRenderingMode; this.defaultArgumentRenderingMode = defaultArgumentRenderingMode;
this.doubleCastType = doubleCastType; doubleType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
//This is kinda wrong, we're supposed to use findFunctionDescriptor("cast"), not instantiate CastFunction
//However, since no Dialects currently override the cast() function, it's OK for now
castFunction = new CastFunction( dialect, dialect.getPreferredSqlTypeCodeForBoolean() );
} }
@Override @Override
@ -101,9 +101,7 @@ public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
final JdbcMapping sourceMapping = realArg.getExpressionType().getJdbcMappings().get( 0 ); final JdbcMapping sourceMapping = realArg.getExpressionType().getJdbcMappings().get( 0 );
// Only cast to float/double if this is an integer // Only cast to float/double if this is an integer
if ( sourceMapping.getJdbcType().isInteger() ) { if ( sourceMapping.getJdbcType().isInteger() ) {
final String cast = dialect.castPattern( sourceMapping.getCastType(), CastType.DOUBLE ); castFunction.render( sqlAppender, Arrays.asList( realArg, new CastTarget(doubleType) ), translator );
new PatternRenderer( cast.replace( "?2", doubleCastType ) )
.render( sqlAppender, Collections.singletonList( realArg ), translator );
} }
else { else {
translator.render( realArg, defaultArgumentRenderingMode ); translator.render( realArg, defaultArgumentRenderingMode );

View File

@ -67,10 +67,7 @@ public class CastFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
private CastType getCastType(JdbcMapping sourceMapping) { private CastType getCastType(JdbcMapping sourceMapping) {
final CastType castType = sourceMapping.getCastType(); final CastType castType = sourceMapping.getCastType();
if ( castType == CastType.BOOLEAN ) { return castType == CastType.BOOLEAN ? booleanCastType : castType;
return booleanCastType;
}
return castType;
} }
// @Override // @Override

View File

@ -19,7 +19,6 @@ import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode; import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.type.BasicType; import org.hibernate.type.BasicType;
import org.hibernate.type.BasicTypeRegistry; import org.hibernate.type.BasicTypeRegistry;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.StandardBasicTypes; import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.spi.TypeConfiguration; import org.hibernate.type.spi.TypeConfiguration;
@ -1705,7 +1704,7 @@ public class CommonFunctionFactory {
.register(); .register();
functionRegistry.register( functionRegistry.register(
CountFunction.FUNCTION_NAME, "count",
new CountFunction( new CountFunction(
dialect, dialect,
typeConfiguration, typeConfiguration,
@ -1720,19 +1719,18 @@ public class CommonFunctionFactory {
Dialect dialect, Dialect dialect,
SqlAstNodeRenderingMode inferenceArgumentRenderingMode) { SqlAstNodeRenderingMode inferenceArgumentRenderingMode) {
functionRegistry.register( functionRegistry.register(
AvgFunction.FUNCTION_NAME, "avg",
new AvgFunction( new AvgFunction(
dialect, dialect,
typeConfiguration, typeConfiguration,
inferenceArgumentRenderingMode, inferenceArgumentRenderingMode
dialect.getTypeName( SqlTypes.DOUBLE )
) )
); );
} }
public void listagg(String emptyWithinReplacement) { public void listagg(String emptyWithinReplacement) {
functionRegistry.register( functionRegistry.register(
ListaggFunction.FUNCTION_NAME, "listagg",
new ListaggFunction( emptyWithinReplacement, typeConfiguration ) new ListaggFunction( emptyWithinReplacement, typeConfiguration )
); );
} }

View File

@ -41,7 +41,6 @@ import org.hibernate.type.spi.TypeConfiguration;
*/ */
public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor { public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
public static final String FUNCTION_NAME = "count";
private final Dialect dialect; private final Dialect dialect;
private final SqlAstNodeRenderingMode defaultArgumentRenderingMode; private final SqlAstNodeRenderingMode defaultArgumentRenderingMode;
private final String concatOperator; private final String concatOperator;
@ -54,7 +53,7 @@ public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
String concatOperator, String concatOperator,
String concatArgumentCastType) { String concatArgumentCastType) {
super( super(
FUNCTION_NAME, "count",
FunctionKind.AGGREGATE, FunctionKind.AGGREGATE,
StandardArgumentsValidators.exactly( 1 ), StandardArgumentsValidators.exactly( 1 ),
StandardFunctionReturnTypeResolvers.invariant( StandardFunctionReturnTypeResolvers.invariant(

View File

@ -31,13 +31,11 @@ import static org.hibernate.query.sqm.produce.function.FunctionParameterType.STR
*/ */
public class ListaggFunction extends AbstractSqmSelfRenderingFunctionDescriptor { public class ListaggFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
public static final String FUNCTION_NAME = "listagg";
private final String emptyWithinReplacement; private final String emptyWithinReplacement;
public ListaggFunction(String emptyWithinReplacement, TypeConfiguration typeConfiguration) { public ListaggFunction(String emptyWithinReplacement, TypeConfiguration typeConfiguration) {
super( super(
FUNCTION_NAME, "listagg",
FunctionKind.ORDERED_SET_AGGREGATE, FunctionKind.ORDERED_SET_AGGREGATE,
new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 2 ), STRING, STRING ), new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 2 ), STRING, STRING ),
StandardFunctionReturnTypeResolvers.invariant( StandardFunctionReturnTypeResolvers.invariant(

View File

@ -60,8 +60,10 @@ public class TimestampaddFunction
StandardFunctionReturnTypeResolvers.useArgType( 3 ) StandardFunctionReturnTypeResolvers.useArgType( 3 )
); );
this.dialect = dialect; this.dialect = dialect;
this.castFunction = new CastFunction( dialect, Types.BOOLEAN );
this.integerType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.INTEGER ); this.integerType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.INTEGER );
//This is kinda wrong, we're supposed to use findFunctionDescriptor("cast"), not instantiate CastFunction
//However, since no Dialects currently override the cast() function, it's OK for now
this.castFunction = new CastFunction( dialect, dialect.getPreferredSqlTypeCodeForBoolean() );
} }
@Override @Override

View File

@ -6,7 +6,6 @@
*/ */
package org.hibernate.sql.ast.tree.expression; package org.hibernate.sql.ast.tree.expression;
import org.hibernate.metamodel.mapping.BasicValuedMapping;
import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.metamodel.mapping.JdbcMappingContainer;
import org.hibernate.sql.ast.SqlAstWalker; import org.hibernate.sql.ast.SqlAstWalker;

View File

@ -0,0 +1,51 @@
package org.hibernate.orm.test.query.hql;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.Test;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
@DomainModel(annotatedClasses = AvgFunctionTest.Value.class)
@SessionFactory
public class AvgFunctionTest {
@Test
public void test(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
session.persist( new Value(0) );
session.persist( new Value(1) );
session.persist( new Value(2) );
session.persist( new Value(3) );
assertThat(
session.createQuery("select avg(value) from Value", Double.class)
.getSingleResult(),
is(1.5)
);
assertThat(
session.createQuery("select avg(integerValue) from Value", Double.class)
.getSingleResult(),
is(1.5)
);
}
);
}
@Entity(name="Value")
public static class Value {
public Value() {}
public Value(int value) {
this.value = value;
this.integerValue = value;
}
@Id
double value;
int integerValue;
}
}