Make sure timestampdiff returns a double for the SECOND unit as per JPA 3.1

This commit is contained in:
Christian Beikov 2022-03-21 17:41:22 +01:00
parent b84a6e3a7f
commit 7020a1a563
3 changed files with 64 additions and 44 deletions

View File

@ -7,22 +7,28 @@
package org.hibernate.dialect.function;
import java.util.List;
import java.util.function.Supplier;
import org.hibernate.dialect.Dialect;
import org.hibernate.metamodel.mapping.BasicValuedMapping;
import org.hibernate.query.ReturnableType;
import org.hibernate.query.sqm.TemporalUnit;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.function.SelfRenderingFunctionSqlAstExpression;
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.query.sqm.produce.function.internal.PatternRenderer;
import org.hibernate.query.sqm.tree.SqmTypedNode;
import org.hibernate.query.sqm.tree.expression.SqmDurationUnit;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.DurationUnit;
import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.type.BasicType;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.spi.TypeConfiguration;
@ -54,8 +60,9 @@ public class TimestampdiffFunction
StandardArgumentsValidators.exactly( 3 ),
TEMPORAL_UNIT, TEMPORAL, TEMPORAL
),
StandardFunctionReturnTypeResolvers.invariant(
typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.LONG )
new TimestampdiffFunctionReturnTypeResolver(
typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.LONG ),
typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE )
),
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, TEMPORAL_UNIT, TEMPORAL, TEMPORAL )
);
@ -81,35 +88,6 @@ public class TimestampdiffFunction
return new PatternRenderer( dialect.timestampdiffPattern( unit, lhsTemporalType, rhsTemporalType ) );
}
// @Override
// protected <T> SelfRenderingSqlFunctionExpression<T> generateSqmFunctionExpression(
// List<SqmTypedNode<?>> arguments,
// ReturnableType<T> impliedResultType,
// QueryEngine queryEngine,
// TypeConfiguration typeConfiguration) {
// SqmExtractUnit<?> field = (SqmExtractUnit<?>) arguments.get(0);
// SqmExpression<?> from = (SqmExpression<?>) arguments.get(1);
// SqmExpression<?> to = (SqmExpression<?>) arguments.get(2);
// return queryEngine.getSqmFunctionRegistry()
// .patternDescriptorBuilder(
// "timestampdiff",
// dialect.timestampdiffPattern(
// field.getUnit(),
// typeConfiguration.isSqlTimestampType( from.getNodeType() ),
// typeConfiguration.isSqlTimestampType( to.getNodeType() )
// )
// )
// .setInvariantType( StandardBasicTypes.LONG )
// .setExactArgumentCount( 3 )
// .descriptor()
// .generateSqmExpression(
// arguments,
// impliedResultType,
// queryEngine,
// typeConfiguration
// );
// }
public SelfRenderingFunctionSqlAstExpression expression(
ReturnableType<?> impliedResultType,
SqlAstNode... sqlAstArguments) {
@ -130,4 +108,51 @@ public class TimestampdiffFunction
return "(TEMPORAL_UNIT field, TEMPORAL start, TEMPORAL end)";
}
/**
* A special resolver that resolves to DOUBLE for {@link TemporalUnit#SECOND} and otherwise to LONG.
*/
private static class TimestampdiffFunctionReturnTypeResolver implements FunctionReturnTypeResolver {
private final BasicType<Long> longType;
private final BasicType<Double> doubleType;
public TimestampdiffFunctionReturnTypeResolver(BasicType<Long> longType, BasicType<Double> doubleType) {
this.longType = longType;
this.doubleType = doubleType;
}
@Override
public ReturnableType<?> resolveFunctionReturnType(
ReturnableType<?> impliedType,
List<? extends SqmTypedNode<?>> arguments,
TypeConfiguration typeConfiguration) {
final BasicType<?> invariantType;
if ( ( (SqmDurationUnit<?>) arguments.get( 0 ) ).getUnit() == TemporalUnit.SECOND ) {
invariantType = doubleType;
}
else {
invariantType = longType;
}
return StandardFunctionReturnTypeResolvers.isAssignableTo( invariantType, impliedType )
? impliedType : invariantType;
}
@Override
public BasicValuedMapping resolveFunctionReturnType(Supplier<BasicValuedMapping> impliedTypeAccess, List<? extends SqlAstNode> arguments) {
final BasicType<?> invariantType;
if ( ( (SqmDurationUnit<?>) arguments.get( 0 ) ).getUnit() == TemporalUnit.SECOND ) {
invariantType = doubleType;
}
else {
invariantType = longType;
}
return StandardFunctionReturnTypeResolvers.useImpliedTypeIfPossible( invariantType, impliedTypeAccess.get() );
}
@Override
public String getReturnType() {
return longType + "|" + doubleType;
}
}
}

View File

@ -11,6 +11,7 @@ import java.util.List;
import java.util.Locale;
import java.util.function.Supplier;
import org.hibernate.Internal;
import org.hibernate.QueryException;
import org.hibernate.metamodel.mapping.BasicValuedMapping;
import org.hibernate.metamodel.mapping.JdbcMapping;
@ -119,7 +120,8 @@ public class StandardFunctionReturnTypeResolvers {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Internal helpers
private static boolean isAssignableTo(
@Internal
public static boolean isAssignableTo(
ReturnableType<?> defined, ReturnableType<?> implied) {
if ( implied == null ) {
return false;
@ -144,7 +146,8 @@ public class StandardFunctionReturnTypeResolvers {
|| isNumeric( impliedTypeCode ) && isNumeric( definedTypeCode );
}
private static BasicValuedMapping useImpliedTypeIfPossible(
@Internal
public static BasicValuedMapping useImpliedTypeIfPossible(
BasicValuedMapping defined,
BasicValuedMapping implied) {
if ( defined == null ) {

View File

@ -5431,17 +5431,9 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
Expression left = getActualExpression( cleanly( () -> toSqlExpression( lhs.accept( this ) ) ) );
Expression right = getActualExpression( cleanly( () -> toSqlExpression( rhs.accept( this ) ) ) );
TypeConfiguration typeConfiguration = getCreationContext().getMappingMetamodel().getTypeConfiguration();
TemporalType leftTimestamp = typeConfiguration.getSqlTemporalType( expression.getLeftHandOperand().getNodeType() );
TemporalType rightTimestamp = typeConfiguration.getSqlTemporalType( expression.getRightHandOperand().getNodeType() );
// when we're dealing with Dates, we use
// DAY as the smallest unit, otherwise we
// use SECOND granularity with fractions as that is what the DurationJavaType expects
TemporalUnit baseUnit = ( rightTimestamp == TemporalType.TIMESTAMP || leftTimestamp == TemporalType.TIMESTAMP ) ?
SECOND :
DAY;
// The result of timestamp subtraction is always a `Duration`, unless a unit is applied
// So use SECOND granularity with fractions as that is what the `DurationJavaType` expects
final TemporalUnit baseUnit = SECOND; // todo: alternatively repurpose NATIVE to mean "INTERVAL SECOND"
if ( adjustedTimestamp != null ) {
if ( appliedByUnit != null ) {