HHH-17335 Add array_replace function

This commit is contained in:
Christian Beikov 2023-10-24 16:59:15 +02:00
parent 937116ed8a
commit 24fa18f954
19 changed files with 374 additions and 26 deletions

View File

@ -1131,6 +1131,7 @@ The following functions deal with SQL array types, which are not supported on ev
| `array_set()` | Creates array copy with given element at given index
| `array_remove()` | Creates array copy with given element removed
| `array_slice()` | Creates a sub-array of the based on lower and upper index
| `array_replace()` | Creates array copy replacing a given element with another
|===
===== `array()`
@ -1316,6 +1317,19 @@ include::{array-example-dir-hql}/ArrayRemoveIndexTest.java[tags=hql-array-remove
----
====
[[hql-array-replace-functions]]
===== `array_replace()`
Returns an array copy which has elements matching the second argument replaced by the third argument.
[[hql-array-replace-example]]
====
[source, JAVA, indent=0]
----
include::{array-example-dir-hql}/ArrayReplaceTest.java[tags=hql-array-replace-example]
----
====
[[hql-user-defined-functions]]
==== Native and user-defined functions

View File

@ -477,6 +477,7 @@ public class CockroachLegacyDialect extends Dialect {
functionFactory.arrayRemove();
functionFactory.arrayRemoveIndex_unnest( true );
functionFactory.arraySlice_operator();
functionFactory.arrayReplace();
functionContributions.getFunctionRegistry().register(
"trunc",

View File

@ -385,6 +385,7 @@ public class H2LegacyDialect extends Dialect {
functionFactory.arrayRemove_h2();
functionFactory.arrayRemoveIndex_h2();
functionFactory.arraySlice();
functionFactory.arrayReplace_h2();
}
else {
// Use group_concat until 2.x as listagg was buggy

View File

@ -263,6 +263,7 @@ public class HSQLLegacyDialect extends Dialect {
functionFactory.arrayRemove_hsql();
functionFactory.arrayRemoveIndex_unnest( false );
functionFactory.arraySlice_unnest();
functionFactory.arrayReplace_unnest();
}
@Override

View File

@ -299,6 +299,7 @@ public class OracleLegacyDialect extends Dialect {
functionFactory.arrayRemove_oracle();
functionFactory.arrayRemoveIndex_oracle();
functionFactory.arraySlice_oracle();
functionFactory.arrayReplace_oracle();
}
@Override

View File

@ -597,6 +597,7 @@ public class PostgreSQLLegacyDialect extends Dialect {
functionFactory.arrayRemove();
functionFactory.arrayRemoveIndex_unnest( true );
functionFactory.arraySlice_operator();
functionFactory.arrayReplace();
if ( getVersion().isSameOrAfter( 9, 4 ) ) {
functionFactory.makeDateTimeTimestamp();

View File

@ -464,6 +464,7 @@ public class CockroachDialect extends Dialect {
functionFactory.arrayRemove();
functionFactory.arrayRemoveIndex_unnest( true );
functionFactory.arraySlice_operator();
functionFactory.arrayReplace();
functionContributions.getFunctionRegistry().register(
"trunc",

View File

@ -324,6 +324,7 @@ public class H2Dialect extends Dialect {
functionFactory.arrayRemove_h2();
functionFactory.arrayRemoveIndex_h2();
functionFactory.arraySlice();
functionFactory.arrayReplace_h2();
}
@Override

View File

@ -203,6 +203,7 @@ public class HSQLDialect extends Dialect {
functionFactory.arrayRemove_hsql();
functionFactory.arrayRemoveIndex_unnest( false );
functionFactory.arraySlice_unnest();
functionFactory.arrayReplace_unnest();
}
@Override

View File

@ -439,6 +439,38 @@ public class OracleArrayJdbcType extends ArrayJdbcType {
false
)
);
database.addAuxiliaryDatabaseObject(
new NamedAuxiliaryDatabaseObject(
arrayTypeName + "_replace",
database.getDefaultNamespace(),
new String[]{
"create or replace function " + arrayTypeName + "_replace(arr in " + arrayTypeName +
", old in " + getRawTypeName( elementType ) + ", elem in " + getRawTypeName( elementType ) + ") return " + arrayTypeName + " deterministic is " +
"res " + arrayTypeName + ":=" + arrayTypeName + "(); begin " +
"if arr is null then return null; end if; " +
"if old is null then " +
"for i in 1 .. arr.count loop " +
"res.extend; " +
"res(res.last) := coalesce(arr(i),elem); " +
"end loop; " +
"else " +
"for i in 1 .. arr.count loop " +
"res.extend; " +
"if arr(i) = old then " +
"res(res.last) := elem; " +
"else " +
"res(res.last) := arr(i); " +
"end if; " +
"end loop; " +
"end if; " +
"return res; " +
"end;"
},
new String[] { "drop function " + arrayTypeName + "_replace" },
emptySet(),
false
)
);
}
protected String createOrReplaceConcatFunction(String arrayTypeName) {

View File

@ -328,6 +328,7 @@ public class OracleDialect extends Dialect {
functionFactory.arrayRemove_oracle();
functionFactory.arrayRemoveIndex_oracle();
functionFactory.arraySlice_oracle();
functionFactory.arrayReplace_oracle();
}
@Override

View File

@ -645,6 +645,7 @@ public class PostgreSQLDialect extends Dialect {
functionFactory.arrayRemove();
functionFactory.arrayRemoveIndex_unnest( true );
functionFactory.arraySlice_operator();
functionFactory.arrayReplace();
functionFactory.makeDateTimeTimestamp();
// Note that PostgreSQL doesn't support the OVER clause for ordered set-aggregate functions

View File

@ -23,6 +23,7 @@ import org.hibernate.dialect.function.array.ArrayContainsOperatorFunction;
import org.hibernate.dialect.function.array.ArrayContainsQuantifiedUnnestFunction;
import org.hibernate.dialect.function.array.ArrayGetUnnestFunction;
import org.hibernate.dialect.function.array.ArrayRemoveIndexUnnestFunction;
import org.hibernate.dialect.function.array.ArrayReplaceUnnestFunction;
import org.hibernate.dialect.function.array.ArraySetUnnestFunction;
import org.hibernate.dialect.function.array.ArraySliceUnnestFunction;
import org.hibernate.dialect.function.array.ArrayViaArgumentReturnTypeResolver;
@ -30,6 +31,7 @@ import org.hibernate.dialect.function.array.ElementViaArrayArgumentReturnTypeRes
import org.hibernate.dialect.function.array.H2ArrayContainsQuantifiedEmulation;
import org.hibernate.dialect.function.array.H2ArrayRemoveFunction;
import org.hibernate.dialect.function.array.H2ArrayRemoveIndexFunction;
import org.hibernate.dialect.function.array.H2ArrayReplaceFunction;
import org.hibernate.dialect.function.array.H2ArraySetFunction;
import org.hibernate.dialect.function.array.HSQLArrayPositionFunction;
import org.hibernate.dialect.function.array.HSQLArrayRemoveFunction;
@ -42,6 +44,7 @@ import org.hibernate.dialect.function.array.OracleArrayLengthFunction;
import org.hibernate.dialect.function.array.OracleArrayPositionFunction;
import org.hibernate.dialect.function.array.OracleArrayRemoveFunction;
import org.hibernate.dialect.function.array.OracleArrayRemoveIndexFunction;
import org.hibernate.dialect.function.array.OracleArrayReplaceFunction;
import org.hibernate.dialect.function.array.OracleArraySetFunction;
import org.hibernate.dialect.function.array.OracleArraySliceFunction;
import org.hibernate.dialect.function.array.PostgreSQLArrayConcatFunction;
@ -3133,4 +3136,42 @@ public class CommonFunctionFactory {
public void arraySlice_oracle() {
functionRegistry.register( "array_slice", new OracleArraySliceFunction() );
}
/**
* H2 array_replace() function
*/
public void arrayReplace_h2() {
functionRegistry.register( "array_replace", new H2ArrayReplaceFunction() );
}
/**
* HSQL array_replace() function
*/
public void arrayReplace_unnest() {
functionRegistry.register( "array_replace", new ArrayReplaceUnnestFunction() );
}
/**
* CockroachDB and PostgreSQL array_replace() function
*/
public void arrayReplace() {
functionRegistry.namedDescriptorBuilder( "array_replace" )
.setArgumentsValidator(
StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 3 ),
new ArrayAndElementArgumentValidator( 0, 1, 2 )
)
)
.setReturnTypeResolver( ArrayViaArgumentReturnTypeResolver.DEFAULT_INSTANCE )
.setArgumentTypeResolver( new ArrayAndElementArgumentTypeResolver( 0, 1, 2 ) )
.setArgumentListSignature( "(ARRAY array, OBJECT old, OBJECT new)" )
.register();
}
/**
* Oracle array_replace() function
*/
public void arrayReplace_oracle() {
functionRegistry.register( "array_replace", new OracleArrayReplaceFunction() );
}
}

View File

@ -6,6 +6,7 @@
*/
package org.hibernate.dialect.function.array;
import org.hibernate.internal.util.collections.ArrayHelper;
import org.hibernate.metamodel.mapping.MappingModelExpressible;
import org.hibernate.metamodel.model.domain.DomainType;
import org.hibernate.query.ReturnableType;
@ -25,11 +26,11 @@ public class ArrayAndElementArgumentTypeResolver implements FunctionArgumentType
public static final FunctionArgumentTypeResolver DEFAULT_INSTANCE = new ArrayAndElementArgumentTypeResolver( 0, 1 );
private final int arrayIndex;
private final int elementIndex;
private final int[] elementIndexes;
public ArrayAndElementArgumentTypeResolver(int arrayIndex, int elementIndex) {
public ArrayAndElementArgumentTypeResolver(int arrayIndex, int... elementIndexes) {
this.arrayIndex = arrayIndex;
this.elementIndex = elementIndex;
this.elementIndexes = elementIndexes;
}
@Override
@ -38,16 +39,18 @@ public class ArrayAndElementArgumentTypeResolver implements FunctionArgumentType
int argumentIndex,
SqmToSqlAstConverter converter) {
if ( argumentIndex == arrayIndex ) {
final SqmTypedNode<?> argument = function.getArguments().get( elementIndex );
final DomainType<?> sqmType = argument.getExpressible().getSqmType();
if ( sqmType instanceof ReturnableType<?> ) {
return ArrayTypeHelper.resolveArrayType(
sqmType,
converter.getCreationContext().getSessionFactory().getTypeConfiguration()
);
for ( int elementIndex : elementIndexes ) {
final SqmTypedNode<?> argument = function.getArguments().get( elementIndex );
final DomainType<?> sqmType = argument.getExpressible().getSqmType();
if ( sqmType instanceof ReturnableType<?> ) {
return ArrayTypeHelper.resolveArrayType(
sqmType,
converter.getCreationContext().getSessionFactory().getTypeConfiguration()
);
}
}
}
else if ( argumentIndex == elementIndex ) {
else if ( ArrayHelper.contains( elementIndexes, argumentIndex ) ) {
final SqmTypedNode<?> argument = function.getArguments().get( arrayIndex );
final SqmExpressible<?> sqmType = argument.getNodeType();
if ( sqmType instanceof BasicPluralType<?, ?> ) {

View File

@ -22,11 +22,11 @@ public class ArrayAndElementArgumentValidator extends ArrayArgumentValidator {
public static final ArgumentsValidator DEFAULT_INSTANCE = new ArrayAndElementArgumentValidator( 0, 1 );
private final int elementIndex;
private final int[] elementIndexes;
public ArrayAndElementArgumentValidator(int arrayIndex, int elementIndex) {
public ArrayAndElementArgumentValidator(int arrayIndex, int... elementIndexes) {
super( arrayIndex );
this.elementIndex = elementIndex;
this.elementIndexes = elementIndexes;
}
@Override
@ -35,18 +35,20 @@ public class ArrayAndElementArgumentValidator extends ArrayArgumentValidator {
String functionName,
TypeConfiguration typeConfiguration) {
final BasicType<?> expectedElementType = getElementType( arguments, functionName, typeConfiguration );
final SqmTypedNode<?> elementArgument = arguments.get( elementIndex );
final SqmExpressible<?> elementType = elementArgument.getExpressible().getSqmType();
if ( expectedElementType != null && elementType != null && expectedElementType != elementType ) {
throw new FunctionArgumentException(
String.format(
"Parameter %d of function '%s()' has type %s, but argument is of type '%s'",
elementIndex,
functionName,
expectedElementType.getJavaTypeDescriptor().getTypeName(),
elementType.getTypeName()
)
);
for ( int elementIndex : elementIndexes ) {
final SqmTypedNode<?> elementArgument = arguments.get( elementIndex );
final SqmExpressible<?> elementType = elementArgument.getExpressible().getSqmType();
if ( expectedElementType != null && elementType != null && expectedElementType != elementType ) {
throw new FunctionArgumentException(
String.format(
"Parameter %d of function '%s()' has type %s, but argument is of type '%s'",
elementIndex,
functionName,
expectedElementType.getJavaTypeDescriptor().getTypeName(),
elementType.getTypeName()
)
);
}
}
}
}

View File

@ -0,0 +1,53 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function.array;
import java.util.List;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.Expression;
/**
* Implement the array replace function by using {@code unnest}.
*/
public class ArrayReplaceUnnestFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
public ArrayReplaceUnnestFunction() {
super(
"array_replace",
StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 3 ),
new ArrayAndElementArgumentValidator( 0, 1, 2 )
),
ArrayViaArgumentReturnTypeResolver.DEFAULT_INSTANCE,
new ArrayAndElementArgumentTypeResolver( 0, 1, 2 )
);
}
@Override
public void render(
SqlAppender sqlAppender,
List<? extends SqlAstNode> sqlAstArguments,
SqlAstTranslator<?> walker) {
final Expression arrayExpression = (Expression) sqlAstArguments.get( 0 );
final Expression oldExpression = (Expression) sqlAstArguments.get( 1 );
final Expression newExpression = (Expression) sqlAstArguments.get( 2 );
sqlAppender.append( "case when ");
arrayExpression.accept( walker );
sqlAppender.append( " is null then null else coalesce((select array_agg(case when t.val is not distinct from " );
oldExpression.accept( walker );
sqlAppender.append( " then " );
newExpression.accept( walker );
sqlAppender.append( " else t.val end) from unnest(" );
arrayExpression.accept( walker );
sqlAppender.append( ") t(val)),array[]) end" );
}
}

View File

@ -0,0 +1,49 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function.array;
import java.util.List;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.Expression;
/**
* H2 array_replace function.
*/
public class H2ArrayReplaceFunction extends ArrayReplaceUnnestFunction {
@Override
public void render(
SqlAppender sqlAppender,
List<? extends SqlAstNode> sqlAstArguments,
SqlAstTranslator<?> walker) {
final Expression arrayExpression = (Expression) sqlAstArguments.get( 0 );
final Expression oldExpression = (Expression) sqlAstArguments.get( 1 );
final Expression newExpression = (Expression) sqlAstArguments.get( 2 );
sqlAppender.append( "case when ");
arrayExpression.accept( walker );
sqlAppender.append( " is null then null else coalesce((select array_agg(case when array_get(");
arrayExpression.accept( walker );
sqlAppender.append(",i.idx) is not distinct from ");
oldExpression.accept( walker );
sqlAppender.append( " then " );
newExpression.accept( walker );
sqlAppender.append( " else array_get(" );
arrayExpression.accept( walker );
sqlAppender.append(",i.idx) end) from system_range(1," );
sqlAppender.append( Integer.toString( getMaximumArraySize() ) );
sqlAppender.append( ") i(idx) where i.idx<=coalesce(cardinality(");
arrayExpression.accept( walker );
sqlAppender.append("),0)),array[]) end" );
}
protected int getMaximumArraySize() {
return 1000;
}
}

View File

@ -0,0 +1,39 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function.array;
import java.util.List;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.Expression;
/**
* Oracle array_replace function.
*/
public class OracleArrayReplaceFunction extends ArrayReplaceUnnestFunction {
@Override
public void render(
SqlAppender sqlAppender,
List<? extends SqlAstNode> sqlAstArguments,
SqlAstTranslator<?> walker) {
final String arrayTypeName = ArrayTypeHelper.getArrayTypeName(
( (Expression) sqlAstArguments.get( 0 ) ).getExpressionType(),
walker
);
sqlAppender.append( arrayTypeName );
sqlAppender.append( "_replace(" );
sqlAstArguments.get( 0 ).accept( walker );
sqlAppender.append( ',' );
sqlAstArguments.get( 1 ).accept( walker );
sqlAppender.append( ',' );
sqlAstArguments.get( 2 ).accept( walker );
sqlAppender.append( ')' );
}
}

View File

@ -0,0 +1,105 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.orm.test.function.array;
import java.util.List;
import org.hibernate.cfg.AvailableSettings;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.hibernate.testing.orm.junit.ServiceRegistry;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.hibernate.testing.orm.junit.Setting;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Tuple;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
/**
* @author Christian Beikov
*/
@DomainModel(annotatedClasses = EntityWithArrays.class)
@SessionFactory
@RequiresDialectFeature( feature = DialectFeatureChecks.SupportsStructuralArrays.class)
// Make sure this stuff runs on a dedicated connection pool,
// otherwise we might run into ORA-21700: object does not exist or is marked for delete
// because the JDBC connection or database session caches something that should have been invalidated
@ServiceRegistry(settings = @Setting(name = AvailableSettings.CONNECTION_PROVIDER, value = ""))
public class ArrayReplaceTest {
@BeforeEach
public void prepareData(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.persist( new EntityWithArrays( 1L, new String[]{} ) );
em.persist( new EntityWithArrays( 2L, new String[]{ "abc", null, "def" } ) );
em.persist( new EntityWithArrays( 3L, null ) );
} );
}
@AfterEach
public void cleanup(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.createMutationQuery( "delete from EntityWithArrays" ).executeUpdate();
} );
}
@Test
public void testReplace(SessionFactoryScope scope) {
scope.inSession( em -> {
//tag::hql-array-replace-example[]
List<Tuple> results = em.createQuery( "select e.id, array_replace(e.theArray, 'abc', 'xyz') from EntityWithArrays e order by e.id", Tuple.class )
.getResultList();
//end::hql-array-replace-example[]
assertEquals( 3, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertArrayEquals( new String[] {}, results.get( 0 ).get( 1, String[].class ) );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertArrayEquals( new String[] { "xyz", null, "def" }, results.get( 1 ).get( 1, String[].class ) );
assertEquals( 3L, results.get( 2 ).get( 0 ) );
assertNull( results.get( 2 ).get( 1, String[].class ) );
} );
}
@Test
public void testReplaceNullElement(SessionFactoryScope scope) {
scope.inSession( em -> {
List<Tuple> results = em.createQuery( "select e.id, array_replace(e.theArray, null, 'aaa') from EntityWithArrays e order by e.id", Tuple.class )
.getResultList();
assertEquals( 3, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertArrayEquals( new String[] {}, results.get( 0 ).get( 1, String[].class ) );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertArrayEquals( new String[] { "abc", "aaa", "def" }, results.get( 1 ).get( 1, String[].class ) );
assertEquals( 3L, results.get( 2 ).get( 0 ) );
assertNull( results.get( 2 ).get( 1, String[].class ) );
} );
}
@Test
public void testReplaceNonExisting(SessionFactoryScope scope) {
scope.inSession( em -> {
List<Tuple> results = em.createQuery( "select e.id, array_replace(e.theArray, 'xyz', 'aaa') from EntityWithArrays e order by e.id", Tuple.class )
.getResultList();
assertEquals( 3, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertArrayEquals( new String[] {}, results.get( 0 ).get( 1, String[].class ) );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertArrayEquals( new String[] { "abc", null, "def" }, results.get( 1 ).get( 1, String[].class ) );
assertEquals( 3L, results.get( 2 ).get( 0 ) );
assertNull( results.get( 2 ).get( 1, String[].class ) );
} );
}
}