HHH-17410 Support creating count query from existing query

This commit is contained in:
Christian Beikov 2023-11-10 11:41:35 +01:00
parent 7a5219b52a
commit e9d08ca18e
10 changed files with 364 additions and 7 deletions

View File

@ -400,4 +400,9 @@ public abstract class CriteriaDefinition<R>
public <X> JpaRoot<X> from(JpaCteCriteria<X> cte) {
return query.from(cte);
}
@Override
public JpaCriteriaQuery<Long> createCountQuery() {
return query.createCountQuery();
}
}

View File

@ -26,6 +26,13 @@ import org.hibernate.query.sqm.FetchClauseType;
*/
public interface JpaCriteriaQuery<T> extends CriteriaQuery<T>, JpaQueryableCriteria<T>, JpaSelectCriteria<T> {
/**
* Wraps this query in a subquery and returns a count query based on that subquery in the from clause.
*
* @since 6.4
*/
JpaCriteriaQuery<Long> createCountQuery();
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Limit/Offset/Fetch clause

View File

@ -6,6 +6,7 @@
*/
package org.hibernate.query.sqm.tree;
import org.hibernate.Incubating;
import org.hibernate.query.sqm.internal.NoParamSqmCopyContext;
import org.hibernate.query.sqm.internal.SimpleSqmCopyContext;
@ -18,6 +19,16 @@ public interface SqmCopyContext {
<T> T registerCopy(T original, T copy);
/**
* Returns whether the {@code fetch} flag for attribute joins should be copied over.
*
* @since 6.4
*/
@Incubating
default boolean copyFetchedFlag() {
return true;
}
static SqmCopyContext simpleContext() {
return new SimpleSqmCopyContext();
}

View File

@ -65,7 +65,7 @@ public class SqmBagJoin<O, E> extends AbstractSqmPluralJoin<O,Collection<E>, E>
getAttribute(),
getExplicitAlias(),
getSqmJoinType(),
isFetched(),
context.copyFetchedFlag() && isFetched(),
nodeBuilder()
)
);

View File

@ -68,7 +68,7 @@ public class SqmListJoin<O,E>
getAttribute(),
getExplicitAlias(),
getSqmJoinType(),
isFetched(),
context.copyFetchedFlag() && isFetched(),
nodeBuilder()
)
);

View File

@ -8,7 +8,6 @@ package org.hibernate.query.sqm.tree.domain;
import java.util.Map;
import jakarta.persistence.criteria.Expression;
import jakarta.persistence.criteria.Path;
import jakarta.persistence.criteria.Predicate;
import org.hibernate.metamodel.model.domain.EntityDomainType;
@ -67,7 +66,7 @@ public class SqmMapJoin<O, K, V>
getAttribute(),
getExplicitAlias(),
getSqmJoinType(),
isFetched(),
context.copyFetchedFlag() && isFetched(),
nodeBuilder()
)
);

View File

@ -67,7 +67,7 @@ public class SqmSetJoin<O, E>
getModel(),
getExplicitAlias(),
getSqmJoinType(),
isFetched(),
context.copyFetchedFlag() && isFetched(),
nodeBuilder()
)
);

View File

@ -11,7 +11,6 @@ import java.util.Locale;
import org.hibernate.metamodel.model.domain.EntityDomainType;
import org.hibernate.metamodel.model.domain.SingularPersistentAttribute;
import org.hibernate.query.sqm.SemanticQueryWalker;
import org.hibernate.query.sqm.SqmPathSource;
import org.hibernate.spi.NavigablePath;
import org.hibernate.query.hql.spi.SqmCreationProcessingState;
import org.hibernate.query.sqm.NodeBuilder;
@ -76,7 +75,7 @@ public class SqmSingularJoin<O,T> extends AbstractSqmAttributeJoin<O,T> {
getAttribute(),
getExplicitAlias(),
getSqmJoinType(),
isFetched(),
context.copyFetchedFlag() && isFetched(),
nodeBuilder()
)
);

View File

@ -6,6 +6,7 @@
*/
package org.hibernate.query.sqm.tree.select;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
@ -27,10 +28,12 @@ import org.hibernate.query.criteria.JpaSelection;
import org.hibernate.query.sqm.NodeBuilder;
import org.hibernate.query.sqm.SemanticQueryWalker;
import org.hibernate.query.sqm.SqmQuerySource;
import org.hibernate.query.sqm.internal.NoParamSqmCopyContext;
import org.hibernate.query.sqm.internal.SqmUtil;
import org.hibernate.query.sqm.tree.SqmCopyContext;
import org.hibernate.query.sqm.tree.SqmStatement;
import org.hibernate.query.sqm.tree.cte.SqmCteStatement;
import org.hibernate.query.sqm.tree.expression.SqmStar;
import org.hibernate.query.sqm.tree.expression.ValueBindJpaCriteriaParameter;
import org.hibernate.query.sqm.tree.expression.SqmParameter;
import org.hibernate.query.sqm.tree.from.SqmFromClause;
@ -455,4 +458,83 @@ public class SqmSelectStatement<T> extends AbstractSqmSelectQuery<T> implements
"Please disable the JPA query compliance if you want to use this feature." );
}
}
@Override
public JpaCriteriaQuery<Long> createCountQuery() {
final SqmCopyContext context = new NoParamSqmCopyContext() {
@Override
public boolean copyFetchedFlag() {
return false;
}
};
final NodeBuilder nodeBuilder = nodeBuilder();
final Set<SqmParameter<?>> parameters;
if ( this.parameters == null ) {
parameters = null;
}
else {
parameters = new LinkedHashSet<>( this.parameters.size() );
for ( SqmParameter<?> parameter : this.parameters ) {
parameters.add( parameter.copy( context ) );
}
}
final SqmSelectStatement<Long> selectStatement = new SqmSelectStatement<>(
nodeBuilder,
copyCteStatements( context ),
Long.class,
SqmQuerySource.CRITERIA,
parameters
);
final SqmQuerySpec<Long> querySpec = new SqmQuerySpec<>( nodeBuilder );
final SqmSubQuery<Tuple> subquery = new SqmSubQuery<>( selectStatement, Tuple.class, nodeBuilder );
final SqmQueryPart<T> queryPart = getQueryPart().copy( context );
resetSelections( queryPart );
// Reset the
if ( queryPart.getFetch() == null && queryPart.getOffset() == null ) {
queryPart.setOrderByClause( null );
}
//noinspection unchecked
subquery.setQueryPart( (SqmQueryPart<Tuple>) queryPart );
querySpec.setFromClause( new SqmFromClause( 1 ) );
querySpec.setSelectClause( new SqmSelectClause( false, 1, nodeBuilder ) );
selectStatement.setQueryPart( querySpec );
selectStatement.select( nodeBuilder.count( new SqmStar( nodeBuilder ) ) );
selectStatement.from( subquery );
return selectStatement;
}
private void resetSelections(SqmQueryPart<?> queryPart) {
if ( queryPart instanceof SqmQuerySpec<?> ) {
resetSelections( (SqmQuerySpec<?>) queryPart );
}
else {
final SqmQueryGroup<?> group = (SqmQueryGroup<?>) queryPart;
for ( SqmQueryPart<?> part : group.getQueryParts() ) {
resetSelections( part );
}
}
}
private void resetSelections(SqmQuerySpec<?> querySpec) {
final NodeBuilder nodeBuilder = nodeBuilder();
final List<SqmSelection<?>> selections = querySpec.getSelectClause().getSelections();
final List<SqmSelectableNode<?>> subSelections = new ArrayList<>();
if ( selections.isEmpty() ) {
subSelections.add( (SqmSelectableNode<?>) nodeBuilder.literal( 1 ).alias( "c0" ) );
}
else {
for ( SqmSelection<?> selection : selections ) {
selection.getSelectableNode().visitSubSelectableNodes( e -> {
e.alias( "c" + subSelections.size() );
subSelections.add( e );
} );
}
}
querySpec.getSelectClause().setSelection( (SqmSelectableNode<?>) nodeBuilder.tuple( subSelections ) );
}
}

View File

@ -0,0 +1,254 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.orm.test.query;
import java.time.LocalDate;
import java.util.List;
import org.hibernate.engine.spi.SessionImplementor;
import org.hibernate.query.criteria.HibernateCriteriaBuilder;
import org.hibernate.query.criteria.JpaCriteriaQuery;
import org.hibernate.query.criteria.JpaParameterExpression;
import org.hibernate.query.criteria.JpaRoot;
import org.hibernate.testing.orm.domain.StandardDomainModel;
import org.hibernate.testing.orm.domain.contacts.Address;
import org.hibernate.testing.orm.domain.contacts.Contact;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.JiraKey;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Tuple;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Christian Beikov
*/
@DomainModel(standardModels = StandardDomainModel.CONTACTS)
@SessionFactory
@JiraKey("HHH-17410")
public class CountQueryTests {
@Test
public void testBasic(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
final HibernateCriteriaBuilder cb = session.getCriteriaBuilder();
verifyCount( session, cb.createQuery(
"select e.id, e.name from Contact e where e.gender is null",
Tuple.class
) );
verifyCount( session, cb.createQuery(
"select e.id as id, e.name as name from Contact e where e.gender = FEMALE",
Tuple.class
) );
}
);
}
@Test
public void testFetches(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
final HibernateCriteriaBuilder cb = session.getCriteriaBuilder();
verifyCount( session, cb.createQuery(
"select e from Contact e join fetch e.alternativeContact",
Contact.class
) );
verifyCollectionCount( session, cb.createQuery(
"select e from Contact e left join fetch e.addresses",
Contact.class
) );
}
);
}
@Test
public void testConstructor(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
final HibernateCriteriaBuilder cb = session.getCriteriaBuilder();
verifyCount( session, cb.createQuery(
"select new " + Contact.class.getName() + "(e.id, e.name, e.gender, e.birthDay) from Contact e",
Tuple.class
) );
}
);
}
@Test
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsRecursiveCtes.class)
public void testCte(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
final HibernateCriteriaBuilder cb = session.getCriteriaBuilder();
verifyCount( session, cb.createQuery(
"with alternativeContacts as (" +
"select c.alternativeContact alt from Contact c where c.id = 1 " +
"union all " +
"select c.alt.alternativeContact alt from alternativeContacts c where c.alt.alternativeContact.id <> 1" +
")" +
"select ac from alternativeContacts c join c.alt ac order by ac.id",
Tuple.class
) );
}
);
}
@Test
public void testValues(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
final HibernateCriteriaBuilder cb = session.getCriteriaBuilder();
final JpaCriteriaQuery<Tuple> cq = cb.createTupleQuery();
final JpaRoot<Contact> root = cq.from( Contact.class );
cq.multiselect(
root.get( "id" ),
root.get( "name" )
);
cq.where(
root.get( "gender" ).equalTo( Contact.Gender.FEMALE )
);
verifyCount( session, cq );
}
);
}
@Test
public void testParameters(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
final HibernateCriteriaBuilder cb = session.getCriteriaBuilder();
final JpaCriteriaQuery<Tuple> cq = cb.createTupleQuery();
final JpaRoot<Contact> root = cq.from( Contact.class );
cq.multiselect(
root.get( "id" ),
root.get( "name" )
);
final JpaParameterExpression<Contact.Gender> parameter = cb.parameter( Contact.Gender.class );
cq.where( root.get( "gender" ).equalTo( parameter ) );
final List<Tuple> resultList = session.createQuery( cq )
.setParameter( parameter, Contact.Gender.FEMALE )
.getResultList();
final Long count = session.createQuery( cq.createCountQuery() )
.setParameter( parameter, Contact.Gender.FEMALE )
.getSingleResult();
assertEquals( resultList.size(), count.intValue() );
}
);
}
@BeforeEach
public void prepareTestData(SessionFactoryScope scope) {
scope.inTransaction( (session) -> {
final Contact contact = new Contact(
1,
new Contact.Name( "John", "Doe" ),
Contact.Gender.MALE,
LocalDate.of( 1970, 1, 1 )
);
final Contact alternativeContact = new Contact(
2,
new Contact.Name( "Jane", "Doe" ),
Contact.Gender.FEMALE,
LocalDate.of( 1970, 1, 1 )
);
final Contact alternativeContact2 = new Contact(
3,
new Contact.Name( "Granny", "Doe" ),
Contact.Gender.FEMALE,
LocalDate.of( 1970, 1, 1 )
);
alternativeContact.setAlternativeContact( alternativeContact2 );
contact.setAlternativeContact( alternativeContact );
contact.setAddresses(
List.of(
new Address( "Street 1", 1234 ),
new Address( "Street 2", 5678 )
)
);
session.persist( alternativeContact2 );
session.persist( alternativeContact );
session.persist( contact );
alternativeContact2.setAlternativeContact( contact );
final Contact c4 = new Contact(
4,
new Contact.Name( "C4", "Doe" ),
Contact.Gender.OTHER,
LocalDate.of( 1970, 1, 1 )
);
final Contact c5 = new Contact(
5,
new Contact.Name( "C5", "Doe" ),
Contact.Gender.OTHER,
LocalDate.of( 1970, 1, 1 )
);
final Contact c6 = new Contact(
6,
new Contact.Name( "C6", "Doe" ),
Contact.Gender.OTHER,
LocalDate.of( 1970, 1, 1 )
);
final Contact c7 = new Contact(
7,
new Contact.Name( "C7", "Doe" ),
Contact.Gender.OTHER,
LocalDate.of( 1970, 1, 1 )
);
final Contact c8 = new Contact(
8,
new Contact.Name( "C8", "Doe" ),
Contact.Gender.OTHER,
LocalDate.of( 1970, 1, 1 )
);
c4.setAlternativeContact( c5 );
c5.setAlternativeContact( c6 );
c7.setAlternativeContact( c8 );
session.persist( c6 );
session.persist( c5 );
session.persist( c4 );
session.persist( c8 );
session.persist( c7 );
} );
}
private <T> void verifyCount(SessionImplementor session, JpaCriteriaQuery<?> query) {
final List<?> resultList = session.createQuery( query ).getResultList();
final Long count = session.createQuery( query.createCountQuery() ).getSingleResult();
assertEquals( resultList.size(), count.intValue() );
}
private <T> void verifyCollectionCount(SessionImplementor session, JpaCriteriaQuery<Contact> query) {
final List<Contact> resultList = session.createQuery( query ).getResultList();
final Long count = session.createQuery( query.createCountQuery() ).getSingleResult();
int ormSize = 0;
for ( Contact contact : resultList ) {
ormSize++;
ormSize += Math.max( contact.getAddresses().size() - 1, 0 );
ormSize += Math.max( contact.getPhoneNumbers().size() - 1, 0 );
}
assertEquals( ormSize, count.intValue() );
}
@AfterEach
public void dropTestData(SessionFactoryScope scope) {
scope.inTransaction( (session) -> {
session.createMutationQuery( "update Contact set alternativeContact = null" ).executeUpdate();
session.createMutationQuery( "delete Contact" ).executeUpdate();
} );
}
}