HHH-17375 Overload length function with array_length semantics on array input

This commit is contained in:
Christian Beikov 2024-05-07 09:38:26 +02:00
parent b74992198c
commit 35102836c7
4 changed files with 203 additions and 8 deletions

View File

@ -1271,7 +1271,7 @@ include::{array-example-dir-hql}/ArrayPositionsTest.java[tags=hql-array-position
==== ====
[[hql-array-length-functions]] [[hql-array-length-functions]]
===== `array_length()` ===== `array_length()` or `length()`
Returns size of the passed array. Returns `null` if the array is `null`. Returns size of the passed array. Returns `null` if the array is `null`.
@ -1283,6 +1283,17 @@ include::{array-example-dir-hql}/ArrayLengthTest.java[tags=hql-array-length-exam
---- ----
==== ====
Alternatively, it is also possible to use the `length()` function,
which is overloaded to also accept an array argument.
[[hql-array-length-hql-example]]
====
[source, JAVA, indent=0]
----
include::{array-example-dir-hql}/ArrayLengthTest.java[tags=hql-array-length-hql-example]
----
====
[[hql-array-concat-functions]] [[hql-array-concat-functions]]
===== `array_concat()` or `||` ===== `array_concat()` or `||`

View File

@ -1605,34 +1605,34 @@ public class CommonFunctionFactory {
* Transact SQL-style * Transact SQL-style
*/ */
public void characterLength_len() { public void characterLength_len() {
functionRegistry.namedDescriptorBuilder( "len" ) functionRegistry.namedDescriptorBuilder( "character_length", "len" )
.setInvariantType(integerType) .setInvariantType(integerType)
.setExactArgumentCount( 1 ) .setExactArgumentCount( 1 )
.setParameterTypes(STRING_OR_CLOB) .setParameterTypes(STRING_OR_CLOB)
.register(); .register();
functionRegistry.registerAlternateKey( "character_length", "len" ); functionRegistry.registerAlternateKey( "len", "character_length" );
functionRegistry.registerAlternateKey( "length", "len" ); functionRegistry.registerAlternateKey( "length", "character_length" );
} }
/** /**
* Oracle-style * Oracle-style
*/ */
public void characterLength_length(SqlAstNodeRenderingMode argumentRenderingMode) { public void characterLength_length(SqlAstNodeRenderingMode argumentRenderingMode) {
functionRegistry.namedDescriptorBuilder( "length" ) functionRegistry.namedDescriptorBuilder( "character_length", "length" )
.setInvariantType(integerType) .setInvariantType(integerType)
.setExactArgumentCount( 1 ) .setExactArgumentCount( 1 )
.setParameterTypes(STRING_OR_CLOB) .setParameterTypes(STRING_OR_CLOB)
.setArgumentRenderingMode( argumentRenderingMode ) .setArgumentRenderingMode( argumentRenderingMode )
.register(); .register();
functionRegistry.registerAlternateKey( "character_length", "length" ); functionRegistry.registerAlternateKey( "length", "character_length" );
} }
public void characterLength_length(String clobPattern) { public void characterLength_length(String clobPattern) {
functionRegistry.register( functionRegistry.register(
"length", "character_length",
new LengthFunction( "length", "length(?1)", clobPattern, typeConfiguration ) new LengthFunction( "length", "length(?1)", clobPattern, typeConfiguration )
); );
functionRegistry.registerAlternateKey( "character_length", "length" ); functionRegistry.registerAlternateKey( "length", "character_length" );
} }
public void octetLength() { public void octetLength() {
@ -2861,6 +2861,7 @@ public class CommonFunctionFactory {
) )
.setArgumentListSignature( "(ARRAY array)" ) .setArgumentListSignature( "(ARRAY array)" )
.register(); .register();
functionRegistry.register( "length", new DynamicDispatchFunction( functionRegistry, "character_length", "array_length" ) );
} }
/** /**
@ -2868,6 +2869,7 @@ public class CommonFunctionFactory {
*/ */
public void arrayLength_oracle() { public void arrayLength_oracle() {
functionRegistry.register( "array_length", new OracleArrayLengthFunction( typeConfiguration ) ); functionRegistry.register( "array_length", new OracleArrayLengthFunction( typeConfiguration ) );
functionRegistry.register( "length", new DynamicDispatchFunction( functionRegistry, "character_length", "array_length" ) );
} }
/** /**

View File

@ -0,0 +1,170 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function;
import java.util.List;
import org.hibernate.query.ReturnableType;
import org.hibernate.query.spi.QueryEngine;
import org.hibernate.query.sqm.function.FunctionKind;
import org.hibernate.query.sqm.function.SelfRenderingSqmFunction;
import org.hibernate.query.sqm.function.SqmFunctionDescriptor;
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
import org.hibernate.query.sqm.tree.SqmTypedNode;
import org.hibernate.query.sqm.tree.predicate.SqmPredicate;
import org.hibernate.query.sqm.tree.select.SqmOrderByClause;
import org.hibernate.type.spi.TypeConfiguration;
/**
* A function that dynamically dispatches to other functions,
* depending on which function validates successfully first.
* This can be used for overload implementations.
*
* @since 6.6
*/
public class DynamicDispatchFunction implements SqmFunctionDescriptor, ArgumentsValidator {
private final SqmFunctionRegistry functionRegistry;
private final String[] functionNames;
private final FunctionKind functionKind;
public DynamicDispatchFunction(SqmFunctionRegistry functionRegistry, String... functionNames) {
this.functionRegistry = functionRegistry;
this.functionNames = functionNames;
FunctionKind functionKind = null;
// Sanity check
for ( String overload : functionNames ) {
final SqmFunctionDescriptor functionDescriptor = functionRegistry.findFunctionDescriptor( overload );
if ( functionDescriptor == null ) {
throw new IllegalArgumentException( "No function registered under the name '" + overload + "'" );
}
if ( functionKind == null ) {
functionKind = functionDescriptor.getFunctionKind();
}
else if ( functionKind != functionDescriptor.getFunctionKind() ) {
throw new IllegalArgumentException( "Function has function kind " + functionDescriptor.getFunctionKind() + ", but other overloads have " + functionKind + ". An overloaded function needs a single function kind." );
}
}
this.functionKind = functionKind;
}
@Override
public FunctionKind getFunctionKind() {
return functionKind;
}
@Override
public <T> SelfRenderingSqmFunction<T> generateSqmExpression(
List<? extends SqmTypedNode<?>> arguments,
ReturnableType<T> impliedResultType,
QueryEngine queryEngine) {
final SqmFunctionDescriptor functionDescriptor = validateGetFunction(
arguments,
queryEngine.getTypeConfiguration()
);
return functionDescriptor.generateSqmExpression( arguments, impliedResultType, queryEngine );
}
@Override
public <T> SelfRenderingSqmFunction<T> generateAggregateSqmExpression(
List<? extends SqmTypedNode<?>> arguments,
SqmPredicate filter,
ReturnableType<T> impliedResultType,
QueryEngine queryEngine) {
final SqmFunctionDescriptor functionDescriptor = validateGetFunction(
arguments,
queryEngine.getTypeConfiguration()
);
return functionDescriptor.generateAggregateSqmExpression(
arguments,
filter,
impliedResultType,
queryEngine
);
}
@Override
public <T> SelfRenderingSqmFunction<T> generateOrderedSetAggregateSqmExpression(
List<? extends SqmTypedNode<?>> arguments,
SqmPredicate filter,
SqmOrderByClause withinGroupClause,
ReturnableType<T> impliedResultType,
QueryEngine queryEngine) {
final SqmFunctionDescriptor functionDescriptor = validateGetFunction(
arguments,
queryEngine.getTypeConfiguration()
);
return functionDescriptor.generateOrderedSetAggregateSqmExpression(
arguments,
filter,
withinGroupClause,
impliedResultType,
queryEngine
);
}
@Override
public <T> SelfRenderingSqmFunction<T> generateWindowSqmExpression(
List<? extends SqmTypedNode<?>> arguments,
SqmPredicate filter,
Boolean respectNulls,
Boolean fromFirst,
ReturnableType<T> impliedResultType,
QueryEngine queryEngine) {
final SqmFunctionDescriptor functionDescriptor = validateGetFunction(
arguments,
queryEngine.getTypeConfiguration()
);
return functionDescriptor.generateWindowSqmExpression(
arguments,
filter,
respectNulls,
fromFirst,
impliedResultType,
queryEngine
);
}
@Override
public ArgumentsValidator getArgumentsValidator() {
return this;
}
@Override
public void validate(
List<? extends SqmTypedNode<?>> arguments,
String functionName,
TypeConfiguration typeConfiguration) {
validateGetFunction( arguments, typeConfiguration );
}
private SqmFunctionDescriptor validateGetFunction(
List<? extends SqmTypedNode<?>> arguments,
TypeConfiguration typeConfiguration) {
RuntimeException exception = null;
for ( String overload : functionNames ) {
final SqmFunctionDescriptor functionDescriptor = functionRegistry.findFunctionDescriptor( overload );
if ( functionDescriptor == null ) {
throw new IllegalArgumentException( "No function registered under the name '" + overload + "'" );
}
try {
functionDescriptor.getArgumentsValidator().validate( arguments, overload, typeConfiguration );
return functionDescriptor;
}
catch (RuntimeException ex) {
if ( exception == null ) {
exception = ex;
}
else {
exception.addSuppressed( ex );
}
}
}
throw exception;
}
}

View File

@ -113,4 +113,16 @@ public class ArrayLengthTest {
} ); } );
} }
@Test
public void testLengthThreeHql(SessionFactoryScope scope) {
scope.inSession( em -> {
//tag::hql-array-length-hql-example[]
List<EntityWithArrays> results = em.createQuery( "from EntityWithArrays e where length(e.theArray) = 3", EntityWithArrays.class )
.getResultList();
//end::hql-array-length-hql-example[]
assertEquals( 1, results.size() );
assertEquals( 2L, results.get( 0 ).getId() );
} );
}
} }