use CastFunction to do typecasts

This commit is contained in:
Gavin King 2022-02-07 13:26:31 +01:00
parent 416eeafaa2
commit 75888b94f2
8 changed files with 73 additions and 31 deletions

View File

@ -6,25 +6,25 @@
*/
package org.hibernate.dialect.function;
import java.util.Collections;
import java.util.Arrays;
import java.util.List;
import org.hibernate.dialect.Dialect;
import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.query.sqm.CastType;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.function.FunctionKind;
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.query.sqm.produce.function.internal.PatternRenderer;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.CastTarget;
import org.hibernate.sql.ast.tree.expression.Distinct;
import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.sql.ast.tree.predicate.Predicate;
import org.hibernate.type.BasicType;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.spi.TypeConfiguration;
@ -35,27 +35,27 @@ import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUM
*/
public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
public static final String FUNCTION_NAME = "avg";
private final Dialect dialect;
private final SqlAstNodeRenderingMode defaultArgumentRenderingMode;
private final String doubleCastType;
private final CastFunction castFunction;
private final BasicType<Double> doubleType;
public AvgFunction(
Dialect dialect,
TypeConfiguration typeConfiguration,
SqlAstNodeRenderingMode defaultArgumentRenderingMode,
String doubleCastType) {
SqlAstNodeRenderingMode defaultArgumentRenderingMode) {
super(
FUNCTION_NAME,
"avg",
FunctionKind.AGGREGATE,
new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 1 ), NUMERIC ),
StandardFunctionReturnTypeResolvers.invariant(
typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE )
)
);
this.dialect = dialect;
this.defaultArgumentRenderingMode = defaultArgumentRenderingMode;
this.doubleCastType = doubleCastType;
doubleType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
//This is kinda wrong, we're supposed to use findFunctionDescriptor("cast"), not instantiate CastFunction
//However, since no Dialects currently override the cast() function, it's OK for now
castFunction = new CastFunction( dialect, dialect.getPreferredSqlTypeCodeForBoolean() );
}
@Override
@ -101,9 +101,7 @@ public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
final JdbcMapping sourceMapping = realArg.getExpressionType().getJdbcMappings().get( 0 );
// Only cast to float/double if this is an integer
if ( sourceMapping.getJdbcType().isInteger() ) {
final String cast = dialect.castPattern( sourceMapping.getCastType(), CastType.DOUBLE );
new PatternRenderer( cast.replace( "?2", doubleCastType ) )
.render( sqlAppender, Collections.singletonList( realArg ), translator );
castFunction.render( sqlAppender, Arrays.asList( realArg, new CastTarget(doubleType) ), translator );
}
else {
translator.render( realArg, defaultArgumentRenderingMode );

View File

@ -67,10 +67,7 @@ public class CastFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
private CastType getCastType(JdbcMapping sourceMapping) {
final CastType castType = sourceMapping.getCastType();
if ( castType == CastType.BOOLEAN ) {
return booleanCastType;
}
return castType;
return castType == CastType.BOOLEAN ? booleanCastType : castType;
}
// @Override

View File

@ -19,7 +19,6 @@ import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.type.BasicType;
import org.hibernate.type.BasicTypeRegistry;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.spi.TypeConfiguration;
@ -1705,7 +1704,7 @@ public class CommonFunctionFactory {
.register();
functionRegistry.register(
CountFunction.FUNCTION_NAME,
"count",
new CountFunction(
dialect,
typeConfiguration,
@ -1720,19 +1719,18 @@ public class CommonFunctionFactory {
Dialect dialect,
SqlAstNodeRenderingMode inferenceArgumentRenderingMode) {
functionRegistry.register(
AvgFunction.FUNCTION_NAME,
"avg",
new AvgFunction(
dialect,
typeConfiguration,
inferenceArgumentRenderingMode,
dialect.getTypeName( SqlTypes.DOUBLE )
inferenceArgumentRenderingMode
)
);
}
public void listagg(String emptyWithinReplacement) {
functionRegistry.register(
ListaggFunction.FUNCTION_NAME,
"listagg",
new ListaggFunction( emptyWithinReplacement, typeConfiguration )
);
}

View File

@ -41,7 +41,6 @@ import org.hibernate.type.spi.TypeConfiguration;
*/
public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
public static final String FUNCTION_NAME = "count";
private final Dialect dialect;
private final SqlAstNodeRenderingMode defaultArgumentRenderingMode;
private final String concatOperator;
@ -54,7 +53,7 @@ public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
String concatOperator,
String concatArgumentCastType) {
super(
FUNCTION_NAME,
"count",
FunctionKind.AGGREGATE,
StandardArgumentsValidators.exactly( 1 ),
StandardFunctionReturnTypeResolvers.invariant(

View File

@ -31,13 +31,11 @@ import static org.hibernate.query.sqm.produce.function.FunctionParameterType.STR
*/
public class ListaggFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
public static final String FUNCTION_NAME = "listagg";
private final String emptyWithinReplacement;
public ListaggFunction(String emptyWithinReplacement, TypeConfiguration typeConfiguration) {
super(
FUNCTION_NAME,
"listagg",
FunctionKind.ORDERED_SET_AGGREGATE,
new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 2 ), STRING, STRING ),
StandardFunctionReturnTypeResolvers.invariant(

View File

@ -60,8 +60,10 @@ public class TimestampaddFunction
StandardFunctionReturnTypeResolvers.useArgType( 3 )
);
this.dialect = dialect;
this.castFunction = new CastFunction( dialect, Types.BOOLEAN );
this.integerType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.INTEGER );
//This is kinda wrong, we're supposed to use findFunctionDescriptor("cast"), not instantiate CastFunction
//However, since no Dialects currently override the cast() function, it's OK for now
this.castFunction = new CastFunction( dialect, dialect.getPreferredSqlTypeCodeForBoolean() );
}
@Override

View File

@ -6,7 +6,6 @@
*/
package org.hibernate.sql.ast.tree.expression;
import org.hibernate.metamodel.mapping.BasicValuedMapping;
import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.metamodel.mapping.JdbcMappingContainer;
import org.hibernate.sql.ast.SqlAstWalker;

View File

@ -0,0 +1,51 @@
package org.hibernate.orm.test.query.hql;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.Test;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
@DomainModel(annotatedClasses = AvgFunctionTest.Value.class)
@SessionFactory
public class AvgFunctionTest {
@Test
public void test(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
session.persist( new Value(0) );
session.persist( new Value(1) );
session.persist( new Value(2) );
session.persist( new Value(3) );
assertThat(
session.createQuery("select avg(value) from Value", Double.class)
.getSingleResult(),
is(1.5)
);
assertThat(
session.createQuery("select avg(integerValue) from Value", Double.class)
.getSingleResult(),
is(1.5)
);
}
);
}
@Entity(name="Value")
public static class Value {
public Value() {}
public Value(int value) {
this.value = value;
this.integerValue = value;
}
@Id
double value;
int integerValue;
}
}