From 6808f451bafb4bcee2c7816cf43a89b175b9c8e1 Mon Sep 17 00:00:00 2001 From: Christian Beikov Date: Tue, 18 Jul 2023 12:09:15 +0200 Subject: [PATCH] HHH-16878 Add support for joins in SQL DML AST --- .../sqm/sql/BaseSqmToSqlAstConverter.java | 11 ++- .../sql/ast/spi/AbstractSqlAstTranslator.java | 54 ++++++++++++- .../tree/AbstractUpdateOrDeleteStatement.java | 76 +++++++++++++++++++ .../sql/ast/tree/delete/DeleteStatement.java | 36 +++++---- .../sql/ast/tree/update/UpdateStatement.java | 42 ++++++---- .../MutationQueriesAndNotFoundActionTest.java | 57 ++++++++++++-- 6 files changed, 236 insertions(+), 40 deletions(-) create mode 100644 hibernate-core/src/main/java/org/hibernate/sql/ast/tree/AbstractUpdateOrDeleteStatement.java diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java index 6a5325ba3c..073ddf6508 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java @@ -340,6 +340,7 @@ import org.hibernate.sql.ast.tree.expression.TrimSpecification; import org.hibernate.sql.ast.tree.expression.UnaryOperation; import org.hibernate.sql.ast.tree.from.CorrelatedPluralTableGroup; import org.hibernate.sql.ast.tree.from.CorrelatedTableGroup; +import org.hibernate.sql.ast.tree.from.FromClause; import org.hibernate.sql.ast.tree.from.NamedTableReference; import org.hibernate.sql.ast.tree.from.PluralTableGroup; import org.hibernate.sql.ast.tree.from.QueryPartTableGroup; @@ -848,9 +849,12 @@ public abstract class BaseSqmToSqlAstConverter extends Base suppliedPredicate = visitWhereClause( whereClause.getPredicate() ); } + final FromClause fromClause = new FromClause(); + fromClause.addRoot( rootTableGroup ); return new UpdateStatement( cteContainer, (NamedTableReference) rootTableGroup.getPrimaryTableReference(), + fromClause, assignments, combinePredicates( suppliedPredicate, additionalRestrictions ), Collections.emptyList() @@ -870,7 +874,7 @@ public abstract class BaseSqmToSqlAstConverter extends Base } else { // otherwise... - throw new SemanticException( "Mutation query may only contain embeddable joins" ); + throw new SemanticException( "Mutation query may not contain joins in the SET clause" ); } } @@ -1119,9 +1123,12 @@ public abstract class BaseSqmToSqlAstConverter extends Base suppliedPredicate = visitWhereClause( whereClause.getPredicate() ); } + final FromClause fromClause = new FromClause(); + fromClause.addRoot( rootTableGroup ); return new DeleteStatement( cteContainer, (NamedTableReference) rootTableGroup.getPrimaryTableReference(), + fromClause, combinePredicates( suppliedPredicate, additionalRestrictions ), Collections.emptyList() ); @@ -3567,7 +3574,7 @@ public abstract class BaseSqmToSqlAstConverter extends Base private X prepareReusablePath(SqmPath sqmPath, FromClauseIndex fromClauseIndex, Supplier supplier) { final Consumer implicitJoinChecker; - if ( getCurrentProcessingState() instanceof SqlAstQueryPartProcessingState ) { + if ( getCurrentClauseStack().getCurrent() != Clause.SET_EXPRESSION ) { implicitJoinChecker = tg -> {}; } else { diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java index ea06c56ef3..e4ed2b2a1c 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java @@ -62,6 +62,7 @@ import org.hibernate.persister.internal.SqlFragmentPredicate; import org.hibernate.query.IllegalQueryOperationException; import org.hibernate.query.ReturnableType; import org.hibernate.query.SortDirection; +import org.hibernate.query.results.TableGroupImpl; import org.hibernate.query.spi.Limit; import org.hibernate.query.spi.QueryOptions; import org.hibernate.query.sqm.BinaryArithmeticOperator; @@ -87,6 +88,7 @@ import org.hibernate.sql.ast.SqlAstNodeRenderingMode; import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.SqlTreeCreationException; import org.hibernate.sql.ast.internal.ParameterMarkerStrategyStandard; +import org.hibernate.sql.ast.tree.AbstractUpdateOrDeleteStatement; import org.hibernate.sql.ast.tree.MutationStatement; import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.Statement; @@ -1052,7 +1054,12 @@ public abstract class AbstractSqlAstTranslator implemen clauseStack.pop(); } - visitWhereClause( statement.getRestriction() ); + if ( statement.getFromClause().hasJoins() ) { + visitWhereClause( determineWhereClauseRestrictionWithJoinEmulation( statement ) ); + } + else { + visitWhereClause( statement.getRestriction() ); + } visitReturningColumns( statement.getReturningColumns() ); } @@ -1069,10 +1076,53 @@ public abstract class AbstractSqlAstTranslator implemen } renderSetClause( statement, clauseStack ); - visitWhereClause( statement.getRestriction() ); + if ( statement.getFromClause().hasJoins() ) { + visitWhereClause( determineWhereClauseRestrictionWithJoinEmulation( statement ) ); + } + else { + visitWhereClause( statement.getRestriction() ); + } visitReturningColumns( statement.getReturningColumns() ); } + protected Predicate determineWhereClauseRestrictionWithJoinEmulation(AbstractUpdateOrDeleteStatement statement) { + final QuerySpec querySpec = new QuerySpec( false ); + querySpec.getSelectClause().addSqlSelection( + new SqlSelectionImpl( new QueryLiteral<>( 1, getIntegerType() ) ) + ); + for ( TableGroup root : statement.getFromClause().getRoots() ) { + if ( root.getPrimaryTableReference() == statement.getTargetTable() ) { + for ( TableReferenceJoin tableReferenceJoin : root.getTableReferenceJoins() ) { + assert tableReferenceJoin.getJoinType() == SqlAstJoinType.INNER; + querySpec.getFromClause().addRoot( + new TableGroupImpl( + root.getNavigablePath(), + null, + tableReferenceJoin.getJoinedTableReference(), + root.getModelPart() + ) + ); + querySpec.applyPredicate( tableReferenceJoin.getPredicate() ); + } + for ( TableGroupJoin tableGroupJoin : root.getTableGroupJoins() ) { + assert tableGroupJoin.getJoinType() == SqlAstJoinType.INNER; + querySpec.getFromClause().addRoot( tableGroupJoin.getJoinedGroup() ); + querySpec.applyPredicate( tableGroupJoin.getPredicate() ); + } + for ( TableGroupJoin tableGroupJoin : root.getNestedTableGroupJoins() ) { + assert tableGroupJoin.getJoinType() == SqlAstJoinType.INNER; + querySpec.getFromClause().addRoot( tableGroupJoin.getJoinedGroup() ); + querySpec.applyPredicate( tableGroupJoin.getPredicate() ); + } + } + else { + querySpec.getFromClause().addRoot( root ); + } + } + querySpec.applyPredicate( statement.getRestriction() ); + return new ExistsPredicate( querySpec, false, getBooleanType() ); + } + protected void renderSetClause(UpdateStatement statement, Stack clauseStack) { appendSql( " set" ); char separator = ' '; diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/AbstractUpdateOrDeleteStatement.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/AbstractUpdateOrDeleteStatement.java new file mode 100644 index 0000000000..5be59d070c --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/AbstractUpdateOrDeleteStatement.java @@ -0,0 +1,76 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.sql.ast.tree; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.hibernate.sql.ast.tree.cte.CteContainer; +import org.hibernate.sql.ast.tree.cte.CteStatement; +import org.hibernate.sql.ast.tree.expression.ColumnReference; +import org.hibernate.sql.ast.tree.from.FromClause; +import org.hibernate.sql.ast.tree.from.NamedTableReference; +import org.hibernate.sql.ast.tree.predicate.Predicate; + +public abstract class AbstractUpdateOrDeleteStatement extends AbstractMutationStatement { + private final FromClause fromClause; + private final Predicate restriction; + + public AbstractUpdateOrDeleteStatement( + NamedTableReference targetTable, + FromClause fromClause, + Predicate restriction) { + super( targetTable ); + this.fromClause = fromClause; + this.restriction = restriction; + } + + public AbstractUpdateOrDeleteStatement( + NamedTableReference targetTable, + FromClause fromClause, + Predicate restriction, + List returningColumns) { + super( new LinkedHashMap<>(), targetTable, returningColumns ); + this.fromClause = fromClause; + this.restriction = restriction; + } + + public AbstractUpdateOrDeleteStatement( + CteContainer cteContainer, + NamedTableReference targetTable, + FromClause fromClause, + Predicate restriction, + List returningColumns) { + this( + cteContainer.getCteStatements(), + targetTable, + fromClause, + restriction, + returningColumns + ); + } + + public AbstractUpdateOrDeleteStatement( + Map cteStatements, + NamedTableReference targetTable, + FromClause fromClause, + Predicate restriction, + List returningColumns) { + super( cteStatements, targetTable, returningColumns ); + this.fromClause = fromClause; + this.restriction = restriction; + } + + public FromClause getFromClause() { + return fromClause; + } + + public Predicate getRestriction() { + return restriction; + } +} diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/delete/DeleteStatement.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/delete/DeleteStatement.java index 3cddb5ae9c..040479fe41 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/delete/DeleteStatement.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/delete/DeleteStatement.java @@ -6,16 +6,16 @@ */ package org.hibernate.sql.ast.tree.delete; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import org.hibernate.sql.ast.SqlAstWalker; import org.hibernate.sql.ast.spi.SqlAstHelper; -import org.hibernate.sql.ast.tree.AbstractMutationStatement; +import org.hibernate.sql.ast.tree.AbstractUpdateOrDeleteStatement; import org.hibernate.sql.ast.tree.cte.CteContainer; import org.hibernate.sql.ast.tree.cte.CteStatement; import org.hibernate.sql.ast.tree.expression.ColumnReference; +import org.hibernate.sql.ast.tree.from.FromClause; import org.hibernate.sql.ast.tree.from.NamedTableReference; import org.hibernate.sql.ast.tree.predicate.Junction; import org.hibernate.sql.ast.tree.predicate.Predicate; @@ -23,32 +23,43 @@ import org.hibernate.sql.ast.tree.predicate.Predicate; /** * @author Steve Ebersole */ -public class DeleteStatement extends AbstractMutationStatement { +public class DeleteStatement extends AbstractUpdateOrDeleteStatement { public static final String DEFAULT_ALIAS = "to_delete_"; - private final Predicate restriction; public DeleteStatement(NamedTableReference targetTable, Predicate restriction) { - super( targetTable ); - this.restriction = restriction; + this( targetTable, new FromClause(), restriction ); } public DeleteStatement( NamedTableReference targetTable, Predicate restriction, List returningColumns) { - super( new LinkedHashMap<>(), targetTable, returningColumns ); - this.restriction = restriction; + this( targetTable, new FromClause(), restriction, returningColumns ); + } + + public DeleteStatement(NamedTableReference targetTable, FromClause fromClause, Predicate restriction) { + super( targetTable, fromClause, restriction ); + } + + public DeleteStatement( + NamedTableReference targetTable, + FromClause fromClause, + Predicate restriction, + List returningColumns) { + super( targetTable, fromClause, restriction, returningColumns ); } public DeleteStatement( CteContainer cteContainer, NamedTableReference targetTable, + FromClause fromClause, Predicate restriction, List returningColumns) { this( cteContainer.getCteStatements(), targetTable, + fromClause, restriction, returningColumns ); @@ -57,14 +68,10 @@ public class DeleteStatement extends AbstractMutationStatement { public DeleteStatement( Map cteStatements, NamedTableReference targetTable, + FromClause fromClause, Predicate restriction, List returningColumns) { - super( cteStatements, targetTable, returningColumns ); - this.restriction = restriction; - } - - public Predicate getRestriction() { - return restriction; + super( cteStatements, targetTable, fromClause, restriction, returningColumns ); } public static class DeleteStatementBuilder { @@ -88,6 +95,7 @@ public class DeleteStatement extends AbstractMutationStatement { public DeleteStatement createDeleteStatement() { return new DeleteStatement( targetTable, + new FromClause(), restriction != null ? restriction : new Junction( Junction.Nature.CONJUNCTION ) ); } diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/update/UpdateStatement.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/update/UpdateStatement.java index 90643f2e62..02f594a663 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/update/UpdateStatement.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/update/UpdateStatement.java @@ -14,26 +14,25 @@ import java.util.Map; import org.hibernate.sql.ast.SqlAstWalker; import org.hibernate.sql.ast.spi.SqlAstTreeHelper; import org.hibernate.sql.ast.tree.AbstractMutationStatement; +import org.hibernate.sql.ast.tree.AbstractUpdateOrDeleteStatement; import org.hibernate.sql.ast.tree.cte.CteContainer; import org.hibernate.sql.ast.tree.cte.CteStatement; import org.hibernate.sql.ast.tree.expression.ColumnReference; +import org.hibernate.sql.ast.tree.from.FromClause; import org.hibernate.sql.ast.tree.from.NamedTableReference; import org.hibernate.sql.ast.tree.predicate.Predicate; /** * @author Steve Ebersole */ -public class UpdateStatement extends AbstractMutationStatement { +public class UpdateStatement extends AbstractUpdateOrDeleteStatement { private final List assignments; - private final Predicate restriction; public UpdateStatement( NamedTableReference targetTable, List assignments, Predicate restriction) { - super( targetTable ); - this.assignments = assignments; - this.restriction = restriction; + this( targetTable, new FromClause(), assignments, restriction ); } public UpdateStatement( @@ -41,20 +40,39 @@ public class UpdateStatement extends AbstractMutationStatement { List assignments, Predicate restriction, List returningColumns) { - super( new LinkedHashMap<>(), targetTable, returningColumns ); + this( targetTable, new FromClause(), assignments, restriction, returningColumns ); + } + + public UpdateStatement( + NamedTableReference targetTable, + FromClause fromClause, + List assignments, + Predicate restriction) { + super( targetTable, fromClause, restriction ); + this.assignments = assignments; + } + + public UpdateStatement( + NamedTableReference targetTable, + FromClause fromClause, + List assignments, + Predicate restriction, + List returningColumns) { + super( targetTable, fromClause, restriction, returningColumns ); this.assignments = assignments; - this.restriction = restriction; } public UpdateStatement( CteContainer cteContainer, NamedTableReference targetTable, + FromClause fromClause, List assignments, Predicate restriction, List returningColumns) { this( cteContainer.getCteStatements(), targetTable, + fromClause, assignments, restriction, returningColumns @@ -64,22 +82,18 @@ public class UpdateStatement extends AbstractMutationStatement { public UpdateStatement( Map cteStatements, NamedTableReference targetTable, + FromClause fromClause, List assignments, Predicate restriction, List returningColumns) { - super( cteStatements, targetTable, returningColumns ); + super( cteStatements, targetTable, fromClause, restriction, returningColumns ); this.assignments = assignments; - this.restriction = restriction; } public List getAssignments() { return assignments; } - public Predicate getRestriction() { - return restriction; - } - public static class UpdateStatementBuilder { private final NamedTableReference targetTableRef; private List assignments; @@ -120,7 +134,7 @@ public class UpdateStatement extends AbstractMutationStatement { return null; } - return new UpdateStatement( targetTableRef, assignments, restriction ); + return new UpdateStatement( targetTableRef, new FromClause(), assignments, restriction ); } } diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/notfound/MutationQueriesAndNotFoundActionTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/notfound/MutationQueriesAndNotFoundActionTest.java index c299537a15..5e52e31554 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/notfound/MutationQueriesAndNotFoundActionTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/notfound/MutationQueriesAndNotFoundActionTest.java @@ -28,18 +28,18 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; @SessionFactory @JiraKey("HHH-16878") public class MutationQueriesAndNotFoundActionTest { - private static final User user1 = new User( 1l, "test 1" ); @BeforeEach public void setUp(SessionFactoryScope scope) { scope.inTransaction( session -> { - User user2 = new User( 2l, "test 2" ); + final User user1 = new User( 1L, "test 1" ); + final User user2 = new User( 2L, "test 2" ); session.persist( user1 ); session.persist( user2 ); - session.persist( new Comment( 1l, "example 1", user1 ) ); - session.persist( new Comment( 2l, "example 2", user2 ) ); - session.persist( new Comment( 3l, "example 3", user1 ) ); + session.persist( new Comment( 1L, "example 1", user1 ) ); + session.persist( new Comment( 2L, "example 2", user2 ) ); + session.persist( new Comment( 3L, "example 3", user1 ) ); } ); } @@ -59,9 +59,9 @@ public class MutationQueriesAndNotFoundActionTest { scope.inTransaction( session -> { int affectedComments = session.createMutationQuery( - "UPDATE Comment c SET c.text = :text WHERE c.user = :user" ) + "update Comment c set c.text = :text where c.user = :user" ) .setParameter( "text", "updated" ) - .setParameter( "user", user1 ) + .setParameter( "user", session.getReference( User.class, 1L ) ) .executeUpdate(); assertThat( affectedComments ).isEqualTo( 2 ); @@ -69,12 +69,53 @@ public class MutationQueriesAndNotFoundActionTest { ); } + @Test + public void testUpdateWithImplicitJoin(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + int affectedComments = session.createMutationQuery( "update Comment c set c.text = :text where c.user.name = :userName" ) + .setParameter( "text", "updated" ) + .setParameter( "userName", "test 1" ) + .executeUpdate(); + + assertThat( affectedComments ).isEqualTo( 2 ); + } + ); + } + + @Test + public void testUpdateSet(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + int affectedComments = session.createMutationQuery( + "update Comment c set c.user = :user" ) + .setParameter( "user", session.getReference( User.class, 2L ) ) + .executeUpdate(); + + assertThat( affectedComments ).isEqualTo( 3 ); + } + ); + } + @Test public void testDelete(SessionFactoryScope scope) { scope.inTransaction( session -> { int affectedComments = session.createMutationQuery( "delete from Comment c where c.user = :user" ) - .setParameter( "user", user1 ) + .setParameter( "user", session.getReference( User.class, 1L ) ) + .executeUpdate(); + + assertThat( affectedComments ).isEqualTo( 2 ); + } + ); + } + + @Test + public void testDeleteWithImplicitJoin(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + int affectedComments = session.createMutationQuery( "delete from Comment c where c.user.name = :userName" ) + .setParameter( "userName", "test 1" ) .executeUpdate(); assertThat( affectedComments ).isEqualTo( 2 );