HHH-17335 Add array_get function

This commit is contained in:
Christian Beikov 2023-10-23 18:38:32 +02:00
parent d5404fdd49
commit 8c4ed1ef48
17 changed files with 350 additions and 0 deletions

View File

@ -1123,6 +1123,11 @@ The following functions deal with SQL array types, which are not supported on ev
| `array_position()` | Determines the position of an element in an array | `array_position()` | Determines the position of an element in an array
| `array_length()` | Determines the length of an array | `array_length()` | Determines the length of an array
| `array_concat()` | Concatenates array with each other in order | `array_concat()` | Concatenates array with each other in order
| `array_contains_all()` | Determines if one array holds all elements of another array
| `array_contains_all_nullable()` | Determines if one array holds all elements of another array, supporting null elements
| `array_contains_any()` | Determines if one array holds at least one element of another array
| `array_contains_any_nullable()` | Determines if one array holds at least one element of another array, supporting null elements
| `array_get()` | Accesses the element of an array by index
|=== |===
===== `array()` ===== `array()`
@ -1241,6 +1246,20 @@ include::{array-example-dir-hql}/ArrayContainsAnyTest.java[tags=hql-array-contai
---- ----
==== ====
[[hql-array-get-functions]]
===== `array_get()`
Returns the element of an array at the given 1-based index. Returns `null` if either of the arguments is `null`,
and also if the index is bigger than the array length.
[[hql-array-get-example]]
====
[source, JAVA, indent=0]
----
include::{array-example-dir-hql}/ArrayGetTest.java[tags=hql-array-get-example]
----
====
[[hql-user-defined-functions]] [[hql-user-defined-functions]]
==== Native and user-defined functions ==== Native and user-defined functions

View File

@ -472,6 +472,7 @@ public class CockroachLegacyDialect extends Dialect {
functionFactory.arrayContainsAny_operator(); functionFactory.arrayContainsAny_operator();
functionFactory.arrayContainsAllNullable_operator(); functionFactory.arrayContainsAllNullable_operator();
functionFactory.arrayContainsAnyNullable_operator(); functionFactory.arrayContainsAnyNullable_operator();
functionFactory.arrayGet_bracket();
functionContributions.getFunctionRegistry().register( functionContributions.getFunctionRegistry().register(
"trunc", "trunc",

View File

@ -380,6 +380,7 @@ public class H2LegacyDialect extends Dialect {
functionFactory.arrayContainsAny_h2(); functionFactory.arrayContainsAny_h2();
functionFactory.arrayContainsAllNullable_h2(); functionFactory.arrayContainsAllNullable_h2();
functionFactory.arrayContainsAnyNullable_h2(); functionFactory.arrayContainsAnyNullable_h2();
functionFactory.arrayGet_h2();
} }
else { else {
// Use group_concat until 2.x as listagg was buggy // Use group_concat until 2.x as listagg was buggy

View File

@ -258,6 +258,7 @@ public class HSQLLegacyDialect extends Dialect {
functionFactory.arrayContainsAny_hsql(); functionFactory.arrayContainsAny_hsql();
functionFactory.arrayContainsAllNullable_hsql(); functionFactory.arrayContainsAllNullable_hsql();
functionFactory.arrayContainsAnyNullable_hsql(); functionFactory.arrayContainsAnyNullable_hsql();
functionFactory.arrayGet_unnest();
} }
@Override @Override

View File

@ -294,6 +294,7 @@ public class OracleLegacyDialect extends Dialect {
functionFactory.arrayContainsAny_oracle(); functionFactory.arrayContainsAny_oracle();
functionFactory.arrayContainsAllNullable_oracle(); functionFactory.arrayContainsAllNullable_oracle();
functionFactory.arrayContainsAnyNullable_oracle(); functionFactory.arrayContainsAnyNullable_oracle();
functionFactory.arrayGet_oracle();
} }
@Override @Override

View File

@ -592,6 +592,7 @@ public class PostgreSQLLegacyDialect extends Dialect {
functionFactory.arrayContainsAny_operator(); functionFactory.arrayContainsAny_operator();
functionFactory.arrayContainsAllNullable_operator(); functionFactory.arrayContainsAllNullable_operator();
functionFactory.arrayContainsAnyNullable_operator(); functionFactory.arrayContainsAnyNullable_operator();
functionFactory.arrayGet_bracket();
if ( getVersion().isSameOrAfter( 9, 4 ) ) { if ( getVersion().isSameOrAfter( 9, 4 ) ) {
functionFactory.makeDateTimeTimestamp(); functionFactory.makeDateTimeTimestamp();

View File

@ -459,6 +459,7 @@ public class CockroachDialect extends Dialect {
functionFactory.arrayContainsAny_operator(); functionFactory.arrayContainsAny_operator();
functionFactory.arrayContainsAllNullable_operator(); functionFactory.arrayContainsAllNullable_operator();
functionFactory.arrayContainsAnyNullable_operator(); functionFactory.arrayContainsAnyNullable_operator();
functionFactory.arrayGet_bracket();
functionContributions.getFunctionRegistry().register( functionContributions.getFunctionRegistry().register(
"trunc", "trunc",

View File

@ -320,6 +320,7 @@ public class H2Dialect extends Dialect {
functionFactory.arrayContainsAny_h2(); functionFactory.arrayContainsAny_h2();
functionFactory.arrayContainsAllNullable_h2(); functionFactory.arrayContainsAllNullable_h2();
functionFactory.arrayContainsAnyNullable_h2(); functionFactory.arrayContainsAnyNullable_h2();
functionFactory.arrayGet_h2();
} }
@Override @Override

View File

@ -198,6 +198,7 @@ public class HSQLDialect extends Dialect {
functionFactory.arrayContainsAny_hsql(); functionFactory.arrayContainsAny_hsql();
functionFactory.arrayContainsAllNullable_hsql(); functionFactory.arrayContainsAllNullable_hsql();
functionFactory.arrayContainsAnyNullable_hsql(); functionFactory.arrayContainsAnyNullable_hsql();
functionFactory.arrayGet_unnest();
} }
@Override @Override

View File

@ -327,6 +327,22 @@ public class OracleArrayJdbcType extends ArrayJdbcType {
false false
) )
); );
database.addAuxiliaryDatabaseObject(
new NamedAuxiliaryDatabaseObject(
arrayTypeName + "_get",
database.getDefaultNamespace(),
new String[]{
"create or replace function " + arrayTypeName + "_get(arr in " + arrayTypeName +
", idx in number) return " + getRawTypeName( elementType ) + " deterministic is begin " +
"if arr is null or idx is null or arr.count < idx then return null; end if; " +
"return arr(idx); " +
"end;"
},
new String[] { "drop function " + arrayTypeName + "_get" },
emptySet(),
false
)
);
} }
protected String createOrReplaceConcatFunction(String arrayTypeName) { protected String createOrReplaceConcatFunction(String arrayTypeName) {

View File

@ -323,6 +323,7 @@ public class OracleDialect extends Dialect {
functionFactory.arrayContainsAny_oracle(); functionFactory.arrayContainsAny_oracle();
functionFactory.arrayContainsAllNullable_oracle(); functionFactory.arrayContainsAllNullable_oracle();
functionFactory.arrayContainsAnyNullable_oracle(); functionFactory.arrayContainsAnyNullable_oracle();
functionFactory.arrayGet_oracle();
} }
@Override @Override

View File

@ -640,6 +640,7 @@ public class PostgreSQLDialect extends Dialect {
functionFactory.arrayContainsAny_operator(); functionFactory.arrayContainsAny_operator();
functionFactory.arrayContainsAllNullable_operator(); functionFactory.arrayContainsAllNullable_operator();
functionFactory.arrayContainsAnyNullable_operator(); functionFactory.arrayContainsAnyNullable_operator();
functionFactory.arrayGet_bracket();
functionFactory.makeDateTimeTimestamp(); functionFactory.makeDateTimeTimestamp();
// Note that PostgreSQL doesn't support the OVER clause for ordered set-aggregate functions // Note that PostgreSQL doesn't support the OVER clause for ordered set-aggregate functions

View File

@ -21,11 +21,14 @@ import org.hibernate.dialect.function.array.ArrayConstructorFunction;
import org.hibernate.dialect.function.array.ArrayContainsQuantifiedOperatorFunction; import org.hibernate.dialect.function.array.ArrayContainsQuantifiedOperatorFunction;
import org.hibernate.dialect.function.array.ArrayContainsOperatorFunction; import org.hibernate.dialect.function.array.ArrayContainsOperatorFunction;
import org.hibernate.dialect.function.array.ArrayContainsQuantifiedUnnestFunction; import org.hibernate.dialect.function.array.ArrayContainsQuantifiedUnnestFunction;
import org.hibernate.dialect.function.array.ArrayGetUnnestFunction;
import org.hibernate.dialect.function.array.ElementViaArrayArgumentReturnTypeResolver;
import org.hibernate.dialect.function.array.H2ArrayContainsQuantifiedEmulation; import org.hibernate.dialect.function.array.H2ArrayContainsQuantifiedEmulation;
import org.hibernate.dialect.function.array.HSQLArrayPositionFunction; import org.hibernate.dialect.function.array.HSQLArrayPositionFunction;
import org.hibernate.dialect.function.array.OracleArrayConcatFunction; import org.hibernate.dialect.function.array.OracleArrayConcatFunction;
import org.hibernate.dialect.function.array.OracleArrayContainsAllFunction; import org.hibernate.dialect.function.array.OracleArrayContainsAllFunction;
import org.hibernate.dialect.function.array.OracleArrayContainsAnyFunction; import org.hibernate.dialect.function.array.OracleArrayContainsAnyFunction;
import org.hibernate.dialect.function.array.OracleArrayGetFunction;
import org.hibernate.dialect.function.array.OracleArrayLengthFunction; import org.hibernate.dialect.function.array.OracleArrayLengthFunction;
import org.hibernate.dialect.function.array.OracleArrayPositionFunction; import org.hibernate.dialect.function.array.OracleArrayPositionFunction;
import org.hibernate.dialect.function.array.PostgreSQLArrayConcatFunction; import org.hibernate.dialect.function.array.PostgreSQLArrayConcatFunction;
@ -36,6 +39,7 @@ import org.hibernate.dialect.function.array.OracleArrayConstructorFunction;
import org.hibernate.dialect.function.array.OracleArrayContainsFunction; import org.hibernate.dialect.function.array.OracleArrayContainsFunction;
import org.hibernate.dialect.function.array.OracleArrayContainsNullFunction; import org.hibernate.dialect.function.array.OracleArrayContainsNullFunction;
import org.hibernate.query.sqm.function.SqmFunctionRegistry; import org.hibernate.query.sqm.function.SqmFunctionRegistry;
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers; import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
@ -2925,4 +2929,51 @@ public class CommonFunctionFactory {
public void arrayConcat_oracle() { public void arrayConcat_oracle() {
functionRegistry.register( "array_concat", new OracleArrayConcatFunction() ); functionRegistry.register( "array_concat", new OracleArrayConcatFunction() );
} }
/**
* H2 array_get() function via bracket syntax
*/
public void arrayGet_h2() {
functionRegistry.patternDescriptorBuilder( "array_get", "case when array_length(?1)>=?2 then ?1[?2] end" )
.setReturnTypeResolver( ElementViaArrayArgumentReturnTypeResolver.DEFAULT_INSTANCE )
.setArgumentsValidator(
StandardArgumentsValidators.composite(
ArrayArgumentValidator.DEFAULT_INSTANCE,
new ArgumentTypesValidator( null, ANY, INTEGER )
)
)
.setArgumentTypeResolver( StandardFunctionArgumentTypeResolvers.invariant( ANY, INTEGER ) )
.setArgumentListSignature( "(ARRAY array, INTEGER index)" )
.register();
}
/**
* CockroachDB and PostgreSQL array_get() function via bracket syntax
*/
public void arrayGet_bracket() {
functionRegistry.patternDescriptorBuilder( "array_get", "?1[?2]" )
.setReturnTypeResolver( ElementViaArrayArgumentReturnTypeResolver.DEFAULT_INSTANCE )
.setArgumentsValidator(
StandardArgumentsValidators.composite(
ArrayArgumentValidator.DEFAULT_INSTANCE,
new ArgumentTypesValidator( null, ANY, INTEGER )
)
)
.setArgumentTypeResolver( StandardFunctionArgumentTypeResolvers.invariant( ANY, INTEGER ) )
.setArgumentListSignature( "(ARRAY array, INTEGER index)" )
.register();
}
/**
* HSQL array_get() function
*/
public void arrayGet_unnest() {
functionRegistry.register( "array_get", new ArrayGetUnnestFunction() );
}
/**
* Oracle array_get() function
*/
public void arrayGet_oracle() {
functionRegistry.register( "array_get", new OracleArrayGetFunction() );
}
} }

View File

@ -0,0 +1,53 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function.array;
import java.util.List;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.Expression;
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.ANY;
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.INTEGER;
/**
* Implement the array get function by using {@code unnest}.
*/
public class ArrayGetUnnestFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
public ArrayGetUnnestFunction() {
super(
"array_get",
StandardArgumentsValidators.composite(
ArrayArgumentValidator.DEFAULT_INSTANCE,
new ArgumentTypesValidator( null, ANY, INTEGER )
),
ElementViaArrayArgumentReturnTypeResolver.DEFAULT_INSTANCE,
StandardFunctionArgumentTypeResolvers.invariant( ANY, INTEGER )
);
}
@Override
public void render(
SqlAppender sqlAppender,
List<? extends SqlAstNode> sqlAstArguments,
SqlAstTranslator<?> walker) {
final Expression arrayExpression = (Expression) sqlAstArguments.get( 0 );
final Expression indexExpression = (Expression) sqlAstArguments.get( 1 );
sqlAppender.append( "(select t.val from unnest(" );
arrayExpression.accept( walker );
sqlAppender.append( ") with ordinality t(val, idx) where t.idx=" );
indexExpression.accept( walker );
sqlAppender.append( ')' );
}
}

View File

@ -0,0 +1,70 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.dialect.function.array;
import java.util.List;
import java.util.function.Supplier;
import org.hibernate.metamodel.mapping.BasicValuedMapping;
import org.hibernate.metamodel.mapping.MappingModelExpressible;
import org.hibernate.metamodel.model.domain.DomainType;
import org.hibernate.query.ReturnableType;
import org.hibernate.query.sqm.SqmExpressible;
import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver;
import org.hibernate.query.sqm.tree.SqmTypedNode;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.type.BasicPluralType;
import org.hibernate.type.spi.TypeConfiguration;
/**
* A {@link FunctionReturnTypeResolver} that resolves the array element type based on an argument.
* The inferred type and implied type have precedence though.
*/
public class ElementViaArrayArgumentReturnTypeResolver implements FunctionReturnTypeResolver {
public static final FunctionReturnTypeResolver DEFAULT_INSTANCE = new ElementViaArrayArgumentReturnTypeResolver( 0 );
private final int arrayIndex;
private ElementViaArrayArgumentReturnTypeResolver(int arrayIndex) {
this.arrayIndex = arrayIndex;
}
@Override
public ReturnableType<?> resolveFunctionReturnType(
ReturnableType<?> impliedType,
Supplier<MappingModelExpressible<?>> inferredTypeSupplier,
List<? extends SqmTypedNode<?>> arguments,
TypeConfiguration typeConfiguration) {
final MappingModelExpressible<?> inferredType = inferredTypeSupplier.get();
if ( inferredType != null ) {
if ( inferredType instanceof ReturnableType<?> ) {
return (ReturnableType<?>) inferredType;
}
else if ( inferredType instanceof BasicValuedMapping ) {
return (ReturnableType<?>) ( (BasicValuedMapping) inferredType ).getJdbcMapping();
}
}
if ( impliedType != null ) {
return impliedType;
}
final SqmExpressible<?> expressible = arguments.get( arrayIndex ).getExpressible();
final DomainType<?> type;
if ( expressible != null && ( type = expressible.getSqmType() ) instanceof BasicPluralType<?, ?> ) {
return ( (BasicPluralType<?, ?>) type ).getElementType();
}
return null;
}
@Override
public BasicValuedMapping resolveFunctionReturnType(
Supplier<BasicValuedMapping> impliedTypeAccess,
List<? extends SqlAstNode> arguments) {
return null;
}
}

View File

@ -0,0 +1,46 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function.array;
import java.util.List;
import org.hibernate.metamodel.mapping.JdbcMappingContainer;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.Expression;
/**
* Oracle array_get function.
*/
public class OracleArrayGetFunction extends ArrayGetUnnestFunction {
public OracleArrayGetFunction() {
}
@Override
public void render(
SqlAppender sqlAppender,
List<? extends SqlAstNode> sqlAstArguments,
SqlAstTranslator<?> walker) {
JdbcMappingContainer expressionType = null;
for ( SqlAstNode sqlAstArgument : sqlAstArguments ) {
expressionType = ( (Expression) sqlAstArgument ).getExpressionType();
if ( expressionType != null ) {
break;
}
}
final String arrayTypeName = ArrayTypeHelper.getArrayTypeName( expressionType, walker );
sqlAppender.append( arrayTypeName );
sqlAppender.append( "_get(" );
sqlAstArguments.get( 0 ).accept( walker );
sqlAppender.append( ',' );
sqlAstArguments.get( 1 ).accept( walker );
sqlAppender.append( ')' );
}
}

View File

@ -0,0 +1,85 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.orm.test.function.array;
import java.util.List;
import org.hibernate.cfg.AvailableSettings;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.hibernate.testing.orm.junit.ServiceRegistry;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.hibernate.testing.orm.junit.Setting;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Christian Beikov
*/
@DomainModel(annotatedClasses = EntityWithArrays.class)
@SessionFactory
@RequiresDialectFeature( feature = DialectFeatureChecks.SupportsStructuralArrays.class)
// Make sure this stuff runs on a dedicated connection pool,
// otherwise we might run into ORA-21700: object does not exist or is marked for delete
// because the JDBC connection or database session caches something that should have been invalidated
@ServiceRegistry(settings = @Setting(name = AvailableSettings.CONNECTION_PROVIDER, value = ""))
public class ArrayGetTest {
@BeforeEach
public void prepareData(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.persist( new EntityWithArrays( 1L, new String[]{} ) );
em.persist( new EntityWithArrays( 2L, new String[]{ "abc", null, "def" } ) );
em.persist( new EntityWithArrays( 3L, null ) );
} );
}
@AfterEach
public void cleanup(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.createMutationQuery( "delete from EntityWithArrays" ).executeUpdate();
} );
}
@Test
public void testGet(SessionFactoryScope scope) {
scope.inSession( em -> {
//tag::hql-array-get-example[]
List<EntityWithArrays> results = em.createQuery( "from EntityWithArrays e where array_get(e.theArray, 1) = 'abc'", EntityWithArrays.class )
.getResultList();
//end::hql-array-get-example[]
assertEquals( 1, results.size() );
assertEquals( 2L, results.get( 0 ).getId() );
} );
}
@Test
public void testGetNullElement(SessionFactoryScope scope) {
scope.inSession( em -> {
List<EntityWithArrays> results = em.createQuery( "from EntityWithArrays e where array_length(e.theArray) >= 2 and array_get(e.theArray, 2) is null", EntityWithArrays.class )
.getResultList();
assertEquals( 1, results.size() );
assertEquals( 2L, results.get( 0 ).getId() );
} );
}
@Test
public void testGetNotExisting(SessionFactoryScope scope) {
scope.inSession( em -> {
List<EntityWithArrays> results = em.createQuery( "from EntityWithArrays e where array_get(e.theArray,100) is null", EntityWithArrays.class )
.getResultList();
assertEquals( 3, results.size() );
} );
}
}