HHH-16878 Add support for joins in SQL DML AST

This commit is contained in:
Christian Beikov 2023-07-18 12:09:15 +02:00 committed by Andrea Boriero
parent d204916c86
commit 6808f451ba
6 changed files with 236 additions and 40 deletions

View File

@ -340,6 +340,7 @@ import org.hibernate.sql.ast.tree.expression.TrimSpecification;
import org.hibernate.sql.ast.tree.expression.UnaryOperation;
import org.hibernate.sql.ast.tree.from.CorrelatedPluralTableGroup;
import org.hibernate.sql.ast.tree.from.CorrelatedTableGroup;
import org.hibernate.sql.ast.tree.from.FromClause;
import org.hibernate.sql.ast.tree.from.NamedTableReference;
import org.hibernate.sql.ast.tree.from.PluralTableGroup;
import org.hibernate.sql.ast.tree.from.QueryPartTableGroup;
@ -848,9 +849,12 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
suppliedPredicate = visitWhereClause( whereClause.getPredicate() );
}
final FromClause fromClause = new FromClause();
fromClause.addRoot( rootTableGroup );
return new UpdateStatement(
cteContainer,
(NamedTableReference) rootTableGroup.getPrimaryTableReference(),
fromClause,
assignments,
combinePredicates( suppliedPredicate, additionalRestrictions ),
Collections.emptyList()
@ -870,7 +874,7 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
}
else {
// otherwise...
throw new SemanticException( "Mutation query may only contain embeddable joins" );
throw new SemanticException( "Mutation query may not contain joins in the SET clause" );
}
}
@ -1119,9 +1123,12 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
suppliedPredicate = visitWhereClause( whereClause.getPredicate() );
}
final FromClause fromClause = new FromClause();
fromClause.addRoot( rootTableGroup );
return new DeleteStatement(
cteContainer,
(NamedTableReference) rootTableGroup.getPrimaryTableReference(),
fromClause,
combinePredicates( suppliedPredicate, additionalRestrictions ),
Collections.emptyList()
);
@ -3567,7 +3574,7 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
private <X> X prepareReusablePath(SqmPath<?> sqmPath, FromClauseIndex fromClauseIndex, Supplier<X> supplier) {
final Consumer<TableGroup> implicitJoinChecker;
if ( getCurrentProcessingState() instanceof SqlAstQueryPartProcessingState ) {
if ( getCurrentClauseStack().getCurrent() != Clause.SET_EXPRESSION ) {
implicitJoinChecker = tg -> {};
}
else {

View File

@ -62,6 +62,7 @@ import org.hibernate.persister.internal.SqlFragmentPredicate;
import org.hibernate.query.IllegalQueryOperationException;
import org.hibernate.query.ReturnableType;
import org.hibernate.query.SortDirection;
import org.hibernate.query.results.TableGroupImpl;
import org.hibernate.query.spi.Limit;
import org.hibernate.query.spi.QueryOptions;
import org.hibernate.query.sqm.BinaryArithmeticOperator;
@ -87,6 +88,7 @@ import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.SqlTreeCreationException;
import org.hibernate.sql.ast.internal.ParameterMarkerStrategyStandard;
import org.hibernate.sql.ast.tree.AbstractUpdateOrDeleteStatement;
import org.hibernate.sql.ast.tree.MutationStatement;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.Statement;
@ -1052,7 +1054,12 @@ public abstract class AbstractSqlAstTranslator<T extends JdbcOperation> implemen
clauseStack.pop();
}
if ( statement.getFromClause().hasJoins() ) {
visitWhereClause( determineWhereClauseRestrictionWithJoinEmulation( statement ) );
}
else {
visitWhereClause( statement.getRestriction() );
}
visitReturningColumns( statement.getReturningColumns() );
}
@ -1069,10 +1076,53 @@ public abstract class AbstractSqlAstTranslator<T extends JdbcOperation> implemen
}
renderSetClause( statement, clauseStack );
if ( statement.getFromClause().hasJoins() ) {
visitWhereClause( determineWhereClauseRestrictionWithJoinEmulation( statement ) );
}
else {
visitWhereClause( statement.getRestriction() );
}
visitReturningColumns( statement.getReturningColumns() );
}
protected Predicate determineWhereClauseRestrictionWithJoinEmulation(AbstractUpdateOrDeleteStatement statement) {
final QuerySpec querySpec = new QuerySpec( false );
querySpec.getSelectClause().addSqlSelection(
new SqlSelectionImpl( new QueryLiteral<>( 1, getIntegerType() ) )
);
for ( TableGroup root : statement.getFromClause().getRoots() ) {
if ( root.getPrimaryTableReference() == statement.getTargetTable() ) {
for ( TableReferenceJoin tableReferenceJoin : root.getTableReferenceJoins() ) {
assert tableReferenceJoin.getJoinType() == SqlAstJoinType.INNER;
querySpec.getFromClause().addRoot(
new TableGroupImpl(
root.getNavigablePath(),
null,
tableReferenceJoin.getJoinedTableReference(),
root.getModelPart()
)
);
querySpec.applyPredicate( tableReferenceJoin.getPredicate() );
}
for ( TableGroupJoin tableGroupJoin : root.getTableGroupJoins() ) {
assert tableGroupJoin.getJoinType() == SqlAstJoinType.INNER;
querySpec.getFromClause().addRoot( tableGroupJoin.getJoinedGroup() );
querySpec.applyPredicate( tableGroupJoin.getPredicate() );
}
for ( TableGroupJoin tableGroupJoin : root.getNestedTableGroupJoins() ) {
assert tableGroupJoin.getJoinType() == SqlAstJoinType.INNER;
querySpec.getFromClause().addRoot( tableGroupJoin.getJoinedGroup() );
querySpec.applyPredicate( tableGroupJoin.getPredicate() );
}
}
else {
querySpec.getFromClause().addRoot( root );
}
}
querySpec.applyPredicate( statement.getRestriction() );
return new ExistsPredicate( querySpec, false, getBooleanType() );
}
protected void renderSetClause(UpdateStatement statement, Stack<Clause> clauseStack) {
appendSql( " set" );
char separator = ' ';

View File

@ -0,0 +1,76 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.sql.ast.tree;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.hibernate.sql.ast.tree.cte.CteContainer;
import org.hibernate.sql.ast.tree.cte.CteStatement;
import org.hibernate.sql.ast.tree.expression.ColumnReference;
import org.hibernate.sql.ast.tree.from.FromClause;
import org.hibernate.sql.ast.tree.from.NamedTableReference;
import org.hibernate.sql.ast.tree.predicate.Predicate;
public abstract class AbstractUpdateOrDeleteStatement extends AbstractMutationStatement {
private final FromClause fromClause;
private final Predicate restriction;
public AbstractUpdateOrDeleteStatement(
NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction) {
super( targetTable );
this.fromClause = fromClause;
this.restriction = restriction;
}
public AbstractUpdateOrDeleteStatement(
NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction,
List<ColumnReference> returningColumns) {
super( new LinkedHashMap<>(), targetTable, returningColumns );
this.fromClause = fromClause;
this.restriction = restriction;
}
public AbstractUpdateOrDeleteStatement(
CteContainer cteContainer,
NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction,
List<ColumnReference> returningColumns) {
this(
cteContainer.getCteStatements(),
targetTable,
fromClause,
restriction,
returningColumns
);
}
public AbstractUpdateOrDeleteStatement(
Map<String, CteStatement> cteStatements,
NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction,
List<ColumnReference> returningColumns) {
super( cteStatements, targetTable, returningColumns );
this.fromClause = fromClause;
this.restriction = restriction;
}
public FromClause getFromClause() {
return fromClause;
}
public Predicate getRestriction() {
return restriction;
}
}

View File

@ -6,16 +6,16 @@
*/
package org.hibernate.sql.ast.tree.delete;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.hibernate.sql.ast.SqlAstWalker;
import org.hibernate.sql.ast.spi.SqlAstHelper;
import org.hibernate.sql.ast.tree.AbstractMutationStatement;
import org.hibernate.sql.ast.tree.AbstractUpdateOrDeleteStatement;
import org.hibernate.sql.ast.tree.cte.CteContainer;
import org.hibernate.sql.ast.tree.cte.CteStatement;
import org.hibernate.sql.ast.tree.expression.ColumnReference;
import org.hibernate.sql.ast.tree.from.FromClause;
import org.hibernate.sql.ast.tree.from.NamedTableReference;
import org.hibernate.sql.ast.tree.predicate.Junction;
import org.hibernate.sql.ast.tree.predicate.Predicate;
@ -23,32 +23,43 @@ import org.hibernate.sql.ast.tree.predicate.Predicate;
/**
* @author Steve Ebersole
*/
public class DeleteStatement extends AbstractMutationStatement {
public class DeleteStatement extends AbstractUpdateOrDeleteStatement {
public static final String DEFAULT_ALIAS = "to_delete_";
private final Predicate restriction;
public DeleteStatement(NamedTableReference targetTable, Predicate restriction) {
super( targetTable );
this.restriction = restriction;
this( targetTable, new FromClause(), restriction );
}
public DeleteStatement(
NamedTableReference targetTable,
Predicate restriction,
List<ColumnReference> returningColumns) {
super( new LinkedHashMap<>(), targetTable, returningColumns );
this.restriction = restriction;
this( targetTable, new FromClause(), restriction, returningColumns );
}
public DeleteStatement(NamedTableReference targetTable, FromClause fromClause, Predicate restriction) {
super( targetTable, fromClause, restriction );
}
public DeleteStatement(
NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction,
List<ColumnReference> returningColumns) {
super( targetTable, fromClause, restriction, returningColumns );
}
public DeleteStatement(
CteContainer cteContainer,
NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction,
List<ColumnReference> returningColumns) {
this(
cteContainer.getCteStatements(),
targetTable,
fromClause,
restriction,
returningColumns
);
@ -57,14 +68,10 @@ public class DeleteStatement extends AbstractMutationStatement {
public DeleteStatement(
Map<String, CteStatement> cteStatements,
NamedTableReference targetTable,
FromClause fromClause,
Predicate restriction,
List<ColumnReference> returningColumns) {
super( cteStatements, targetTable, returningColumns );
this.restriction = restriction;
}
public Predicate getRestriction() {
return restriction;
super( cteStatements, targetTable, fromClause, restriction, returningColumns );
}
public static class DeleteStatementBuilder {
@ -88,6 +95,7 @@ public class DeleteStatement extends AbstractMutationStatement {
public DeleteStatement createDeleteStatement() {
return new DeleteStatement(
targetTable,
new FromClause(),
restriction != null ? restriction : new Junction( Junction.Nature.CONJUNCTION )
);
}

View File

@ -14,26 +14,25 @@ import java.util.Map;
import org.hibernate.sql.ast.SqlAstWalker;
import org.hibernate.sql.ast.spi.SqlAstTreeHelper;
import org.hibernate.sql.ast.tree.AbstractMutationStatement;
import org.hibernate.sql.ast.tree.AbstractUpdateOrDeleteStatement;
import org.hibernate.sql.ast.tree.cte.CteContainer;
import org.hibernate.sql.ast.tree.cte.CteStatement;
import org.hibernate.sql.ast.tree.expression.ColumnReference;
import org.hibernate.sql.ast.tree.from.FromClause;
import org.hibernate.sql.ast.tree.from.NamedTableReference;
import org.hibernate.sql.ast.tree.predicate.Predicate;
/**
* @author Steve Ebersole
*/
public class UpdateStatement extends AbstractMutationStatement {
public class UpdateStatement extends AbstractUpdateOrDeleteStatement {
private final List<Assignment> assignments;
private final Predicate restriction;
public UpdateStatement(
NamedTableReference targetTable,
List<Assignment> assignments,
Predicate restriction) {
super( targetTable );
this.assignments = assignments;
this.restriction = restriction;
this( targetTable, new FromClause(), assignments, restriction );
}
public UpdateStatement(
@ -41,20 +40,39 @@ public class UpdateStatement extends AbstractMutationStatement {
List<Assignment> assignments,
Predicate restriction,
List<ColumnReference> returningColumns) {
super( new LinkedHashMap<>(), targetTable, returningColumns );
this( targetTable, new FromClause(), assignments, restriction, returningColumns );
}
public UpdateStatement(
NamedTableReference targetTable,
FromClause fromClause,
List<Assignment> assignments,
Predicate restriction) {
super( targetTable, fromClause, restriction );
this.assignments = assignments;
}
public UpdateStatement(
NamedTableReference targetTable,
FromClause fromClause,
List<Assignment> assignments,
Predicate restriction,
List<ColumnReference> returningColumns) {
super( targetTable, fromClause, restriction, returningColumns );
this.assignments = assignments;
this.restriction = restriction;
}
public UpdateStatement(
CteContainer cteContainer,
NamedTableReference targetTable,
FromClause fromClause,
List<Assignment> assignments,
Predicate restriction,
List<ColumnReference> returningColumns) {
this(
cteContainer.getCteStatements(),
targetTable,
fromClause,
assignments,
restriction,
returningColumns
@ -64,22 +82,18 @@ public class UpdateStatement extends AbstractMutationStatement {
public UpdateStatement(
Map<String, CteStatement> cteStatements,
NamedTableReference targetTable,
FromClause fromClause,
List<Assignment> assignments,
Predicate restriction,
List<ColumnReference> returningColumns) {
super( cteStatements, targetTable, returningColumns );
super( cteStatements, targetTable, fromClause, restriction, returningColumns );
this.assignments = assignments;
this.restriction = restriction;
}
public List<Assignment> getAssignments() {
return assignments;
}
public Predicate getRestriction() {
return restriction;
}
public static class UpdateStatementBuilder {
private final NamedTableReference targetTableRef;
private List<Assignment> assignments;
@ -120,7 +134,7 @@ public class UpdateStatement extends AbstractMutationStatement {
return null;
}
return new UpdateStatement( targetTableRef, assignments, restriction );
return new UpdateStatement( targetTableRef, new FromClause(), assignments, restriction );
}
}

View File

@ -28,18 +28,18 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
@SessionFactory
@JiraKey("HHH-16878")
public class MutationQueriesAndNotFoundActionTest {
private static final User user1 = new User( 1l, "test 1" );
@BeforeEach
public void setUp(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
User user2 = new User( 2l, "test 2" );
final User user1 = new User( 1L, "test 1" );
final User user2 = new User( 2L, "test 2" );
session.persist( user1 );
session.persist( user2 );
session.persist( new Comment( 1l, "example 1", user1 ) );
session.persist( new Comment( 2l, "example 2", user2 ) );
session.persist( new Comment( 3l, "example 3", user1 ) );
session.persist( new Comment( 1L, "example 1", user1 ) );
session.persist( new Comment( 2L, "example 2", user2 ) );
session.persist( new Comment( 3L, "example 3", user1 ) );
}
);
}
@ -59,9 +59,9 @@ public class MutationQueriesAndNotFoundActionTest {
scope.inTransaction(
session -> {
int affectedComments = session.createMutationQuery(
"UPDATE Comment c SET c.text = :text WHERE c.user = :user" )
"update Comment c set c.text = :text where c.user = :user" )
.setParameter( "text", "updated" )
.setParameter( "user", user1 )
.setParameter( "user", session.getReference( User.class, 1L ) )
.executeUpdate();
assertThat( affectedComments ).isEqualTo( 2 );
@ -69,12 +69,53 @@ public class MutationQueriesAndNotFoundActionTest {
);
}
@Test
public void testUpdateWithImplicitJoin(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
int affectedComments = session.createMutationQuery( "update Comment c set c.text = :text where c.user.name = :userName" )
.setParameter( "text", "updated" )
.setParameter( "userName", "test 1" )
.executeUpdate();
assertThat( affectedComments ).isEqualTo( 2 );
}
);
}
@Test
public void testUpdateSet(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
int affectedComments = session.createMutationQuery(
"update Comment c set c.user = :user" )
.setParameter( "user", session.getReference( User.class, 2L ) )
.executeUpdate();
assertThat( affectedComments ).isEqualTo( 3 );
}
);
}
@Test
public void testDelete(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
int affectedComments = session.createMutationQuery( "delete from Comment c where c.user = :user" )
.setParameter( "user", user1 )
.setParameter( "user", session.getReference( User.class, 1L ) )
.executeUpdate();
assertThat( affectedComments ).isEqualTo( 2 );
}
);
}
@Test
public void testDeleteWithImplicitJoin(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
int affectedComments = session.createMutationQuery( "delete from Comment c where c.user.name = :userName" )
.setParameter( "userName", "test 1" )
.executeUpdate();
assertThat( affectedComments ).isEqualTo( 2 );