HHH-16878 Add support for joins in SQL DML AST

This commit is contained in:
Christian Beikov 2023-07-18 12:09:15 +02:00 committed by Andrea Boriero
parent 226d0f956f
commit 544e9a3fb3
6 changed files with 237 additions and 41 deletions

View File

@ -338,6 +338,7 @@ import org.hibernate.sql.ast.tree.expression.TrimSpecification;
import org.hibernate.sql.ast.tree.expression.UnaryOperation; import org.hibernate.sql.ast.tree.expression.UnaryOperation;
import org.hibernate.sql.ast.tree.from.CorrelatedPluralTableGroup; import org.hibernate.sql.ast.tree.from.CorrelatedPluralTableGroup;
import org.hibernate.sql.ast.tree.from.CorrelatedTableGroup; import org.hibernate.sql.ast.tree.from.CorrelatedTableGroup;
import org.hibernate.sql.ast.tree.from.FromClause;
import org.hibernate.sql.ast.tree.from.NamedTableReference; import org.hibernate.sql.ast.tree.from.NamedTableReference;
import org.hibernate.sql.ast.tree.from.PluralTableGroup; import org.hibernate.sql.ast.tree.from.PluralTableGroup;
import org.hibernate.sql.ast.tree.from.QueryPartTableGroup; import org.hibernate.sql.ast.tree.from.QueryPartTableGroup;
@ -846,9 +847,12 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
suppliedPredicate = visitWhereClause( whereClause.getPredicate() ); suppliedPredicate = visitWhereClause( whereClause.getPredicate() );
} }
final FromClause fromClause = new FromClause();
fromClause.addRoot( rootTableGroup );
return new UpdateStatement( return new UpdateStatement(
cteContainer, cteContainer,
(NamedTableReference) rootTableGroup.getPrimaryTableReference(), (NamedTableReference) rootTableGroup.getPrimaryTableReference(),
fromClause,
assignments, assignments,
combinePredicates( suppliedPredicate, additionalRestrictions ), combinePredicates( suppliedPredicate, additionalRestrictions ),
Collections.emptyList() Collections.emptyList()
@ -868,7 +872,7 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
} }
else { else {
// otherwise... // otherwise...
throw new QueryException( "Manipulation query may only contain embeddable joins" ); throw new SemanticException( "Mutation query may not contain joins in the SET clause" );
} }
} }
@ -1117,9 +1121,12 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
suppliedPredicate = visitWhereClause( whereClause.getPredicate() ); suppliedPredicate = visitWhereClause( whereClause.getPredicate() );
} }
final FromClause fromClause = new FromClause();
fromClause.addRoot( rootTableGroup );
return new DeleteStatement( return new DeleteStatement(
cteContainer, cteContainer,
(NamedTableReference) rootTableGroup.getPrimaryTableReference(), (NamedTableReference) rootTableGroup.getPrimaryTableReference(),
fromClause,
combinePredicates( suppliedPredicate, additionalRestrictions ), combinePredicates( suppliedPredicate, additionalRestrictions ),
Collections.emptyList() Collections.emptyList()
); );
@ -3542,7 +3549,7 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
private <X> X prepareReusablePath(SqmPath<?> sqmPath, FromClauseIndex fromClauseIndex, Supplier<X> supplier) { private <X> X prepareReusablePath(SqmPath<?> sqmPath, FromClauseIndex fromClauseIndex, Supplier<X> supplier) {
final Consumer<TableGroup> implicitJoinChecker; final Consumer<TableGroup> implicitJoinChecker;
if ( getCurrentProcessingState() instanceof SqlAstQueryPartProcessingState ) { if ( getCurrentClauseStack().getCurrent() != Clause.SET_EXPRESSION ) {
implicitJoinChecker = tg -> {}; implicitJoinChecker = tg -> {};
} }
else { else {

View File

@ -61,6 +61,7 @@ import org.hibernate.persister.entity.Loadable;
import org.hibernate.persister.internal.SqlFragmentPredicate; import org.hibernate.persister.internal.SqlFragmentPredicate;
import org.hibernate.query.IllegalQueryOperationException; import org.hibernate.query.IllegalQueryOperationException;
import org.hibernate.query.ReturnableType; import org.hibernate.query.ReturnableType;
import org.hibernate.query.results.TableGroupImpl;
import org.hibernate.query.spi.Limit; import org.hibernate.query.spi.Limit;
import org.hibernate.query.spi.QueryOptions; import org.hibernate.query.spi.QueryOptions;
import org.hibernate.query.sqm.BinaryArithmeticOperator; import org.hibernate.query.sqm.BinaryArithmeticOperator;
@ -87,6 +88,7 @@ import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.SqlTreeCreationException; import org.hibernate.sql.ast.SqlTreeCreationException;
import org.hibernate.sql.ast.internal.ParameterMarkerStrategyStandard; import org.hibernate.sql.ast.internal.ParameterMarkerStrategyStandard;
import org.hibernate.sql.ast.tree.AbstractUpdateOrDeleteStatement;
import org.hibernate.sql.ast.tree.MutationStatement; import org.hibernate.sql.ast.tree.MutationStatement;
import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.Statement; import org.hibernate.sql.ast.tree.Statement;
@ -385,7 +387,7 @@ public abstract class AbstractSqlAstTranslator<T extends JdbcOperation> implemen
/** /**
* A lazy session implementation that is needed for rendering literals. * A lazy session implementation that is needed for rendering literals.
* Usually, only the {@link WrapperOptions} interface is needed, * Usually, only the {@link WrapperOptions} interface is needed,
* but for creating LOBs, it might be to have a full blown session. * but for creating LOBs, it might be to have a full-blown session.
*/ */
private static class LazySessionWrapperOptions extends AbstractDelegatingWrapperOptions { private static class LazySessionWrapperOptions extends AbstractDelegatingWrapperOptions {
@ -1047,7 +1049,12 @@ public abstract class AbstractSqlAstTranslator<T extends JdbcOperation> implemen
clauseStack.pop(); clauseStack.pop();
} }
if ( statement.getFromClause().hasJoins() ) {
visitWhereClause( determineWhereClauseRestrictionWithJoinEmulation( statement ) );
}
else {
visitWhereClause( statement.getRestriction() ); visitWhereClause( statement.getRestriction() );
}
visitReturningColumns( statement.getReturningColumns() ); visitReturningColumns( statement.getReturningColumns() );
} }
@ -1064,10 +1071,53 @@ public abstract class AbstractSqlAstTranslator<T extends JdbcOperation> implemen
} }
renderSetClause( statement, clauseStack ); renderSetClause( statement, clauseStack );
if ( statement.getFromClause().hasJoins() ) {
visitWhereClause( determineWhereClauseRestrictionWithJoinEmulation( statement ) );
}
else {
visitWhereClause( statement.getRestriction() ); visitWhereClause( statement.getRestriction() );
}
visitReturningColumns( statement.getReturningColumns() ); visitReturningColumns( statement.getReturningColumns() );
} }
protected Predicate determineWhereClauseRestrictionWithJoinEmulation(AbstractUpdateOrDeleteStatement statement) {
final QuerySpec querySpec = new QuerySpec( false );
querySpec.getSelectClause().addSqlSelection(
new SqlSelectionImpl( new QueryLiteral<>( 1, getIntegerType() ) )
);
for ( TableGroup root : statement.getFromClause().getRoots() ) {
if ( root.getPrimaryTableReference() == statement.getTargetTable() ) {
for ( TableReferenceJoin tableReferenceJoin : root.getTableReferenceJoins() ) {
assert tableReferenceJoin.getJoinType() == SqlAstJoinType.INNER;
querySpec.getFromClause().addRoot(
new TableGroupImpl(
root.getNavigablePath(),
null,
tableReferenceJoin.getJoinedTableReference(),
root.getModelPart()
)
);
querySpec.applyPredicate( tableReferenceJoin.getPredicate() );
}
for ( TableGroupJoin tableGroupJoin : root.getTableGroupJoins() ) {
assert tableGroupJoin.getJoinType() == SqlAstJoinType.INNER;
querySpec.getFromClause().addRoot( tableGroupJoin.getJoinedGroup() );
querySpec.applyPredicate( tableGroupJoin.getPredicate() );
}
for ( TableGroupJoin tableGroupJoin : root.getNestedTableGroupJoins() ) {
assert tableGroupJoin.getJoinType() == SqlAstJoinType.INNER;
querySpec.getFromClause().addRoot( tableGroupJoin.getJoinedGroup() );
querySpec.applyPredicate( tableGroupJoin.getPredicate() );
}
}
else {
querySpec.getFromClause().addRoot( root );
}
}
querySpec.applyPredicate( statement.getRestriction() );
return new ExistsPredicate( querySpec, false, getBooleanType() );
}
protected void renderSetClause(UpdateStatement statement, Stack<Clause> clauseStack) { protected void renderSetClause(UpdateStatement statement, Stack<Clause> clauseStack) {
appendSql( " set" ); appendSql( " set" );
char separator = ' '; char separator = ' ';

View File

@ -0,0 +1,76 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.sql.ast.tree;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.hibernate.sql.ast.tree.cte.CteContainer;
import org.hibernate.sql.ast.tree.cte.CteStatement;
import org.hibernate.sql.ast.tree.expression.ColumnReference;
import org.hibernate.sql.ast.tree.from.FromClause;
import org.hibernate.sql.ast.tree.from.NamedTableReference;
import org.hibernate.sql.ast.tree.predicate.Predicate;
public abstract class AbstractUpdateOrDeleteStatement extends AbstractMutationStatement {
private final FromClause fromClause;
private final Predicate restriction;
public AbstractUpdateOrDeleteStatement(
NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction) {
super( targetTable );
this.fromClause = fromClause;
this.restriction = restriction;
}
public AbstractUpdateOrDeleteStatement(
NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction,
List<ColumnReference> returningColumns) {
super( new LinkedHashMap<>(), targetTable, returningColumns );
this.fromClause = fromClause;
this.restriction = restriction;
}
public AbstractUpdateOrDeleteStatement(
CteContainer cteContainer,
NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction,
List<ColumnReference> returningColumns) {
this(
cteContainer.getCteStatements(),
targetTable,
fromClause,
restriction,
returningColumns
);
}
public AbstractUpdateOrDeleteStatement(
Map<String, CteStatement> cteStatements,
NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction,
List<ColumnReference> returningColumns) {
super( cteStatements, targetTable, returningColumns );
this.fromClause = fromClause;
this.restriction = restriction;
}
public FromClause getFromClause() {
return fromClause;
}
public Predicate getRestriction() {
return restriction;
}
}

View File

@ -6,16 +6,16 @@
*/ */
package org.hibernate.sql.ast.tree.delete; package org.hibernate.sql.ast.tree.delete;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import org.hibernate.sql.ast.SqlAstWalker; import org.hibernate.sql.ast.SqlAstWalker;
import org.hibernate.sql.ast.spi.SqlAstHelper; import org.hibernate.sql.ast.spi.SqlAstHelper;
import org.hibernate.sql.ast.tree.AbstractMutationStatement; import org.hibernate.sql.ast.tree.AbstractUpdateOrDeleteStatement;
import org.hibernate.sql.ast.tree.cte.CteContainer; import org.hibernate.sql.ast.tree.cte.CteContainer;
import org.hibernate.sql.ast.tree.cte.CteStatement; import org.hibernate.sql.ast.tree.cte.CteStatement;
import org.hibernate.sql.ast.tree.expression.ColumnReference; import org.hibernate.sql.ast.tree.expression.ColumnReference;
import org.hibernate.sql.ast.tree.from.FromClause;
import org.hibernate.sql.ast.tree.from.NamedTableReference; import org.hibernate.sql.ast.tree.from.NamedTableReference;
import org.hibernate.sql.ast.tree.predicate.Junction; import org.hibernate.sql.ast.tree.predicate.Junction;
import org.hibernate.sql.ast.tree.predicate.Predicate; import org.hibernate.sql.ast.tree.predicate.Predicate;
@ -23,32 +23,43 @@ import org.hibernate.sql.ast.tree.predicate.Predicate;
/** /**
* @author Steve Ebersole * @author Steve Ebersole
*/ */
public class DeleteStatement extends AbstractMutationStatement { public class DeleteStatement extends AbstractUpdateOrDeleteStatement {
public static final String DEFAULT_ALIAS = "to_delete_"; public static final String DEFAULT_ALIAS = "to_delete_";
private final Predicate restriction;
public DeleteStatement(NamedTableReference targetTable, Predicate restriction) { public DeleteStatement(NamedTableReference targetTable, Predicate restriction) {
super( targetTable ); this( targetTable, new FromClause(), restriction );
this.restriction = restriction;
} }
public DeleteStatement( public DeleteStatement(
NamedTableReference targetTable, NamedTableReference targetTable,
Predicate restriction, Predicate restriction,
List<ColumnReference> returningColumns) { List<ColumnReference> returningColumns) {
super( new LinkedHashMap<>(), targetTable, returningColumns ); this( targetTable, new FromClause(), restriction, returningColumns );
this.restriction = restriction; }
public DeleteStatement(NamedTableReference targetTable, FromClause fromClause, Predicate restriction) {
super( targetTable, fromClause, restriction );
}
public DeleteStatement(
NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction,
List<ColumnReference> returningColumns) {
super( targetTable, fromClause, restriction, returningColumns );
} }
public DeleteStatement( public DeleteStatement(
CteContainer cteContainer, CteContainer cteContainer,
NamedTableReference targetTable, NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction, Predicate restriction,
List<ColumnReference> returningColumns) { List<ColumnReference> returningColumns) {
this( this(
cteContainer.getCteStatements(), cteContainer.getCteStatements(),
targetTable, targetTable,
fromClause,
restriction, restriction,
returningColumns returningColumns
); );
@ -57,14 +68,10 @@ public class DeleteStatement extends AbstractMutationStatement {
public DeleteStatement( public DeleteStatement(
Map<String, CteStatement> cteStatements, Map<String, CteStatement> cteStatements,
NamedTableReference targetTable, NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction, Predicate restriction,
List<ColumnReference> returningColumns) { List<ColumnReference> returningColumns) {
super( cteStatements, targetTable, returningColumns ); super( cteStatements, targetTable, fromClause, restriction, returningColumns );
this.restriction = restriction;
}
public Predicate getRestriction() {
return restriction;
} }
public static class DeleteStatementBuilder { public static class DeleteStatementBuilder {
@ -88,6 +95,7 @@ public class DeleteStatement extends AbstractMutationStatement {
public DeleteStatement createDeleteStatement() { public DeleteStatement createDeleteStatement() {
return new DeleteStatement( return new DeleteStatement(
targetTable, targetTable,
new FromClause(),
restriction != null ? restriction : new Junction( Junction.Nature.CONJUNCTION ) restriction != null ? restriction : new Junction( Junction.Nature.CONJUNCTION )
); );
} }

View File

@ -14,26 +14,25 @@ import java.util.Map;
import org.hibernate.sql.ast.SqlAstWalker; import org.hibernate.sql.ast.SqlAstWalker;
import org.hibernate.sql.ast.spi.SqlAstTreeHelper; import org.hibernate.sql.ast.spi.SqlAstTreeHelper;
import org.hibernate.sql.ast.tree.AbstractMutationStatement; import org.hibernate.sql.ast.tree.AbstractMutationStatement;
import org.hibernate.sql.ast.tree.AbstractUpdateOrDeleteStatement;
import org.hibernate.sql.ast.tree.cte.CteContainer; import org.hibernate.sql.ast.tree.cte.CteContainer;
import org.hibernate.sql.ast.tree.cte.CteStatement; import org.hibernate.sql.ast.tree.cte.CteStatement;
import org.hibernate.sql.ast.tree.expression.ColumnReference; import org.hibernate.sql.ast.tree.expression.ColumnReference;
import org.hibernate.sql.ast.tree.from.FromClause;
import org.hibernate.sql.ast.tree.from.NamedTableReference; import org.hibernate.sql.ast.tree.from.NamedTableReference;
import org.hibernate.sql.ast.tree.predicate.Predicate; import org.hibernate.sql.ast.tree.predicate.Predicate;
/** /**
* @author Steve Ebersole * @author Steve Ebersole
*/ */
public class UpdateStatement extends AbstractMutationStatement { public class UpdateStatement extends AbstractUpdateOrDeleteStatement {
private final List<Assignment> assignments; private final List<Assignment> assignments;
private final Predicate restriction;
public UpdateStatement( public UpdateStatement(
NamedTableReference targetTable, NamedTableReference targetTable,
List<Assignment> assignments, List<Assignment> assignments,
Predicate restriction) { Predicate restriction) {
super( targetTable ); this( targetTable, new FromClause(), assignments, restriction );
this.assignments = assignments;
this.restriction = restriction;
} }
public UpdateStatement( public UpdateStatement(
@ -41,20 +40,39 @@ public class UpdateStatement extends AbstractMutationStatement {
List<Assignment> assignments, List<Assignment> assignments,
Predicate restriction, Predicate restriction,
List<ColumnReference> returningColumns) { List<ColumnReference> returningColumns) {
super( new LinkedHashMap<>(), targetTable, returningColumns ); this( targetTable, new FromClause(), assignments, restriction, returningColumns );
}
public UpdateStatement(
NamedTableReference targetTable,
FromClause fromClause,
List<Assignment> assignments,
Predicate restriction) {
super( targetTable, fromClause, restriction );
this.assignments = assignments;
}
public UpdateStatement(
NamedTableReference targetTable,
FromClause fromClause,
List<Assignment> assignments,
Predicate restriction,
List<ColumnReference> returningColumns) {
super( targetTable, fromClause, restriction, returningColumns );
this.assignments = assignments; this.assignments = assignments;
this.restriction = restriction;
} }
public UpdateStatement( public UpdateStatement(
CteContainer cteContainer, CteContainer cteContainer,
NamedTableReference targetTable, NamedTableReference targetTable,
FromClause fromClause,
List<Assignment> assignments, List<Assignment> assignments,
Predicate restriction, Predicate restriction,
List<ColumnReference> returningColumns) { List<ColumnReference> returningColumns) {
this( this(
cteContainer.getCteStatements(), cteContainer.getCteStatements(),
targetTable, targetTable,
fromClause,
assignments, assignments,
restriction, restriction,
returningColumns returningColumns
@ -64,22 +82,18 @@ public class UpdateStatement extends AbstractMutationStatement {
public UpdateStatement( public UpdateStatement(
Map<String, CteStatement> cteStatements, Map<String, CteStatement> cteStatements,
NamedTableReference targetTable, NamedTableReference targetTable,
FromClause fromClause,
List<Assignment> assignments, List<Assignment> assignments,
Predicate restriction, Predicate restriction,
List<ColumnReference> returningColumns) { List<ColumnReference> returningColumns) {
super( cteStatements, targetTable, returningColumns ); super( cteStatements, targetTable, fromClause, restriction, returningColumns );
this.assignments = assignments; this.assignments = assignments;
this.restriction = restriction;
} }
public List<Assignment> getAssignments() { public List<Assignment> getAssignments() {
return assignments; return assignments;
} }
public Predicate getRestriction() {
return restriction;
}
public static class UpdateStatementBuilder { public static class UpdateStatementBuilder {
private final NamedTableReference targetTableRef; private final NamedTableReference targetTableRef;
private List<Assignment> assignments; private List<Assignment> assignments;
@ -120,7 +134,7 @@ public class UpdateStatement extends AbstractMutationStatement {
return null; return null;
} }
return new UpdateStatement( targetTableRef, assignments, restriction ); return new UpdateStatement( targetTableRef, new FromClause(), assignments, restriction );
} }
} }

View File

@ -28,18 +28,18 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
@SessionFactory @SessionFactory
@JiraKey("HHH-16878") @JiraKey("HHH-16878")
public class MutationQueriesAndNotFoundActionTest { public class MutationQueriesAndNotFoundActionTest {
private static final User user1 = new User( 1l, "test 1" );
@BeforeEach @BeforeEach
public void setUp(SessionFactoryScope scope) { public void setUp(SessionFactoryScope scope) {
scope.inTransaction( scope.inTransaction(
session -> { session -> {
User user2 = new User( 2l, "test 2" ); final User user1 = new User( 1L, "test 1" );
final User user2 = new User( 2L, "test 2" );
session.persist( user1 ); session.persist( user1 );
session.persist( user2 ); session.persist( user2 );
session.persist( new Comment( 1l, "example 1", user1 ) ); session.persist( new Comment( 1L, "example 1", user1 ) );
session.persist( new Comment( 2l, "example 2", user2 ) ); session.persist( new Comment( 2L, "example 2", user2 ) );
session.persist( new Comment( 3l, "example 3", user1 ) ); session.persist( new Comment( 3L, "example 3", user1 ) );
} }
); );
} }
@ -59,9 +59,9 @@ public class MutationQueriesAndNotFoundActionTest {
scope.inTransaction( scope.inTransaction(
session -> { session -> {
int affectedComments = session.createMutationQuery( int affectedComments = session.createMutationQuery(
"UPDATE Comment c SET c.text = :text WHERE c.user = :user" ) "update Comment c set c.text = :text where c.user = :user" )
.setParameter( "text", "updated" ) .setParameter( "text", "updated" )
.setParameter( "user", user1 ) .setParameter( "user", session.getReference( User.class, 1L ) )
.executeUpdate(); .executeUpdate();
assertThat( affectedComments ).isEqualTo( 2 ); assertThat( affectedComments ).isEqualTo( 2 );
@ -69,12 +69,53 @@ public class MutationQueriesAndNotFoundActionTest {
); );
} }
@Test
public void testUpdateWithImplicitJoin(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
int affectedComments = session.createMutationQuery( "update Comment c set c.text = :text where c.user.name = :userName" )
.setParameter( "text", "updated" )
.setParameter( "userName", "test 1" )
.executeUpdate();
assertThat( affectedComments ).isEqualTo( 2 );
}
);
}
@Test
public void testUpdateSet(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
int affectedComments = session.createMutationQuery(
"update Comment c set c.user = :user" )
.setParameter( "user", session.getReference( User.class, 2L ) )
.executeUpdate();
assertThat( affectedComments ).isEqualTo( 3 );
}
);
}
@Test @Test
public void testDelete(SessionFactoryScope scope) { public void testDelete(SessionFactoryScope scope) {
scope.inTransaction( scope.inTransaction(
session -> { session -> {
int affectedComments = session.createMutationQuery( "delete from Comment c where c.user = :user" ) int affectedComments = session.createMutationQuery( "delete from Comment c where c.user = :user" )
.setParameter( "user", user1 ) .setParameter( "user", session.getReference( User.class, 1L ) )
.executeUpdate();
assertThat( affectedComments ).isEqualTo( 2 );
}
);
}
@Test
public void testDeleteWithImplicitJoin(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
int affectedComments = session.createMutationQuery( "delete from Comment c where c.user.name = :userName" )
.setParameter( "userName", "test 1" )
.executeUpdate(); .executeUpdate();
assertThat( affectedComments ).isEqualTo( 2 ); assertThat( affectedComments ).isEqualTo( 2 );