HHH-17335 Add array_concat function

This commit is contained in:
Christian Beikov 2023-10-20 16:48:11 +02:00
parent 36b7374ba8
commit d46fcf1abe
23 changed files with 453 additions and 0 deletions

View File

@ -1122,6 +1122,7 @@ The following functions deal with SQL array types, which are not supported on ev
| `array_contains_null()` | Whether an array contains a null
| `array_position()` | Determines the position of an element in an array
| `array_length()` | Determines the length of an array
| `array_concat()` | Concatenates array with each other in order
|===
===== `array()`
@ -1181,6 +1182,19 @@ include::{array-example-dir-hql}/ArrayLengthTest.java[tags=hql-array-length-exam
----
====
[[hql-array-concat-functions]]
===== `array_concat()`
Concatenates arrays with each other in order. Returns `null` if one of the arguments is `null`.
[[hql-array-concat-example]]
====
[source, JAVA, indent=0]
----
include::{array-example-dir-hql}/ArrayConcatTest.java[tags=hql-array-concat-example]
----
====
[[hql-user-defined-functions]]
==== Native and user-defined functions

View File

@ -467,6 +467,7 @@ public class CockroachLegacyDialect extends Dialect {
functionFactory.arrayContainsNull_array_position();
functionFactory.arrayPosition_postgresql();
functionFactory.arrayLength_cardinality();
functionFactory.arrayConcat_postgresql();
functionContributions.getFunctionRegistry().register(
"trunc",

View File

@ -375,6 +375,7 @@ public class H2LegacyDialect extends Dialect {
functionFactory.arrayContains();
functionFactory.arrayContainsNull();
functionFactory.arrayLength_cardinality();
functionFactory.arrayConcat_operator();
}
else {
// Use group_concat until 2.x as listagg was buggy

View File

@ -253,6 +253,7 @@ public class HSQLLegacyDialect extends Dialect {
functionFactory.arrayContainsNull_hsql();
functionFactory.arrayPosition_hsql();
functionFactory.arrayLength_cardinality();
functionFactory.arrayConcat_operator();
}
@Override

View File

@ -289,6 +289,7 @@ public class OracleLegacyDialect extends Dialect {
functionFactory.arrayContainsNull_oracle();
functionFactory.arrayPosition_oracle();
functionFactory.arrayLength_oracle();
functionFactory.arrayConcat_oracle();
}
@Override

View File

@ -587,6 +587,7 @@ public class PostgreSQLLegacyDialect extends Dialect {
functionFactory.arrayContainsNull_array_position();
functionFactory.arrayPosition_postgresql();
functionFactory.arrayLength_cardinality();
functionFactory.arrayConcat_postgresql();
if ( getVersion().isSameOrAfter( 9, 4 ) ) {
functionFactory.makeDateTimeTimestamp();

View File

@ -454,6 +454,7 @@ public class CockroachDialect extends Dialect {
functionFactory.arrayContainsNull_array_position();
functionFactory.arrayPosition_postgresql();
functionFactory.arrayLength_cardinality();
functionFactory.arrayConcat_postgresql();
functionContributions.getFunctionRegistry().register(
"trunc",

View File

@ -315,6 +315,7 @@ public class H2Dialect extends Dialect {
functionFactory.arrayContains();
functionFactory.arrayContainsNull();
functionFactory.arrayLength_cardinality();
functionFactory.arrayConcat_operator();
}
@Override

View File

@ -193,6 +193,7 @@ public class HSQLDialect extends Dialect {
functionFactory.arrayContainsNull_hsql();
functionFactory.arrayPosition_hsql();
functionFactory.arrayLength_cardinality();
functionFactory.arrayConcat_operator();
}
@Override

View File

@ -272,6 +272,48 @@ public class OracleArrayJdbcType extends ArrayJdbcType {
false
)
);
database.addAuxiliaryDatabaseObject(
new NamedAuxiliaryDatabaseObject(
arrayTypeName + "_concat",
database.getDefaultNamespace(),
new String[]{ createOrReplaceConcatFunction( arrayTypeName ) },
new String[] { "drop function " + arrayTypeName + "_concat" },
emptySet(),
false
)
);
}
protected String createOrReplaceConcatFunction(String arrayTypeName) {
// Since Oracle has no builtin concat function for varrays and doesn't support varargs,
// we have to create a function with a fixed amount of arguments with default that fits "most" cases.
// Let's just use 5 for the time being until someone requests more.
return createOrReplaceConcatFunction( arrayTypeName, 5 );
}
protected String createOrReplaceConcatFunction(String arrayTypeName, int maxConcatParams) {
final StringBuilder sb = new StringBuilder();
sb.append( "create or replace function " ).append( arrayTypeName ).append( "_concat(" );
sb.append( "arr0 in " ).append( arrayTypeName ).append( ",arr1 in " ).append( arrayTypeName );
for ( int i = 2; i < maxConcatParams; i++ ) {
sb.append( ",arr" ).append( i ).append( " in " ).append( arrayTypeName )
.append( " default " ).append( arrayTypeName ).append( "()" );
}
sb.append( ") return " ).append( arrayTypeName ).append( " deterministic is res " ).append( arrayTypeName )
.append( "; begin if " );
String separator = "";
for ( int i = 0; i < maxConcatParams; i++ ) {
sb.append( separator ).append( "arr" ).append( i ).append( " is null" );
separator = " or ";
}
sb.append( " then return null; end if; " );
sb.append( "select * bulk collect into res from (" );
separator = "";
for ( int i = 0; i < maxConcatParams; i++ ) {
sb.append( separator ).append( "select * from table(arr" ).append( i ).append( ')' );
separator = " union all ";
}
return sb.append( "); return res; end;" ).toString();
}
private static String getRawTypeName(String typeName) {

View File

@ -318,6 +318,7 @@ public class OracleDialect extends Dialect {
functionFactory.arrayContainsNull_oracle();
functionFactory.arrayPosition_oracle();
functionFactory.arrayLength_oracle();
functionFactory.arrayConcat_oracle();
}
@Override

View File

@ -635,6 +635,7 @@ public class PostgreSQLDialect extends Dialect {
functionFactory.arrayContainsNull_array_position();
functionFactory.arrayPosition_postgresql();
functionFactory.arrayLength_cardinality();
functionFactory.arrayConcat_postgresql();
functionFactory.makeDateTimeTimestamp();
// Note that PostgreSQL doesn't support the OVER clause for ordered set-aggregate functions

View File

@ -16,11 +16,15 @@ import org.hibernate.dialect.function.array.ArrayAggFunction;
import org.hibernate.dialect.function.array.ArrayAndElementArgumentTypeResolver;
import org.hibernate.dialect.function.array.ArrayAndElementArgumentValidator;
import org.hibernate.dialect.function.array.ArrayArgumentValidator;
import org.hibernate.dialect.function.array.ArrayConcatArgumentValidator;
import org.hibernate.dialect.function.array.ArrayConcatFunction;
import org.hibernate.dialect.function.array.ArrayConstructorFunction;
import org.hibernate.dialect.function.array.ArrayContainsOperatorFunction;
import org.hibernate.dialect.function.array.HSQLArrayPositionFunction;
import org.hibernate.dialect.function.array.OracleArrayConcatFunction;
import org.hibernate.dialect.function.array.OracleArrayLengthFunction;
import org.hibernate.dialect.function.array.OracleArrayPositionFunction;
import org.hibernate.dialect.function.array.PostgreSQLArrayConcatFunction;
import org.hibernate.dialect.function.array.PostgreSQLArrayPositionFunction;
import org.hibernate.dialect.function.array.CastingArrayConstructorFunction;
import org.hibernate.dialect.function.array.OracleArrayAggEmulation;
@ -2739,4 +2743,25 @@ public class CommonFunctionFactory {
public void arrayLength_oracle() {
functionRegistry.register( "array_length", new OracleArrayLengthFunction( typeConfiguration ) );
}
/**
* H2 and HSQLDB array_concat() function
*/
public void arrayConcat_operator() {
functionRegistry.register( "array_concat", new ArrayConcatFunction( "", "||", "" ) );
}
/**
* CockroachDB and PostgreSQL array_concat() function
*/
public void arrayConcat_postgresql() {
functionRegistry.register( "array_concat", new PostgreSQLArrayConcatFunction() );
}
/**
* Oracle array_concat() function
*/
public void arrayConcat_oracle() {
functionRegistry.register( "array_concat", new OracleArrayConcatFunction() );
}
}

View File

@ -41,6 +41,14 @@ public class ArrayArgumentValidator implements ArgumentsValidator {
List<? extends SqmTypedNode<?>> arguments,
String functionName,
TypeConfiguration typeConfiguration) {
return getElementType( arrayIndex, arguments, functionName, typeConfiguration );
}
protected BasicType<?> getElementType(
int arrayIndex,
List<? extends SqmTypedNode<?>> arguments,
String functionName,
TypeConfiguration typeConfiguration) {
final SqmTypedNode<?> arrayArgument = arguments.get( arrayIndex );
final SqmExpressible<?> arrayType = arrayArgument.getExpressible().getSqmType();
if ( arrayType == null ) {

View File

@ -0,0 +1,66 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.dialect.function.array;
import java.util.List;
import org.hibernate.metamodel.model.domain.DomainType;
import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
import org.hibernate.query.sqm.produce.function.FunctionArgumentException;
import org.hibernate.query.sqm.tree.SqmTypedNode;
import org.hibernate.type.BasicPluralType;
import org.hibernate.type.spi.TypeConfiguration;
/**
* A {@link ArgumentsValidator} that validates all arguments are of the same array type.
*/
public class ArrayConcatArgumentValidator implements ArgumentsValidator {
public static final ArgumentsValidator INSTANCE = new ArrayConcatArgumentValidator();
@Override
public void validate(
List<? extends SqmTypedNode<?>> arguments,
String functionName,
TypeConfiguration typeConfiguration) {
BasicPluralType<?, ?> arrayType = null;
for ( int i = 0; i < arguments.size(); i++ ) {
final DomainType<?> sqmType = arguments.get( i ).getExpressible().getSqmType();
if ( sqmType != null ) {
if ( arrayType == null ) {
if ( !( sqmType instanceof BasicPluralType<?, ?> ) ) {
throw new FunctionArgumentException(
String.format(
"Parameter %d of function '%s()' requires an array type, but argument is of type '%s'",
i,
functionName,
sqmType.getTypeName()
)
);
}
arrayType = (BasicPluralType<?, ?>) sqmType;
}
else if ( !arrayType.equals( sqmType ) ) {
throw new FunctionArgumentException(
String.format(
"Parameter %d of function '%s()' requires an array type %s, but argument is of type '%s'",
i,
functionName,
arrayType.getTypeName(),
sqmType.getTypeName()
)
);
}
}
}
}
@Override
public String getSignature() {
return "(ARRAY array0, ARRAY array1[, ARRAY array2, ...])";
}
}

View File

@ -0,0 +1,56 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function.array;
import java.util.List;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
/**
* Concatenation function for arrays.
*/
public class ArrayConcatFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
private final String prefix;
private final String separator;
private final String suffix;
public ArrayConcatFunction(String prefix, String separator, String suffix) {
super(
"array_concat",
StandardArgumentsValidators.composite(
StandardArgumentsValidators.min( 2 ),
ArrayConcatArgumentValidator.INSTANCE
),
StandardFunctionReturnTypeResolvers.useFirstNonNull(),
StandardFunctionArgumentTypeResolvers.ARGUMENT_OR_IMPLIED_RESULT_TYPE
);
this.prefix = prefix;
this.separator = separator;
this.suffix = suffix;
}
@Override
public void render(
SqlAppender sqlAppender,
List<? extends SqlAstNode> sqlAstArguments,
SqlAstTranslator<?> walker) {
sqlAppender.append( prefix );
sqlAstArguments.get( 0 ).accept( walker );
for ( int i = 1; i < sqlAstArguments.size(); i++ ) {
sqlAppender.append( separator );
sqlAstArguments.get( i ).accept( walker );
}
sqlAppender.append( suffix );
}
}

View File

@ -0,0 +1,44 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function.array;
import java.util.List;
import org.hibernate.metamodel.mapping.JdbcMappingContainer;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.Expression;
/**
* Oracle concatenation function for arrays.
*/
public class OracleArrayConcatFunction extends ArrayConcatFunction {
public OracleArrayConcatFunction() {
super( "(", ",", ")" );
}
@Override
public void render(
SqlAppender sqlAppender,
List<? extends SqlAstNode> sqlAstArguments,
SqlAstTranslator<?> walker) {
JdbcMappingContainer expressionType = null;
for ( SqlAstNode sqlAstArgument : sqlAstArguments ) {
expressionType = ( (Expression) sqlAstArgument ).getExpressionType();
if ( expressionType != null ) {
break;
}
}
final String arrayTypeName = ArrayTypeHelper.getArrayTypeName( expressionType, walker );
sqlAppender.append( arrayTypeName );
sqlAppender.append( "_concat" );
super.render( sqlAppender, sqlAstArguments, walker );
}
}

View File

@ -0,0 +1,42 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function.array;
import java.util.List;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
/**
* PostgreSQL variant of the function to properly return {@code null} when one of the arguments is null.
*/
public class PostgreSQLArrayConcatFunction extends ArrayConcatFunction {
public PostgreSQLArrayConcatFunction() {
super( "", "||", "" );
}
@Override
public void render(
SqlAppender sqlAppender,
List<? extends SqlAstNode> sqlAstArguments,
SqlAstTranslator<?> walker) {
sqlAppender.append( "case when " );
String separator = "";
for ( SqlAstNode node : sqlAstArguments ) {
sqlAppender.append( separator );
node.accept( walker );
sqlAppender.append( " is not null" );
separator = " and ";
}
sqlAppender.append( " then " );
super.render( sqlAppender, sqlAstArguments, walker );
sqlAppender.append( " end" );
}
}

View File

@ -6,6 +6,8 @@
*/
package org.hibernate.type;
import java.util.Objects;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.jdbc.JdbcType;
import org.hibernate.type.descriptor.jdbc.JdbcTypeIndicators;
@ -51,4 +53,15 @@ public class BasicArrayType<T,E>
//noinspection unchecked
return (BasicType<X>) this;
}
@Override
public boolean equals(Object o) {
return o == this || o.getClass() == BasicArrayType.class
&& Objects.equals( baseDescriptor, ( (BasicArrayType<?, ?>) o ).baseDescriptor );
}
@Override
public int hashCode() {
return baseDescriptor.hashCode();
}
}

View File

@ -7,6 +7,7 @@
package org.hibernate.type;
import java.util.Collection;
import java.util.Objects;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.java.spi.BasicCollectionJavaType;
@ -73,4 +74,15 @@ public class BasicCollectionType<C extends Collection<E>, E>
//noinspection unchecked
return (BasicType<X>) this;
}
@Override
public boolean equals(Object o) {
return o == this || o.getClass() == BasicCollectionType.class
&& Objects.equals( baseDescriptor, ( (BasicCollectionType<?, ?>) o ).baseDescriptor );
}
@Override
public int hashCode() {
return baseDescriptor.hashCode();
}
}

View File

@ -6,6 +6,8 @@
*/
package org.hibernate.type;
import java.util.Objects;
import org.hibernate.type.descriptor.ValueBinder;
import org.hibernate.type.descriptor.ValueExtractor;
import org.hibernate.type.descriptor.converter.spi.BasicValueConverter;
@ -98,4 +100,18 @@ public class ConvertedBasicArrayType<T,S,E>
public JdbcLiteralFormatter<T> getJdbcLiteralFormatter() {
return jdbcLiteralFormatter;
}
@Override
public boolean equals(Object o) {
return o == this || super.equals( o )
&& o instanceof ConvertedBasicArrayType<?, ?, ?>
&& Objects.equals( converter, ( (ConvertedBasicArrayType<?, ?, ?>) o ).converter );
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + converter.hashCode();
return result;
}
}

View File

@ -7,6 +7,7 @@
package org.hibernate.type;
import java.util.Collection;
import java.util.Objects;
import org.hibernate.type.descriptor.ValueBinder;
import org.hibernate.type.descriptor.ValueExtractor;
@ -65,4 +66,18 @@ public class ConvertedBasicCollectionType<C extends Collection<E>, E> extends Ba
public JdbcLiteralFormatter<C> getJdbcLiteralFormatter() {
return jdbcLiteralFormatter;
}
@Override
public boolean equals(Object o) {
return o == this || super.equals( o )
&& o instanceof ConvertedBasicCollectionType<?, ?>
&& Objects.equals( converter, ( (ConvertedBasicCollectionType<?, ?>) o ).converter );
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + converter.hashCode();
return result;
}
}

View File

@ -0,0 +1,90 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.orm.test.function.array;
import java.util.List;
import org.hibernate.cfg.AvailableSettings;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.hibernate.testing.orm.junit.ServiceRegistry;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.hibernate.testing.orm.junit.Setting;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Tuple;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
/**
* @author Christian Beikov
*/
@DomainModel(annotatedClasses = EntityWithArrays.class)
@SessionFactory
@RequiresDialectFeature( feature = DialectFeatureChecks.SupportsStructuralArrays.class)
// Make sure this stuff runs on a dedicated connection pool,
// otherwise we might run into ORA-21700: object does not exist or is marked for delete
// because the JDBC connection or database session caches something that should have been invalidated
@ServiceRegistry(settings = @Setting(name = AvailableSettings.CONNECTION_PROVIDER, value = ""))
public class ArrayConcatTest {
@BeforeEach
public void prepareData(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.persist( new EntityWithArrays( 1L, new String[]{} ) );
em.persist( new EntityWithArrays( 2L, new String[]{ "abc", null, "def" } ) );
em.persist( new EntityWithArrays( 3L, null ) );
} );
}
@AfterEach
public void cleanup(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.createMutationQuery( "delete from EntityWithArrays" ).executeUpdate();
} );
}
@Test
public void testConcatAppend(SessionFactoryScope scope) {
scope.inSession( em -> {
//tag::hql-array-concat-example[]
List<Tuple> results = em.createQuery( "select e.id, array_concat(e.theArray, array('xyz')) from EntityWithArrays e order by e.id", Tuple.class )
.getResultList();
//end::hql-array-concat-example[]
assertEquals( 3, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertArrayEquals( new String[]{ "xyz" }, results.get( 0 ).get( 1, String[].class ) );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertArrayEquals( new String[]{ "abc", null, "def", "xyz" }, results.get( 1 ).get( 1, String[].class ) );
assertEquals( 3L, results.get( 2 ).get( 0 ) );
assertNull( results.get( 2 ).get( 1, String[].class ) );
} );
}
@Test
public void testConcatPrepend(SessionFactoryScope scope) {
scope.inSession( em -> {
List<Tuple> results = em.createQuery( "select e.id, array_concat(array('xyz'), e.theArray) from EntityWithArrays e order by e.id", Tuple.class )
.getResultList();
assertEquals( 3, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertArrayEquals( new String[]{ "xyz" }, results.get( 0 ).get( 1, String[].class ) );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertArrayEquals( new String[]{ "xyz", "abc", null, "def" }, results.get( 1 ).get( 1, String[].class ) );
assertEquals( 3L, results.get( 2 ).get( 0 ) );
assertNull( results.get( 2 ).get( 1, String[].class ) );
} );
}
}