use CastFunction to do typecasts
This commit is contained in:
parent
416eeafaa2
commit
75888b94f2
|
@ -6,25 +6,25 @@
|
|||
*/
|
||||
package org.hibernate.dialect.function;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.metamodel.mapping.JdbcMapping;
|
||||
import org.hibernate.query.sqm.CastType;
|
||||
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
|
||||
import org.hibernate.query.sqm.function.FunctionKind;
|
||||
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
|
||||
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
|
||||
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
|
||||
import org.hibernate.query.sqm.produce.function.internal.PatternRenderer;
|
||||
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
|
||||
import org.hibernate.sql.ast.SqlAstTranslator;
|
||||
import org.hibernate.sql.ast.spi.SqlAppender;
|
||||
import org.hibernate.sql.ast.tree.SqlAstNode;
|
||||
import org.hibernate.sql.ast.tree.expression.CastTarget;
|
||||
import org.hibernate.sql.ast.tree.expression.Distinct;
|
||||
import org.hibernate.sql.ast.tree.expression.Expression;
|
||||
import org.hibernate.sql.ast.tree.predicate.Predicate;
|
||||
import org.hibernate.type.BasicType;
|
||||
import org.hibernate.type.StandardBasicTypes;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
|
@ -35,27 +35,27 @@ import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUM
|
|||
*/
|
||||
public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
|
||||
|
||||
public static final String FUNCTION_NAME = "avg";
|
||||
private final Dialect dialect;
|
||||
private final SqlAstNodeRenderingMode defaultArgumentRenderingMode;
|
||||
private final String doubleCastType;
|
||||
private final CastFunction castFunction;
|
||||
private final BasicType<Double> doubleType;
|
||||
|
||||
public AvgFunction(
|
||||
Dialect dialect,
|
||||
TypeConfiguration typeConfiguration,
|
||||
SqlAstNodeRenderingMode defaultArgumentRenderingMode,
|
||||
String doubleCastType) {
|
||||
SqlAstNodeRenderingMode defaultArgumentRenderingMode) {
|
||||
super(
|
||||
FUNCTION_NAME,
|
||||
"avg",
|
||||
FunctionKind.AGGREGATE,
|
||||
new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 1 ), NUMERIC ),
|
||||
StandardFunctionReturnTypeResolvers.invariant(
|
||||
typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE )
|
||||
)
|
||||
);
|
||||
this.dialect = dialect;
|
||||
this.defaultArgumentRenderingMode = defaultArgumentRenderingMode;
|
||||
this.doubleCastType = doubleCastType;
|
||||
doubleType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
|
||||
//This is kinda wrong, we're supposed to use findFunctionDescriptor("cast"), not instantiate CastFunction
|
||||
//However, since no Dialects currently override the cast() function, it's OK for now
|
||||
castFunction = new CastFunction( dialect, dialect.getPreferredSqlTypeCodeForBoolean() );
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -101,9 +101,7 @@ public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
|
|||
final JdbcMapping sourceMapping = realArg.getExpressionType().getJdbcMappings().get( 0 );
|
||||
// Only cast to float/double if this is an integer
|
||||
if ( sourceMapping.getJdbcType().isInteger() ) {
|
||||
final String cast = dialect.castPattern( sourceMapping.getCastType(), CastType.DOUBLE );
|
||||
new PatternRenderer( cast.replace( "?2", doubleCastType ) )
|
||||
.render( sqlAppender, Collections.singletonList( realArg ), translator );
|
||||
castFunction.render( sqlAppender, Arrays.asList( realArg, new CastTarget(doubleType) ), translator );
|
||||
}
|
||||
else {
|
||||
translator.render( realArg, defaultArgumentRenderingMode );
|
||||
|
|
|
@ -67,10 +67,7 @@ public class CastFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
|
|||
|
||||
private CastType getCastType(JdbcMapping sourceMapping) {
|
||||
final CastType castType = sourceMapping.getCastType();
|
||||
if ( castType == CastType.BOOLEAN ) {
|
||||
return booleanCastType;
|
||||
}
|
||||
return castType;
|
||||
return castType == CastType.BOOLEAN ? booleanCastType : castType;
|
||||
}
|
||||
|
||||
// @Override
|
||||
|
|
|
@ -19,7 +19,6 @@ import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
|
|||
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
|
||||
import org.hibernate.type.BasicType;
|
||||
import org.hibernate.type.BasicTypeRegistry;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.StandardBasicTypes;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
|
@ -1705,7 +1704,7 @@ public class CommonFunctionFactory {
|
|||
.register();
|
||||
|
||||
functionRegistry.register(
|
||||
CountFunction.FUNCTION_NAME,
|
||||
"count",
|
||||
new CountFunction(
|
||||
dialect,
|
||||
typeConfiguration,
|
||||
|
@ -1720,19 +1719,18 @@ public class CommonFunctionFactory {
|
|||
Dialect dialect,
|
||||
SqlAstNodeRenderingMode inferenceArgumentRenderingMode) {
|
||||
functionRegistry.register(
|
||||
AvgFunction.FUNCTION_NAME,
|
||||
"avg",
|
||||
new AvgFunction(
|
||||
dialect,
|
||||
typeConfiguration,
|
||||
inferenceArgumentRenderingMode,
|
||||
dialect.getTypeName( SqlTypes.DOUBLE )
|
||||
inferenceArgumentRenderingMode
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void listagg(String emptyWithinReplacement) {
|
||||
functionRegistry.register(
|
||||
ListaggFunction.FUNCTION_NAME,
|
||||
"listagg",
|
||||
new ListaggFunction( emptyWithinReplacement, typeConfiguration )
|
||||
);
|
||||
}
|
||||
|
|
|
@ -41,7 +41,6 @@ import org.hibernate.type.spi.TypeConfiguration;
|
|||
*/
|
||||
public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
|
||||
|
||||
public static final String FUNCTION_NAME = "count";
|
||||
private final Dialect dialect;
|
||||
private final SqlAstNodeRenderingMode defaultArgumentRenderingMode;
|
||||
private final String concatOperator;
|
||||
|
@ -54,7 +53,7 @@ public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
|
|||
String concatOperator,
|
||||
String concatArgumentCastType) {
|
||||
super(
|
||||
FUNCTION_NAME,
|
||||
"count",
|
||||
FunctionKind.AGGREGATE,
|
||||
StandardArgumentsValidators.exactly( 1 ),
|
||||
StandardFunctionReturnTypeResolvers.invariant(
|
||||
|
|
|
@ -31,13 +31,11 @@ import static org.hibernate.query.sqm.produce.function.FunctionParameterType.STR
|
|||
*/
|
||||
public class ListaggFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
|
||||
|
||||
public static final String FUNCTION_NAME = "listagg";
|
||||
|
||||
private final String emptyWithinReplacement;
|
||||
|
||||
public ListaggFunction(String emptyWithinReplacement, TypeConfiguration typeConfiguration) {
|
||||
super(
|
||||
FUNCTION_NAME,
|
||||
"listagg",
|
||||
FunctionKind.ORDERED_SET_AGGREGATE,
|
||||
new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 2 ), STRING, STRING ),
|
||||
StandardFunctionReturnTypeResolvers.invariant(
|
||||
|
|
|
@ -60,8 +60,10 @@ public class TimestampaddFunction
|
|||
StandardFunctionReturnTypeResolvers.useArgType( 3 )
|
||||
);
|
||||
this.dialect = dialect;
|
||||
this.castFunction = new CastFunction( dialect, Types.BOOLEAN );
|
||||
this.integerType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.INTEGER );
|
||||
//This is kinda wrong, we're supposed to use findFunctionDescriptor("cast"), not instantiate CastFunction
|
||||
//However, since no Dialects currently override the cast() function, it's OK for now
|
||||
this.castFunction = new CastFunction( dialect, dialect.getPreferredSqlTypeCodeForBoolean() );
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
*/
|
||||
package org.hibernate.sql.ast.tree.expression;
|
||||
|
||||
import org.hibernate.metamodel.mapping.BasicValuedMapping;
|
||||
import org.hibernate.metamodel.mapping.JdbcMapping;
|
||||
import org.hibernate.metamodel.mapping.JdbcMappingContainer;
|
||||
import org.hibernate.sql.ast.SqlAstWalker;
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
package org.hibernate.orm.test.query.hql;
|
||||
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.Id;
|
||||
import org.hibernate.testing.orm.junit.DomainModel;
|
||||
import org.hibernate.testing.orm.junit.SessionFactory;
|
||||
import org.hibernate.testing.orm.junit.SessionFactoryScope;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
|
||||
@DomainModel(annotatedClasses = AvgFunctionTest.Value.class)
|
||||
@SessionFactory
|
||||
public class AvgFunctionTest {
|
||||
|
||||
@Test
|
||||
public void test(SessionFactoryScope scope) {
|
||||
scope.inTransaction(
|
||||
session -> {
|
||||
session.persist( new Value(0) );
|
||||
session.persist( new Value(1) );
|
||||
session.persist( new Value(2) );
|
||||
session.persist( new Value(3) );
|
||||
assertThat(
|
||||
session.createQuery("select avg(value) from Value", Double.class)
|
||||
.getSingleResult(),
|
||||
is(1.5)
|
||||
);
|
||||
assertThat(
|
||||
session.createQuery("select avg(integerValue) from Value", Double.class)
|
||||
.getSingleResult(),
|
||||
is(1.5)
|
||||
);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@Entity(name="Value")
|
||||
public static class Value {
|
||||
public Value() {}
|
||||
public Value(int value) {
|
||||
this.value = value;
|
||||
this.integerValue = value;
|
||||
}
|
||||
@Id
|
||||
double value;
|
||||
int integerValue;
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue