diff --git a/documentation/src/test/java/org/hibernate/userguide/pc/FilterJoinTableTest.java b/documentation/src/test/java/org/hibernate/userguide/pc/FilterJoinTableTest.java index a3d2e03540..aad30a91be 100644 --- a/documentation/src/test/java/org/hibernate/userguide/pc/FilterJoinTableTest.java +++ b/documentation/src/test/java/org/hibernate/userguide/pc/FilterJoinTableTest.java @@ -18,7 +18,6 @@ import javax.persistence.OneToMany; import javax.persistence.OrderColumn; import org.hibernate.Session; -import org.hibernate.annotations.Filter; import org.hibernate.annotations.FilterDef; import org.hibernate.annotations.FilterJoinTable; import org.hibernate.annotations.ParamDef; @@ -26,8 +25,6 @@ import org.hibernate.jpa.test.BaseEntityManagerFunctionalTestCase; import org.junit.Test; -import org.jboss.logging.Logger; - import static org.hibernate.testing.transaction.TransactionUtil.doInJPA; import static org.junit.Assert.assertEquals; @@ -118,10 +115,6 @@ public class FilterJoinTableTest extends BaseEntityManagerFunctionalTestCase { type="int" ) ) - @Filter( - name="firstAccounts", - condition="order_id <= :maxOrderId" - ) public static class Client { @Id diff --git a/documentation/src/test/java/org/hibernate/userguide/pc/FilterTest.java b/documentation/src/test/java/org/hibernate/userguide/pc/FilterTest.java index 92b60a1cd8..e5d432d3a5 100644 --- a/documentation/src/test/java/org/hibernate/userguide/pc/FilterTest.java +++ b/documentation/src/test/java/org/hibernate/userguide/pc/FilterTest.java @@ -18,11 +18,13 @@ import javax.persistence.Id; import javax.persistence.ManyToOne; import javax.persistence.NoResultException; import javax.persistence.OneToMany; +import javax.persistence.Table; import org.hibernate.Session; import org.hibernate.annotations.Filter; import org.hibernate.annotations.FilterDef; import org.hibernate.annotations.ParamDef; +import org.hibernate.annotations.SqlFragmentAlias; import org.hibernate.annotations.Where; import org.hibernate.jpa.test.BaseEntityManagerFunctionalTestCase; @@ -56,7 +58,8 @@ public class FilterTest extends BaseEntityManagerFunctionalTestCase { //tag::pc-filter-persistence-example[] Client client = new Client() .setId( 1L ) - .setName( "John Doe" ); + .setName( "John Doe" ) + .setType( AccountType.DEBIT ); client.addAccount( new Account() @@ -186,7 +189,7 @@ public class FilterTest extends BaseEntityManagerFunctionalTestCase { Client client = entityManager.find( Client.class, 1L ); - assertEquals( 2, client.getAccounts().size() ); + assertEquals( 1, client.getAccounts().size() ); //end::pc-filter-collection-query-example[] } ); } @@ -198,6 +201,7 @@ public class FilterTest extends BaseEntityManagerFunctionalTestCase { //tag::pc-filter-Client-example[] @Entity(name = "Client") + @Table(name = "client") public static class Client { @Id @@ -205,13 +209,19 @@ public class FilterTest extends BaseEntityManagerFunctionalTestCase { private String name; + private AccountType type; + @OneToMany( mappedBy = "client", cascade = CascadeType.ALL ) @Filter( name="activeAccount", - condition="active_status = :active" + condition="{a}.active_status = :active and {a}.type = {c}.type", + aliases = { + @SqlFragmentAlias( alias = "a", table= "account"), + @SqlFragmentAlias( alias = "c", table= "client"), + } ) private List accounts = new ArrayList<>( ); @@ -235,6 +245,15 @@ public class FilterTest extends BaseEntityManagerFunctionalTestCase { return this; } + public AccountType getType() { + return type; + } + + public Client setType(AccountType type) { + this.type = type; + return this; + } + public List getAccounts() { return accounts; } @@ -249,6 +268,7 @@ public class FilterTest extends BaseEntityManagerFunctionalTestCase { //tag::pc-filter-Account-example[] @Entity(name = "Account") + @Table(name = "account") @FilterDef( name="activeAccount", parameters = @ParamDef( diff --git a/hibernate-core/src/main/java/org/hibernate/internal/FilterHelper.java b/hibernate-core/src/main/java/org/hibernate/internal/FilterHelper.java index 62d3763ca5..f6c87bdf50 100644 --- a/hibernate-core/src/main/java/org/hibernate/internal/FilterHelper.java +++ b/hibernate-core/src/main/java/org/hibernate/internal/FilterHelper.java @@ -13,13 +13,22 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import org.hibernate.Filter; -import org.hibernate.HibernateException; +import org.hibernate.MappingException; +import org.hibernate.engine.spi.LoadQueryInfluencers; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.internal.util.StringHelper; import org.hibernate.internal.util.collections.CollectionHelper; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.metamodel.mapping.PluralAttributeMapping; +import org.hibernate.persister.collection.AbstractCollectionPersister; +import org.hibernate.persister.entity.AbstractEntityPersister; +import org.hibernate.persister.entity.Joinable; import org.hibernate.sql.Template; +import org.hibernate.sql.ast.spi.SqlAstCreationContext; +import org.hibernate.sql.ast.tree.predicate.FilterPredicate; import org.hibernate.type.Type; +import static org.hibernate.internal.util.StringHelper.join; import static org.hibernate.internal.util.StringHelper.safeInterning; /** @@ -31,7 +40,7 @@ import static org.hibernate.internal.util.StringHelper.safeInterning; */ public class FilterHelper { - private static Pattern FILTER_PARAMETER_PATTERN = Pattern.compile( ":(\\w+)\\.(\\w+)" ); + private static final Pattern FILTER_PARAMETER_PATTERN = Pattern.compile( ":(\\w+)\\.(\\w+)" ); private final String[] filterNames; private final String[] filterConditions; @@ -141,49 +150,30 @@ public class FilterHelper { } } - public static class TypedValue { - private final Type type; - private final Object value; - - public TypedValue(Type type, Object value) { - this.type = type; - this.value = value; - } - - public Type getType() { - return type; - } - - public Object getValue() { - return value; + public static FilterPredicate createFilterPredicate(LoadQueryInfluencers loadQueryInfluencers, Joinable joinable, String alias) { + if ( loadQueryInfluencers.hasEnabledFilters() ) { + final String filterFragment; + if ( joinable instanceof AbstractCollectionPersister && ( (AbstractCollectionPersister) joinable ).isManyToMany() ) { + filterFragment = ( (AbstractCollectionPersister) joinable ).getManyToManyFilterFragment( + alias, + loadQueryInfluencers.getEnabledFilters() + ); + } + else { + filterFragment = joinable.filterFragment( alias, loadQueryInfluencers.getEnabledFilters() ); + } + if ( ! StringHelper.isEmptyOrWhiteSpace( filterFragment ) ) { + return doCreateFilterPredicate( filterFragment, loadQueryInfluencers.getEnabledFilters() ); + } } + return null; } - public static class TransformResult { - private final String transformedFilterFragment; - private final List parameters; - - public TransformResult( - String transformedFilterFragment, - List parameters) { - this.transformedFilterFragment = transformedFilterFragment; - this.parameters = parameters; - } - - public String getTransformedFilterFragment() { - return transformedFilterFragment; - } - - public List getParameters() { - return parameters; - } - } - - public static TransformResult transformToPositionalParameters(String filterFragment, Map enabledFilters) { + private static FilterPredicate doCreateFilterPredicate(String filterFragment, Map enabledFilters) { final Matcher matcher = FILTER_PARAMETER_PATTERN.matcher( filterFragment ); final StringBuilder sb = new StringBuilder(); int pos = 0; - final List parameters = new ArrayList<>( matcher.groupCount() ); + final List parameters = new ArrayList<>( matcher.groupCount() ); while( matcher.find() ) { sb.append( filterFragment, pos, matcher.start() ); pos = matcher.end(); @@ -192,16 +182,19 @@ public class FilterHelper { final String parameterName = matcher.group( 2 ); final FilterImpl enabledFilter = (FilterImpl) enabledFilters.get( filterName ); if ( enabledFilter == null ) { - throw new HibernateException( String.format( "unknown filter [%s]", filterName ) ); + throw new MappingException( String.format( "unknown filter [%s]", filterName ) ); } final Type parameterType = enabledFilter.getFilterDefinition().getParameterType( parameterName ); + if ( ! (parameterType instanceof JdbcMapping ) ) { + throw new MappingException( String.format( "parameter [%s] for filter [%s] is not of JdbcMapping type", parameterName, filterName ) ); + } final Object parameterValue = enabledFilter.getParameter( parameterName ); if ( parameterValue == null ) { - throw new HibernateException( String.format( "unknown parameter [%s] for filter [%s]", parameterName, filterName ) ); + throw new MappingException( String.format( "unknown parameter [%s] for filter [%s]", parameterName, filterName ) ); } - parameters.add( new TypedValue( parameterType, parameterValue ) ); + parameters.add( new FilterJdbcParameter( (JdbcMapping) parameterType, parameterValue ) ); } sb.append( filterFragment, pos, filterFragment.length() ); - return new TransformResult( sb.toString(), parameters ); + return new FilterPredicate( sb.toString(), parameters ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/internal/FilterJdbcParameter.java b/hibernate-core/src/main/java/org/hibernate/internal/FilterJdbcParameter.java new file mode 100644 index 0000000000..772c1ac237 --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/internal/FilterJdbcParameter.java @@ -0,0 +1,50 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.internal; + +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.sql.ast.tree.expression.JdbcParameter; +import org.hibernate.sql.exec.internal.JdbcParameterImpl; +import org.hibernate.sql.exec.spi.JdbcParameterBinder; +import org.hibernate.sql.exec.spi.JdbcParameterBinding; + +/** + * @author Nathan Xu + */ +public class FilterJdbcParameter { + private final JdbcParameter parameter; + private final JdbcMapping jdbcMapping; + private final Object jdbcParameterValue; + + public FilterJdbcParameter(JdbcMapping jdbcMapping, Object jdbcParameterValue) { + this.parameter = new JdbcParameterImpl( jdbcMapping ); + this.jdbcMapping = jdbcMapping; + this.jdbcParameterValue = jdbcParameterValue; + } + + public JdbcParameter getParameter() { + return parameter; + } + + public JdbcParameterBinder getBinder() { + return parameter.getParameterBinder(); + } + + public JdbcParameterBinding getBinding() { + return new JdbcParameterBinding() { + @Override + public JdbcMapping getBindType() { + return jdbcMapping; + } + + @Override + public Object getBindValue() { + return jdbcParameterValue; + } + }; + } +} diff --git a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderBatchKey.java b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderBatchKey.java index a436b9029a..849baf42fa 100644 --- a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderBatchKey.java +++ b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderBatchKey.java @@ -163,8 +163,7 @@ public class CollectionLoaderBatchKey implements CollectionLoader { final JdbcSelect jdbcSelect = sqlAstTranslatorFactory.buildSelectTranslator( sessionFactory ).translate( sqlAst ); final JdbcParameterBindings jdbcParameterBindings = new JdbcParameterBindingsImpl( keyJdbcCount * smallBatchLength ); - - sqlAst.getQuerySpec().bindFilterPredicateParameters( jdbcParameterBindings ); + jdbcSelect.registerFilterJdbcParameterBindings( jdbcParameterBindings ); final Iterator paramItr = jdbcParameters.iterator(); diff --git a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderSingleKey.java b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderSingleKey.java index 2055850399..704d09b888 100644 --- a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderSingleKey.java +++ b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderSingleKey.java @@ -99,6 +99,7 @@ public class CollectionLoaderSingleKey implements CollectionLoader { final JdbcSelect jdbcSelect = sqlAstTranslatorFactory.buildSelectTranslator( sessionFactory ).translate( sqlAst ); final JdbcParameterBindings jdbcParameterBindings = new JdbcParameterBindingsImpl( keyJdbcCount ); + jdbcSelect.registerFilterJdbcParameterBindings( jdbcParameterBindings ); final Iterator paramItr = jdbcParameters.iterator(); @@ -127,8 +128,6 @@ public class CollectionLoaderSingleKey implements CollectionLoader { ); assert !paramItr.hasNext(); - sqlAst.getQuerySpec().bindFilterPredicateParameters( jdbcParameterBindings ); - jdbcServices.getJdbcSelectExecutor().list( jdbcSelect, jdbcParameterBindings, diff --git a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/LoaderSelectBuilder.java b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/LoaderSelectBuilder.java index 06e8d0489f..9878f8a622 100644 --- a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/LoaderSelectBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/LoaderSelectBuilder.java @@ -24,7 +24,6 @@ import org.hibernate.engine.spi.LoadQueryInfluencers; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.engine.spi.SubselectFetch; import org.hibernate.internal.FilterHelper; -import org.hibernate.internal.FilterHelper.TransformResult; import org.hibernate.loader.ast.spi.Loadable; import org.hibernate.loader.ast.spi.Loader; import org.hibernate.metamodel.mapping.BasicValuedModelPart; @@ -49,13 +48,11 @@ import org.hibernate.sql.ast.tree.expression.Expression; import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.expression.SqlTuple; import org.hibernate.sql.ast.tree.from.TableGroup; -import org.hibernate.sql.ast.tree.from.TableGroupJoin; import org.hibernate.sql.ast.tree.from.TableReference; -import org.hibernate.sql.ast.tree.from.TableReferenceJoin; import org.hibernate.sql.ast.tree.predicate.ComparisonPredicate; -import org.hibernate.sql.ast.tree.predicate.FilterPredicate; import org.hibernate.sql.ast.tree.predicate.InListPredicate; import org.hibernate.sql.ast.tree.predicate.InSubQueryPredicate; +import org.hibernate.sql.ast.tree.predicate.Predicate; import org.hibernate.sql.ast.tree.select.QuerySpec; import org.hibernate.sql.ast.tree.select.SelectStatement; import org.hibernate.sql.exec.internal.JdbcParameterImpl; @@ -231,7 +228,7 @@ public class LoaderSelectBuilder { sqlAstCreationState.getFromClauseAccess().registerTableGroup( rootNavigablePath, rootTableGroup ); if ( loadable instanceof PluralAttributeMapping ) { - applyFiltering( rootQuerySpec, loadQueryInfluencers, (PluralAttributeMapping) loadable ); + applyFiltering( rootQuerySpec, rootTableGroup, (PluralAttributeMapping) loadable ); applyOrdering( rootTableGroup, (PluralAttributeMapping) loadable ); } @@ -382,62 +379,22 @@ public class LoaderSelectBuilder { } } - private void applyFiltering( - QuerySpec querySpec, - LoadQueryInfluencers loadQueryInfluencers, - PluralAttributeMapping pluralAttributeMapping) { + private void applyFiltering(QuerySpec querySpec, TableGroup tableGroup, PluralAttributeMapping pluralAttributeMapping) { if ( loadQueryInfluencers.hasEnabledFilters() ) { final Joinable joinable = pluralAttributeMapping .getCollectionDescriptor() .getCollectionType() .getAssociatedJoinable( creationContext.getSessionFactory() ); assert joinable instanceof AbstractCollectionPersister; - final AbstractCollectionPersister collectionPersister = (AbstractCollectionPersister) joinable; - querySpec.getFromClause().getRoots().forEach( tableGroup -> consumeTableAliasByTableExpression( - tableGroup, - joinable.getTableName(), - alias -> { - final boolean isManyToMany = collectionPersister.isManyToMany(); - String filterFragment; - if ( isManyToMany ) { - filterFragment = collectionPersister.getManyToManyFilterFragment( - alias, - loadQueryInfluencers.getEnabledFilters() - ); - } - else { - filterFragment = collectionPersister.filterFragment( - alias, - loadQueryInfluencers.getEnabledFilters() - ); - } - final TransformResult transformResult = FilterHelper.transformToPositionalParameters( - filterFragment, loadQueryInfluencers.getEnabledFilters() - ); - filterFragment = transformResult.getTransformedFilterFragment(); - final FilterPredicate filterPredicate = new FilterPredicate( - filterFragment, transformResult.getParameters() - ); - querySpec.applyPredicate( filterPredicate ); - querySpec.addFilterPredicate( filterPredicate ); - } - ) + final String tableExpression = joinable.getTableName(); + final String tableAlias = tableGroup.resolveTableReference( tableExpression ).getIdentificationVariable(); + final Predicate filterPredicate = FilterHelper.createFilterPredicate( + loadQueryInfluencers, + joinable, + tableAlias ); - } - } - - private void consumeTableAliasByTableExpression(TableGroup tableGroup, String tableExpression, Consumer aliasConsumer) { - if ( tableExpression.equals( tableGroup.getPrimaryTableReference().getTableExpression() ) ) { - aliasConsumer.accept( tableGroup.getPrimaryTableReference().getIdentificationVariable() ); - } - else { - for ( TableReferenceJoin referenceJoin : tableGroup.getTableReferenceJoins() ) { - if ( tableExpression.equals( referenceJoin.getJoinedTableReference().getTableExpression() ) ) { - aliasConsumer.accept( referenceJoin.getJoinedTableReference().getIdentificationVariable() ); - } - } - for ( TableGroupJoin tableGroupJoin : tableGroup.getTableGroupJoins() ) { - consumeTableAliasByTableExpression( tableGroupJoin.getJoinedGroup(), tableExpression, aliasConsumer ); + if ( filterPredicate != null ) { + querySpec.applyPredicate( filterPredicate ); } } } @@ -554,7 +511,11 @@ public class LoaderSelectBuilder { fetches.add( fetch ); if ( fetchable instanceof PluralAttributeMapping && fetchTiming == FetchTiming.IMMEDIATE && joined ) { - applyFiltering( querySpec, loadQueryInfluencers, (PluralAttributeMapping) fetchable ); + applyFiltering( + querySpec, + creationState.getFromClauseAccess().getTableGroup( fetchablePath ), + ( (PluralAttributeMapping) fetchable ) + ); applyOrdering( querySpec, fetchablePath, @@ -630,7 +591,7 @@ public class LoaderSelectBuilder { sqlAstCreationState.getFromClauseAccess().registerTableGroup( rootNavigablePath, rootTableGroup ); // NOTE : no need to check - we are explicitly processing a plural-attribute - applyFiltering( rootQuerySpec, loadQueryInfluencers, (PluralAttributeMapping) loadable ); + applyFiltering( rootQuerySpec, rootTableGroup, attributeMapping ); applyOrdering( rootTableGroup, attributeMapping ); // generate and apply the restriction diff --git a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/SingleIdLoadPlan.java b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/SingleIdLoadPlan.java index 159cde0926..2f335ba0e7 100644 --- a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/SingleIdLoadPlan.java +++ b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/SingleIdLoadPlan.java @@ -103,6 +103,7 @@ public class SingleIdLoadPlan implements SingleEntityLoadPlan { assert jdbcParameters.size() % jdbcTypeCount == 0; final JdbcParameterBindings jdbcParameterBindings = new JdbcParameterBindingsImpl( jdbcTypeCount ); + jdbcSelect.registerFilterJdbcParameterBindings( jdbcParameterBindings ); final Iterator paramItr = jdbcParameters.iterator(); @@ -132,8 +133,6 @@ public class SingleIdLoadPlan implements SingleEntityLoadPlan { ); } - sqlAst.getQuerySpec().bindFilterPredicateParameters( jdbcParameterBindings ); - final List list = JdbcSelectExecutorStandardImpl.INSTANCE.list( jdbcSelect, jdbcParameterBindings, diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/ConcreteSqmSelectQueryPlan.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/ConcreteSqmSelectQueryPlan.java index 1fa109eb6b..720696399d 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/ConcreteSqmSelectQueryPlan.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/ConcreteSqmSelectQueryPlan.java @@ -160,6 +160,7 @@ public class ConcreteSqmSelectQueryPlan implements SelectQueryPlan { sqmInterpretation.getTableGroupAccess()::findTableGroup, session ); + sqmInterpretation.getJdbcSelect().registerFilterJdbcParameterBindings( jdbcParameterBindings ); try { return session.getFactory().getJdbcServices().getJdbcSelectExecutor().list( diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmInterpretationsKey.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmInterpretationsKey.java index 8b3813fa07..6ea9d26d07 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmInterpretationsKey.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmInterpretationsKey.java @@ -45,6 +45,11 @@ public class SqmInterpretationsKey implements QueryInterpretationCache.Key { private static boolean isCacheable(QuerySqmImpl query) { assert query.getQueryOptions().getAppliedGraph() != null; + if ( query.getSession().getLoadQueryInfluencers().hasEnabledFilters() ) { + // At the moment we cannot cache query plan if there is filter enabled. + return false; + } + if ( query.getQueryOptions().getAppliedGraph().getSemantic() != null ) { // At the moment we cannot cache query plan if there is an // EntityGraph enabled. diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/StandardSqmSelectTranslator.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/StandardSqmSelectTranslator.java index dab75686c9..c53ccbd903 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/StandardSqmSelectTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/StandardSqmSelectTranslator.java @@ -6,6 +6,7 @@ */ package org.hibernate.query.sqm.sql.internal; +import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; @@ -19,15 +20,19 @@ import org.hibernate.NotYetImplementedFor6Exception; import org.hibernate.engine.FetchTiming; import org.hibernate.engine.profile.FetchProfile; import org.hibernate.engine.spi.LoadQueryInfluencers; +import org.hibernate.internal.FilterHelper; import org.hibernate.internal.util.collections.CollectionHelper; import org.hibernate.internal.util.collections.Stack; import org.hibernate.internal.util.collections.StandardStack; import org.hibernate.metamodel.mapping.EntityMappingType; import org.hibernate.metamodel.mapping.ModelPart; +import org.hibernate.metamodel.mapping.ModelPartContainer; import org.hibernate.metamodel.mapping.PluralAttributeMapping; import org.hibernate.metamodel.mapping.ordering.OrderByFragment; import org.hibernate.metamodel.model.domain.EntityDomainType; +import org.hibernate.persister.entity.AbstractEntityPersister; import org.hibernate.persister.entity.EntityPersister; +import org.hibernate.persister.entity.Joinable; import org.hibernate.query.DynamicInstantiationNature; import org.hibernate.query.NavigablePath; import org.hibernate.query.spi.QueryOptions; @@ -57,6 +62,7 @@ import org.hibernate.sql.ast.tree.expression.Expression; import org.hibernate.sql.ast.tree.from.TableGroup; import org.hibernate.sql.ast.tree.from.TableGroupJoin; import org.hibernate.sql.ast.tree.from.TableGroupJoinProducer; +import org.hibernate.sql.ast.tree.predicate.FilterPredicate; import org.hibernate.sql.ast.tree.select.QuerySpec; import org.hibernate.sql.ast.tree.select.SelectStatement; import org.hibernate.sql.results.graph.DomainResult; @@ -89,6 +95,8 @@ public class StandardSqmSelectTranslator private int fetchDepth; + private List collectionFieldFilterPredicates; + public StandardSqmSelectTranslator( QueryOptions queryOptions, DomainParameterXref domainParameterXref, @@ -187,6 +195,24 @@ public class StandardSqmSelectTranslator @Override protected void postProcessQuerySpec(QuerySpec sqlQuerySpec) { + final List roots = sqlQuerySpec.getFromClause().getRoots(); + if ( roots != null && roots.size() == 1 ) { + final TableGroup root = roots.get( 0 ); + final ModelPartContainer modelPartContainer = root.getModelPart(); + final EntityPersister entityPersister = modelPartContainer.findContainingEntityMapping().getEntityPersister(); + assert entityPersister instanceof AbstractEntityPersister; + final String primaryTableAlias = root.getPrimaryTableReference().getIdentificationVariable(); + final FilterPredicate filterPredicate = FilterHelper.createFilterPredicate( + fetchInfluencers, (AbstractEntityPersister) entityPersister, primaryTableAlias + ); + if ( filterPredicate != null ) { + sqlQuerySpec.applyPredicate( filterPredicate ); + } + if ( !CollectionHelper.isEmpty( collectionFieldFilterPredicates ) ) { + collectionFieldFilterPredicates.forEach( sqlQuerySpec::applyPredicate ); + } + } + try { final OrderByFragmentConsumer orderByFragmentConsumer = orderByFragmentConsumerStack.getCurrent(); if ( orderByFragmentConsumer != null ) { @@ -384,10 +410,32 @@ public class StandardSqmSelectTranslator StandardSqmSelectTranslator.this ); - final OrderByFragmentConsumer orderByFragmentConsumer = orderByFragmentConsumerStack.getCurrent(); - if ( orderByFragmentConsumer != null ) { - if ( fetchable instanceof PluralAttributeMapping && fetch.getTiming() == FetchTiming.IMMEDIATE ) { - final PluralAttributeMapping pluralAttributeMapping = (PluralAttributeMapping) fetchable; + if ( fetchable instanceof PluralAttributeMapping && fetch.getTiming() == FetchTiming.IMMEDIATE && joined ) { + final PluralAttributeMapping pluralAttributeMapping = (PluralAttributeMapping) fetchable; + + String tableAlias = alias; + if ( tableAlias == null ) { + tableAlias = getFromClauseAccess().getTableGroup( fetchablePath ).getPrimaryTableReference().getIdentificationVariable(); + } + final Joinable joinable = pluralAttributeMapping + .getCollectionDescriptor() + .getCollectionType() + .getAssociatedJoinable( getCreationContext().getSessionFactory() ); + final FilterPredicate collectionFieldFilterPredicate = FilterHelper.createFilterPredicate( + fetchInfluencers, + joinable, + tableAlias + ); + if ( collectionFieldFilterPredicate != null ) { + if ( collectionFieldFilterPredicates == null ) { + collectionFieldFilterPredicates = new ArrayList<>(); + } + collectionFieldFilterPredicates.add( collectionFieldFilterPredicate ); + } + + final OrderByFragmentConsumer orderByFragmentConsumer = orderByFragmentConsumerStack.getCurrent(); + if ( orderByFragmentConsumer != null ) { + final TableGroup tableGroup = getFromClauseIndex().getTableGroup( fetchablePath ); assert tableGroup.getModelPart() == pluralAttributeMapping; diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstWalker.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstWalker.java index 82aa5289ed..66df2e2f61 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstWalker.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstWalker.java @@ -14,6 +14,8 @@ import org.hibernate.SortOrder; import org.hibernate.dialect.Dialect; import org.hibernate.engine.jdbc.spi.JdbcServices; import org.hibernate.engine.spi.SessionFactoryImplementor; +import org.hibernate.internal.FilterJdbcParameter; +import org.hibernate.internal.util.StringHelper; import org.hibernate.internal.util.collections.Stack; import org.hibernate.internal.util.collections.StandardStack; import org.hibernate.metamodel.mapping.JdbcMapping; @@ -70,11 +72,10 @@ import org.hibernate.sql.ast.tree.predicate.SelfRenderingPredicate; import org.hibernate.sql.ast.tree.select.QuerySpec; import org.hibernate.sql.ast.tree.select.SelectClause; import org.hibernate.sql.ast.tree.select.SortSpecification; -import org.hibernate.sql.exec.internal.AbstractJdbcParameter; import org.hibernate.sql.exec.internal.JdbcParametersImpl; import org.hibernate.sql.exec.spi.JdbcParameterBinder; -import org.hibernate.type.descriptor.sql.SqlTypeDescriptorIndicators; import org.hibernate.type.descriptor.sql.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.sql.SqlTypeDescriptorIndicators; import org.hibernate.type.spi.TypeConfiguration; import static org.hibernate.query.TemporalUnit.NANOSECOND; @@ -98,10 +99,12 @@ public abstract class AbstractSqlAstWalker // In-flight state private final StringBuilder sqlBuffer = new StringBuilder(); - private final List parameterBinders = new ArrayList<>(); + private final List parameterBinders = new ArrayList<>(); private final JdbcParametersImpl jdbcParameters = new JdbcParametersImpl(); + protected final List filterJdbcParameters = new ArrayList<>(); + private final Stack clauseStack = new StandardStack<>(); private final Dialect dialect; @@ -1034,12 +1037,12 @@ public abstract class AbstractSqlAstWalker @Override public void visitFilterPredicate(FilterPredicate filterPredicate) { - if ( filterPredicate.getFilterFragment() != null ) { - appendSql( filterPredicate.getFilterFragment() ); - for (JdbcParameter jdbcParameter : filterPredicate.getJdbcParameters()) { - parameterBinders.add( (AbstractJdbcParameter) jdbcParameter ); - jdbcParameters.addParameter( jdbcParameter ); - } + assert StringHelper.isNotEmpty( filterPredicate.getFilterFragment() ); + appendSql( filterPredicate.getFilterFragment() ); + for ( FilterJdbcParameter filterJdbcParameter : filterPredicate.getFilterJdbcParameters() ) { + parameterBinders.add( filterJdbcParameter.getBinder() ); + jdbcParameters.addParameter( filterJdbcParameter.getParameter() ); + filterJdbcParameters.add( filterJdbcParameter ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/StandardSqlAstSelectTranslator.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/StandardSqlAstSelectTranslator.java index c264e4e139..d834aeedb9 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/StandardSqlAstSelectTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/StandardSqlAstSelectTranslator.java @@ -74,7 +74,8 @@ public class StandardSqlAstSelectTranslator querySpec.getSelectClause().getSqlSelections(), Collections.emptyList() ), - getAffectedTableNames() + getAffectedTableNames(), + filterJdbcParameters ); } @@ -93,7 +94,8 @@ public class StandardSqlAstSelectTranslator sqlAstSelect.getQuerySpec().getSelectClause().getSqlSelections(), sqlAstSelect.getDomainResultDescriptors() ), - getAffectedTableNames() + getAffectedTableNames(), + filterJdbcParameters ); } diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/predicate/FilterPredicate.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/predicate/FilterPredicate.java index 9bf6e5506c..b5d4b700fd 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/predicate/FilterPredicate.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/predicate/FilterPredicate.java @@ -6,13 +6,10 @@ */ package org.hibernate.sql.ast.tree.predicate; -import java.util.ArrayList; import java.util.List; -import org.hibernate.internal.FilterHelper; +import org.hibernate.internal.FilterJdbcParameter; import org.hibernate.sql.ast.SqlAstWalker; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; -import org.hibernate.sql.exec.internal.JdbcParameterImpl; /** * Represents a filter applied to an entity/collection. @@ -20,19 +17,15 @@ import org.hibernate.sql.exec.internal.JdbcParameterImpl; * Note, we do not attempt to parse the filter * * @author Steve Ebersole + * @author Nathan Xu */ public class FilterPredicate implements Predicate { private final String filterFragment; - private final List jdbcParameters; - private final List jdbcParameterTypedValues; + private final List filterJdbcParameters; - public FilterPredicate(String filterFragment, List jdbcParameterTypedValues) { + public FilterPredicate(String filterFragment, List filterJdbcParameters) { this.filterFragment = filterFragment; - jdbcParameters = new ArrayList<>( jdbcParameterTypedValues.size() ); - this.jdbcParameterTypedValues = jdbcParameterTypedValues; - for (int i = 0; i < jdbcParameterTypedValues.size(); i++) { - jdbcParameters.add( new JdbcParameterImpl( null ) ); - } + this.filterJdbcParameters = filterJdbcParameters; } @Override @@ -49,11 +42,7 @@ public class FilterPredicate implements Predicate { return filterFragment; } - public List getJdbcParameters() { - return jdbcParameters; - } - - public List getJdbcParameterTypedValues() { - return jdbcParameterTypedValues; + public List getFilterJdbcParameters() { + return filterJdbcParameters; } } diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/select/QuerySpec.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/select/QuerySpec.java index ed4fb7cc83..1760a3ab6e 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/select/QuerySpec.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/select/QuerySpec.java @@ -10,9 +10,6 @@ import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; -import org.hibernate.HibernateException; -import org.hibernate.internal.FilterHelper; -import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.metamodel.mapping.MappingModelExpressable; import org.hibernate.query.sqm.sql.internal.DomainResultProducer; import org.hibernate.sql.ast.SqlAstWalker; @@ -22,13 +19,9 @@ import org.hibernate.sql.ast.spi.SqlSelection; import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.cte.CteConsumer; import org.hibernate.sql.ast.tree.expression.Expression; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.from.FromClause; -import org.hibernate.sql.ast.tree.predicate.FilterPredicate; import org.hibernate.sql.ast.tree.predicate.Predicate; import org.hibernate.sql.ast.tree.predicate.PredicateContainer; -import org.hibernate.sql.exec.spi.JdbcParameterBinding; -import org.hibernate.sql.exec.spi.JdbcParameterBindings; import org.hibernate.sql.results.graph.DomainResult; import org.hibernate.sql.results.graph.DomainResultCreationState; import org.hibernate.sql.results.graph.basic.BasicResult; @@ -45,7 +38,6 @@ public class QuerySpec implements SqlAstNode, PredicateContainer, Expression, Ct private final SelectClause selectClause = new SelectClause(); private Predicate whereClauseRestrictions; - private List filterPredicates; private List sortSpecifications; private Expression limitClauseExpression; private Expression offsetClauseExpression; @@ -85,13 +77,6 @@ public class QuerySpec implements SqlAstNode, PredicateContainer, Expression, Ct this.whereClauseRestrictions = SqlAstTreeHelper.combinePredicates( this.whereClauseRestrictions, predicate ); } - public void addFilterPredicate(FilterPredicate filterPredicate) { - if ( filterPredicates == null ) { - filterPredicates = new ArrayList<>(); - } - filterPredicates.add( filterPredicate ); - } - public List getSortSpecifications() { return sortSpecifications; } @@ -170,32 +155,4 @@ public class QuerySpec implements SqlAstNode, PredicateContainer, Expression, Ct descriptor ); } - - public void bindFilterPredicateParameters(JdbcParameterBindings jdbcParameterBindings) { - if ( filterPredicates != null && !filterPredicates.isEmpty() ) { - for ( FilterPredicate filterPredicate : filterPredicates ) { - for ( int i = 0; i < filterPredicate.getJdbcParameters().size(); i++ ) { - final JdbcParameter parameter = filterPredicate.getJdbcParameters().get( i ); - final FilterHelper.TypedValue parameterTypedValue = filterPredicate.getJdbcParameterTypedValues().get( i ); - if ( !(parameterTypedValue.getType() instanceof JdbcMapping ) ) { - throw new HibernateException( String.format( "Filter parameter type [%s] did not implement JdbcMapping", parameterTypedValue.getType() ) ); - } - jdbcParameterBindings.addBinding( - parameter, - new JdbcParameterBinding() { - @Override - public JdbcMapping getBindType() { - return (JdbcMapping) parameterTypedValue.getType(); - } - - @Override - public Object getBindValue() { - return parameterTypedValue.getValue(); - } - } - ); - } - } - } - } } diff --git a/hibernate-core/src/main/java/org/hibernate/sql/exec/spi/JdbcSelect.java b/hibernate-core/src/main/java/org/hibernate/sql/exec/spi/JdbcSelect.java index 2dd1cc2417..c7efe7fcbf 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/exec/spi/JdbcSelect.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/exec/spi/JdbcSelect.java @@ -9,6 +9,8 @@ package org.hibernate.sql.exec.spi; import java.util.List; import java.util.Set; +import org.hibernate.internal.FilterJdbcParameter; +import org.hibernate.internal.util.collections.CollectionHelper; import org.hibernate.sql.results.jdbc.spi.JdbcValuesMappingProducer; /** @@ -21,16 +23,19 @@ public class JdbcSelect implements JdbcOperation { private final List parameterBinders; private final JdbcValuesMappingProducer jdbcValuesMappingProducer; private final Set affectedTableNames; + private final List filterJdbcParameters; public JdbcSelect( String sql, List parameterBinders, JdbcValuesMappingProducer jdbcValuesMappingProducer, - Set affectedTableNames) { + Set affectedTableNames, + List filterJdbcParameters) { this.sql = sql; this.parameterBinders = parameterBinders; this.jdbcValuesMappingProducer = jdbcValuesMappingProducer; this.affectedTableNames = affectedTableNames; + this.filterJdbcParameters = filterJdbcParameters; } @Override @@ -51,4 +56,12 @@ public class JdbcSelect implements JdbcOperation { public JdbcValuesMappingProducer getJdbcValuesMappingProducer() { return jdbcValuesMappingProducer; } + + public void registerFilterJdbcParameterBindings(JdbcParameterBindings jdbcParameterBindings) { + if ( CollectionHelper.isNotEmpty( filterJdbcParameters ) ) { + for ( FilterJdbcParameter filterJdbcParameter : filterJdbcParameters ) { + jdbcParameterBindings.addBinding( filterJdbcParameter.getParameter(), filterJdbcParameter.getBinding() ); + } + } + } } diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/criteria/filter/FilterBasicsTests.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/criteria/filter/FilterBasicsTests.java new file mode 100644 index 0000000000..9a860d0272 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/criteria/filter/FilterBasicsTests.java @@ -0,0 +1,300 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.query.criteria.filter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.FetchType; +import javax.persistence.Id; +import javax.persistence.ManyToOne; +import javax.persistence.OneToMany; +import javax.persistence.criteria.CriteriaBuilder; +import javax.persistence.criteria.CriteriaQuery; +import javax.persistence.criteria.Root; + +import org.hibernate.annotations.Filter; +import org.hibernate.annotations.FilterDef; +import org.hibernate.annotations.ParamDef; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SessionFactoryScopeAware; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.junit.Assert.assertThat; + +/** + * @author Nathan Xu + */ +@DomainModel( + annotatedClasses = { + FilterBasicsTests.Client.class, + FilterBasicsTests.Account.class + } +) +@SessionFactory +public class FilterBasicsTests implements SessionFactoryScopeAware { + + private SessionFactoryScope scope; + + @Override + public void injectSessionFactoryScope(SessionFactoryScope scope) { + this.scope = scope; + } + + @BeforeEach + void setUp(SessionFactoryScope scope) { + scope.inTransaction( session -> { + + // ensure query plan cache won't interfere + scope.getSessionFactory().getQueryEngine().getInterpretationCache().close(); + + Client client = new Client() + .setId( 1L ) + .setName( "John Doe" ); + + client.addAccount( + new Account() + .setId( 1L ) + .setType( AccountType.CREDIT ) + .setAmount( 5000d ) + .setRate( 1.25 / 100 ) + .setActive( true ) + ); + + client.addAccount( + new Account() + .setId( 2L ) + .setType( AccountType.DEBIT ) + .setAmount( 0d ) + .setRate( 1.05 / 100 ) + .setActive( false ) + ); + + client.addAccount( + new Account() + .setType( AccountType.DEBIT ) + .setId( 3L ) + .setAmount( 250d ) + .setRate( 1.05 / 100 ) + .setActive( true ) + ); + session.persist( client ); + } ); + } + + @ParameterizedTest + @ValueSource( strings = { "true", "false" } ) + void testLoadFilterOnEntity(boolean enableFilter) { + scope.inTransaction( session -> { + if ( enableFilter ) { + session.enableFilter( "activeAccount" ).setParameter( "active", true ); + } + final CriteriaBuilder criteriaBuilder = scope.getSessionFactory().getCriteriaBuilder(); + + final CriteriaQuery criteriaQuery1 = createCriteriaQuery( criteriaBuilder, Account.class, "id", 1L ); + Account account1 = session.createQuery( criteriaQuery1 ).uniqueResult(); + assertThat( account1, notNullValue() ); + + final CriteriaQuery criteriaQuery2 = createCriteriaQuery( criteriaBuilder, Account.class, "id", 2L ); + Account account2 = session.createQuery( criteriaQuery2 ).uniqueResult(); + assertThat( account2, enableFilter ? nullValue() : notNullValue() ); + } ); + } + + @ParameterizedTest + @ValueSource( strings = { "true", "false" } ) + void testLoadFilterOnCollectionField(boolean enableFilter) { + scope.inTransaction( session -> { + if ( enableFilter ) { + session.enableFilter( "activeAccount" ).setParameter( "active", true ); + } + + final CriteriaBuilder criteriaBuilder = scope.getSessionFactory().getCriteriaBuilder(); + final CriteriaQuery criteriaQuery = createCriteriaQuery( criteriaBuilder, Client.class, "id", 1L ); + final Client client = session.createQuery(criteriaQuery).uniqueResult(); + + if ( enableFilter ) { + assertThat( client.getAccounts().stream().map( Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 3L ) ) ) ); + } + else { + assertThat( client.getAccounts().stream().map( Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 2L, 3L ) ) ) ); + } + } ); + } + + @AfterEach + void tearDown(SessionFactoryScope scope) { + scope.inTransaction( session -> { + session.createQuery( "delete from Account" ).executeUpdate(); + session.createQuery( "delete from Client" ).executeUpdate(); + } ); + } + + public enum AccountType { + DEBIT, + CREDIT + } + + @Entity(name = "Client") + public static class Client { + + @Id + private Long id; + + private String name; + + @OneToMany( + mappedBy = "client", + cascade = CascadeType.ALL + ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + private List accounts = new ArrayList<>(); + + public Long getId() { + return id; + } + + public Client setId(Long id) { + this.id = id; + return this; + } + + public String getName() { + return name; + } + + public Client setName(String name) { + this.name = name; + return this; + } + + public List getAccounts() { + return accounts; + } + + public void addAccount(Account account) { + account.setClient( this ); + this.accounts.add( account ); + } + } + + @Entity(name = "Account") + @FilterDef( + name="activeAccount", + parameters = @ParamDef( + name="active", + type="boolean" + ) + ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + public static class Account { + + @Id + private Long id; + + @ManyToOne(fetch = FetchType.LAZY) + private Client client; + + @Column(name = "account_type") + @Enumerated(EnumType.STRING) + private AccountType type; + + private Double amount; + + private Double rate; + + @Column(name = "active_status") + private boolean active; + + public Long getId() { + return id; + } + + public Account setId(Long id) { + this.id = id; + return this; + } + + public Client getClient() { + return client; + } + + public Account setClient(Client client) { + this.client = client; + return this; + } + + public AccountType getType() { + return type; + } + + public Account setType(AccountType type) { + this.type = type; + return this; + } + + public Double getAmount() { + return amount; + } + + public Account setAmount(Double amount) { + this.amount = amount; + return this; + } + + public Double getRate() { + return rate; + } + + public Account setRate(Double rate) { + this.rate = rate; + return this; + } + + public boolean isActive() { + return active; + } + + public Account setActive(boolean active) { + this.active = active; + return this; + } + } + + private static CriteriaQuery createCriteriaQuery(CriteriaBuilder criteriaBuilder, Class entityClass, String idFieldName, Object idValue) { + final CriteriaQuery criteria = criteriaBuilder.createQuery( entityClass ); + Root root = criteria.from( entityClass ); + criteria.select( root ); + criteria.where( criteriaBuilder.equal( root.get( idFieldName ), criteriaBuilder.literal( idValue ) ) ); + return criteria; + } + +} diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/criteria/filter/FilterJoinTableTests.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/criteria/filter/FilterJoinTableTests.java new file mode 100644 index 0000000000..005f011e72 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/criteria/filter/FilterJoinTableTests.java @@ -0,0 +1,218 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.query.criteria.filter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.Id; +import javax.persistence.JoinTable; +import javax.persistence.ManyToMany; +import javax.persistence.OrderColumn; +import javax.persistence.criteria.CriteriaBuilder; +import javax.persistence.criteria.CriteriaQuery; +import javax.persistence.criteria.Root; + +import org.hibernate.annotations.FilterDef; +import org.hibernate.annotations.FilterJoinTable; +import org.hibernate.annotations.ParamDef; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.Assert.assertThat; + +/** + * @author Nathan Xu + */ +@DomainModel( + annotatedClasses = { + FilterJoinTableTests.Client.class, + FilterJoinTableTests.Account.class + } +) +@SessionFactory +public class FilterJoinTableTests { + + @BeforeEach + void setUp(SessionFactoryScope scope) { + scope.inTransaction( session -> { + + // ensure query plan cache won't interfere + scope.getSessionFactory().getQueryEngine().getInterpretationCache().close(); + + Client client = new Client() + .setId( 1L ) + .setName( "John Doe" ); + + client.addAccount( + new Account() + .setId( 1L ) + .setType( AccountType.CREDIT ) + .setAmount( 5000d ) + .setRate( 1.25 / 100 ) + ); + + client.addAccount( + new Account() + .setId( 2L ) + .setType( AccountType.DEBIT ) + .setAmount( 0d ) + .setRate( 1.05 / 100 ) + ); + + client.addAccount( + new Account() + .setType( AccountType.DEBIT ) + .setId( 3L ) + .setAmount( 250d ) + .setRate( 1.05 / 100 ) + ); + + session.persist( client ); + } ); + } + + @Test + void testLoadFilterOnCollectionField(SessionFactoryScope scope) { + scope.inTransaction( session -> { + session.enableFilter( "firstAccounts" ).setParameter( "maxOrderId", 1); + + final CriteriaBuilder criteriaBuilder = scope.getSessionFactory().getCriteriaBuilder(); + final CriteriaQuery criteriaQuery = createCriteriaQuery( criteriaBuilder, Client.class, "id", 1L ); + final Client client = session.createQuery( criteriaQuery ).uniqueResult(); + assertThat( client.getAccounts().stream().map( Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 2L ) ) ) ); + } ); + } + + public enum AccountType { + DEBIT, + CREDIT + } + + @Entity(name = "Client") + @FilterDef( + name="firstAccounts", + parameters=@ParamDef( + name="maxOrderId", + type="int" + ) + ) + public static class Client { + + @Id + private Long id; + + private String name; + + @ManyToMany(cascade = CascadeType.ALL) + @JoinTable + @OrderColumn(name = "order_id") + @FilterJoinTable( + name="firstAccounts", + condition="order_id <= :maxOrderId" + ) + private List accounts = new ArrayList<>(); + + public Long getId() { + return id; + } + + public Client setId(Long id) { + this.id = id; + return this; + } + + public String getName() { + return name; + } + + public Client setName(String name) { + this.name = name; + return this; + } + + public List getAccounts() { + return accounts; + } + + public void addAccount(Account account) { + this.accounts.add( account ); + } + } + + @Entity(name = "Account") + public static class Account { + + @Id + private Long id; + + @Column(name = "account_type") + @Enumerated(EnumType.STRING) + private AccountType type; + + private Double amount; + + private Double rate; + + public Long getId() { + return id; + } + + public Account setId(Long id) { + this.id = id; + return this; + } + + public AccountType getType() { + return type; + } + + public Account setType(AccountType type) { + this.type = type; + return this; + } + + public Double getAmount() { + return amount; + } + + public Account setAmount(Double amount) { + this.amount = amount; + return this; + } + + public Double getRate() { + return rate; + } + + public Account setRate(Double rate) { + this.rate = rate; + return this; + } + } + + private static CriteriaQuery createCriteriaQuery(CriteriaBuilder criteriaBuilder, Class entityClass, String idFieldName, Object idValue) { + final CriteriaQuery criteria = criteriaBuilder.createQuery( entityClass ); + Root root = criteria.from( entityClass ); + criteria.select( root ); + criteria.where( criteriaBuilder.equal( root.get( idFieldName ), criteriaBuilder.literal( idValue ) ) ); + return criteria; + } +} diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/criteria/filter/FilterOnJoinFetchedCollectionTests.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/criteria/filter/FilterOnJoinFetchedCollectionTests.java new file mode 100644 index 0000000000..887b79ac04 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/criteria/filter/FilterOnJoinFetchedCollectionTests.java @@ -0,0 +1,278 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.query.criteria.filter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.FetchType; +import javax.persistence.Id; +import javax.persistence.ManyToOne; +import javax.persistence.OneToMany; +import javax.persistence.criteria.CriteriaBuilder; +import javax.persistence.criteria.CriteriaQuery; +import javax.persistence.criteria.Root; + +import org.hibernate.annotations.Filter; +import org.hibernate.annotations.FilterDef; +import org.hibernate.annotations.ParamDef; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SessionFactoryScopeAware; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.Assert.assertThat; + +/** + * @author Nathan Xu + */ +@DomainModel( + annotatedClasses = { + FilterOnJoinFetchedCollectionTests.Client.class, + FilterOnJoinFetchedCollectionTests.Account.class + } +) +@SessionFactory +public class FilterOnJoinFetchedCollectionTests implements SessionFactoryScopeAware { + + private SessionFactoryScope scope; + + @Override + public void injectSessionFactoryScope(SessionFactoryScope scope) { + this.scope = scope; + } + + @BeforeEach + void setUp(SessionFactoryScope scope) { + scope.inTransaction( session -> { + + // ensure query plan cache won't interfere + scope.getSessionFactory().getQueryEngine().getInterpretationCache().close(); + + Client client = new Client() + .setId( 1L ) + .setName( "John Doe" ); + + client.addAccount( + new Account() + .setId( 1L ) + .setType( AccountType.CREDIT ) + .setAmount( 5000d ) + .setRate( 1.25 / 100 ) + .setActive( true ) + ); + + client.addAccount( + new Account() + .setId( 2L ) + .setType( AccountType.DEBIT ) + .setAmount( 0d ) + .setRate( 1.05 / 100 ) + .setActive( false ) + ); + + client.addAccount( + new Account() + .setType( AccountType.DEBIT ) + .setId( 3L ) + .setAmount( 250d ) + .setRate( 1.05 / 100 ) + .setActive( true ) + ); + session.persist( client ); + } ); + } + + @ParameterizedTest + @ValueSource( strings = { "true", "false" } ) + void testJoinFetchedCollectionField(boolean enableFilter) { + scope.inTransaction( session -> { + if ( enableFilter ) { + session.enableFilter( "activeAccount" ).setParameter( "active", true ); + } + final CriteriaBuilder criteriaBuilder = scope.getSessionFactory().getCriteriaBuilder(); + final CriteriaQuery criteriaQuery = createCriteriaQuery( criteriaBuilder, Client.class, "id", 1L ); + final Client client = session.createQuery( criteriaQuery ).uniqueResult(); + + if ( enableFilter ) { + assertThat( client.getAccounts().stream().map( Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 3L ) ) ) ); + } + else { + assertThat( client.getAccounts().stream().map( Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 2L, 3L ) ) ) ); + } + } ); + } + + @AfterEach + void tearDown(SessionFactoryScope scope) { + scope.inTransaction( session -> { + session.createQuery( "delete from Account" ).executeUpdate(); + session.createQuery( "delete from Client" ).executeUpdate(); + } ); + } + + public enum AccountType { + DEBIT, + CREDIT + } + + @Entity(name = "Client") + public static class Client { + + @Id + private Long id; + + private String name; + + @OneToMany( + mappedBy = "client", + cascade = CascadeType.ALL + ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + private List accounts = new ArrayList<>(); + + public Long getId() { + return id; + } + + public Client setId(Long id) { + this.id = id; + return this; + } + + public String getName() { + return name; + } + + public Client setName(String name) { + this.name = name; + return this; + } + + public List getAccounts() { + return accounts; + } + + public void addAccount(Account account) { + account.setClient( this ); + this.accounts.add( account ); + } + } + + @Entity(name = "Account") + @FilterDef( + name="activeAccount", + parameters = @ParamDef( + name="active", + type="boolean" + ) + ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + public static class Account { + + @Id + private Long id; + + @ManyToOne(fetch = FetchType.LAZY) + private Client client; + + @Column(name = "account_type") + @Enumerated(EnumType.STRING) + private AccountType type; + + private Double amount; + + private Double rate; + + @Column(name = "active_status") + private boolean active; + + public Long getId() { + return id; + } + + public Account setId(Long id) { + this.id = id; + return this; + } + + public Client getClient() { + return client; + } + + public Account setClient(Client client) { + this.client = client; + return this; + } + + public AccountType getType() { + return type; + } + + public Account setType(AccountType type) { + this.type = type; + return this; + } + + public Double getAmount() { + return amount; + } + + public Account setAmount(Double amount) { + this.amount = amount; + return this; + } + + public Double getRate() { + return rate; + } + + public Account setRate(Double rate) { + this.rate = rate; + return this; + } + + public boolean isActive() { + return active; + } + + public Account setActive(boolean active) { + this.active = active; + return this; + } + } + + private static CriteriaQuery createCriteriaQuery(CriteriaBuilder criteriaBuilder, Class entityClass, String idFieldName, Object idValue) { + final CriteriaQuery criteria = criteriaBuilder.createQuery( entityClass ); + Root root = criteria.from( entityClass ); + criteria.select( root ); + criteria.where( criteriaBuilder.equal( root.get( idFieldName ), criteriaBuilder.literal( idValue ) ) ); + return criteria; + } + +} diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/criteria/filter/FilterWithSqlFragmentAliasTests.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/criteria/filter/FilterWithSqlFragmentAliasTests.java new file mode 100644 index 0000000000..b3f4dcddec --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/criteria/filter/FilterWithSqlFragmentAliasTests.java @@ -0,0 +1,291 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.query.criteria.filter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.FetchType; +import javax.persistence.Id; +import javax.persistence.ManyToOne; +import javax.persistence.OneToMany; +import javax.persistence.Table; +import javax.persistence.criteria.CriteriaBuilder; +import javax.persistence.criteria.CriteriaQuery; +import javax.persistence.criteria.Root; + +import org.hibernate.annotations.Filter; +import org.hibernate.annotations.FilterDef; +import org.hibernate.annotations.ParamDef; +import org.hibernate.annotations.SqlFragmentAlias; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SessionFactoryScopeAware; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.Assert.assertThat; + +/** + * @author Nathan Xu + */ +@DomainModel( + annotatedClasses = { + FilterWithSqlFragmentAliasTests.Client.class, + FilterWithSqlFragmentAliasTests.Account.class + } +) +@SessionFactory +public class FilterWithSqlFragmentAliasTests implements SessionFactoryScopeAware { + + private SessionFactoryScope scope; + + @Override + public void injectSessionFactoryScope(SessionFactoryScope scope) { + this.scope = scope; + } + + @BeforeEach + void setUp(SessionFactoryScope scope) { + scope.inTransaction( session -> { + // ensure query plan cache won't interfere + scope.getSessionFactory().getQueryEngine().getInterpretationCache().close(); + + Client client = new Client() + .setId( 1L ) + .setName( "John Doe" ); + + client.addAccount( + new Account() + .setId( 1L ) + .setType( AccountType.CREDIT ) + .setAmount( 5000d ) + .setRate( 1.25 / 100 ) + .setActive( true ) + ); + + client.addAccount( + new Account() + .setId( 2L ) + .setType( AccountType.DEBIT ) + .setAmount( 0d ) + .setRate( 1.05 / 100 ) + .setActive( false ) + ); + + client.addAccount( + new Account() + .setType( AccountType.DEBIT ) + .setId( 3L ) + .setAmount( 250d ) + .setRate( 1.05 / 100 ) + .setActive( true ) + ); + session.persist( client ); + } ); + } + + @ParameterizedTest + @ValueSource( strings = { "true", "false" } ) + void testLoadFilterOnCollectionField(boolean enableFilter) { + scope.inTransaction( session -> { + if ( enableFilter ) { + session.enableFilter( "activeAccount" ).setParameter( "active", true ); + } + + final CriteriaBuilder criteriaBuilder = scope.getSessionFactory().getCriteriaBuilder(); + final CriteriaQuery criteriaQuery = createCriteriaQuery( criteriaBuilder, Client.class, "id", 1L ); + final Client client = session.createQuery( criteriaQuery ).uniqueResult(); + + if ( enableFilter ) { + assertThat( client.getAccounts().stream().map( Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 3L ) ) ) ); + } + else { + assertThat( client.getAccounts().stream().map( Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 2L, 3L ) ) ) ); + } + } ); + } + + @AfterEach + void tearDown(SessionFactoryScope scope) { + scope.inTransaction( session -> { + session.createQuery( "delete from Account" ).executeUpdate(); + session.createQuery( "delete from Client" ).executeUpdate(); + } ); + } + + public enum AccountType { + DEBIT, + CREDIT + } + + @Entity(name = "Client") + @Table(name = "client") + public static class Client { + + @Id + private Long id; + + private String name; + + private AccountType type; + + @OneToMany( + mappedBy = "client", + cascade = CascadeType.ALL + ) + @Filter( + name="activeAccount", + condition="{a}.active_status = :active", + aliases = { + @SqlFragmentAlias( alias = "a", table= "account") + } + ) + private List accounts = new ArrayList<>(); + + public Long getId() { + return id; + } + + public Client setId(Long id) { + this.id = id; + return this; + } + + public String getName() { + return name; + } + + public Client setName(String name) { + this.name = name; + return this; + } + + public AccountType getType() { + return type; + } + + public void setType(AccountType type) { + this.type = type; + } + + public List getAccounts() { + return accounts; + } + + public void addAccount(Account account) { + account.setClient( this ); + this.accounts.add( account ); + } + } + + @Entity(name = "Account") + @Table(name = "account") + @FilterDef( + name="activeAccount", + parameters = @ParamDef( + name="active", + type="boolean" + ) + ) + public static class Account { + + @Id + private Long id; + + @ManyToOne(fetch = FetchType.LAZY) + private Client client; + + @Column(name = "account_type") + @Enumerated(EnumType.STRING) + private AccountType type; + + private Double amount; + + private Double rate; + + @Column(name = "active_status") + private boolean active; + + public Long getId() { + return id; + } + + public Account setId(Long id) { + this.id = id; + return this; + } + + public Client getClient() { + return client; + } + + public Account setClient(Client client) { + this.client = client; + return this; + } + + public AccountType getType() { + return type; + } + + public Account setType(AccountType type) { + this.type = type; + return this; + } + + public Double getAmount() { + return amount; + } + + public Account setAmount(Double amount) { + this.amount = amount; + return this; + } + + public Double getRate() { + return rate; + } + + public Account setRate(Double rate) { + this.rate = rate; + return this; + } + + public boolean isActive() { + return active; + } + + public Account setActive(boolean active) { + this.active = active; + return this; + } + } + + private static CriteriaQuery createCriteriaQuery(CriteriaBuilder criteriaBuilder, Class entityClass, String idFieldName, Object idValue) { + final CriteriaQuery criteria = criteriaBuilder.createQuery( entityClass ); + Root root = criteria.from( entityClass ); + criteria.select( root ); + criteria.where( criteriaBuilder.equal( root.get( idFieldName ), criteriaBuilder.literal( idValue ) ) ); + return criteria; + } + +} diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/filter/FilterBasicsTests.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/filter/FilterBasicsTests.java new file mode 100644 index 0000000000..47a4d8bd9b --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/filter/FilterBasicsTests.java @@ -0,0 +1,284 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.query.hql.filter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.FetchType; +import javax.persistence.Id; +import javax.persistence.ManyToOne; +import javax.persistence.OneToMany; + +import org.hibernate.annotations.Filter; +import org.hibernate.annotations.FilterDef; +import org.hibernate.annotations.ParamDef; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SessionFactoryScopeAware; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.junit.Assert.assertThat; + +/** + * @author Nathan Xu + */ +@DomainModel( + annotatedClasses = { + FilterBasicsTests.Client.class, + FilterBasicsTests.Account.class + } +) +@SessionFactory +public class FilterBasicsTests implements SessionFactoryScopeAware { + + private SessionFactoryScope scope; + + @Override + public void injectSessionFactoryScope(SessionFactoryScope scope) { + this.scope = scope; + } + + @BeforeEach + void setUp(SessionFactoryScope scope) { + scope.inTransaction( session -> { + // ensure query plan cache won't interfere + scope.getSessionFactory().getQueryEngine().getInterpretationCache().close(); + + Client client = new Client() + .setId( 1L ) + .setName( "John Doe" ); + + client.addAccount( + new Account() + .setId( 1L ) + .setType( AccountType.CREDIT ) + .setAmount( 5000d ) + .setRate( 1.25 / 100 ) + .setActive( true ) + ); + + client.addAccount( + new Account() + .setId( 2L ) + .setType( AccountType.DEBIT ) + .setAmount( 0d ) + .setRate( 1.05 / 100 ) + .setActive( false ) + ); + + client.addAccount( + new Account() + .setType( AccountType.DEBIT ) + .setId( 3L ) + .setAmount( 250d ) + .setRate( 1.05 / 100 ) + .setActive( true ) + ); + session.persist( client ); + } ); + } + + @ParameterizedTest + @ValueSource( strings = { "true", "false" } ) + void testLoadFilterOnEntity(boolean enableFilter) { + scope.inTransaction( session -> { + if ( enableFilter ) { + session.enableFilter( "activeAccount" ).setParameter( "active", true ); + } + final String hqlString = "select a from Account a where a.id = :id"; + final Account account1 = session.createQuery( hqlString, Account.class ) + .setParameter( "id", 1L ).uniqueResult(); + final Account account2 = session.createQuery( hqlString, Account.class ) + .setParameter( "id", 2L ).uniqueResult(); + assertThat( account1, notNullValue() ); + assertThat( account2, enableFilter ? nullValue() : notNullValue() ); + } ); + } + + @ParameterizedTest + @ValueSource( strings = { "true", "false" } ) + void testLoadFilterOnCollectionField(boolean enableFilter) { + scope.inTransaction( session -> { + if ( enableFilter ) { + session.enableFilter( "activeAccount" ).setParameter( "active", true ); + } + Client client = session.createQuery( "select c from Client c where c.id = :id", Client.class ) + .setParameter( "id", 1L ).uniqueResult(); + + if ( enableFilter ) { + assertThat( client.getAccounts().stream().map(Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 3L ) ) ) ); + } + else { + assertThat( client.getAccounts().stream().map(Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 2L, 3L ) ) ) ); + } + } ); + } + + @AfterEach + void tearDown(SessionFactoryScope scope) { + scope.inTransaction( session -> { + session.createQuery( "delete from Account" ).executeUpdate(); + session.createQuery( "delete from Client" ).executeUpdate(); + } ); + } + + public enum AccountType { + DEBIT, + CREDIT + } + + @Entity(name = "Client") + public static class Client { + + @Id + private Long id; + + private String name; + + @OneToMany( + mappedBy = "client", + cascade = CascadeType.ALL + ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + private List accounts = new ArrayList<>(); + + public Long getId() { + return id; + } + + public Client setId(Long id) { + this.id = id; + return this; + } + + public String getName() { + return name; + } + + public Client setName(String name) { + this.name = name; + return this; + } + + public List getAccounts() { + return accounts; + } + + public void addAccount(Account account) { + account.setClient( this ); + this.accounts.add( account ); + } + } + + @Entity(name = "Account") + @FilterDef( + name="activeAccount", + parameters = @ParamDef( + name="active", + type="boolean" + ) + ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + public static class Account { + + @Id + private Long id; + + @ManyToOne(fetch = FetchType.LAZY) + private Client client; + + @Column(name = "account_type") + @Enumerated(EnumType.STRING) + private AccountType type; + + private Double amount; + + private Double rate; + + @Column(name = "active_status") + private boolean active; + + public Long getId() { + return id; + } + + public Account setId(Long id) { + this.id = id; + return this; + } + + public Client getClient() { + return client; + } + + public Account setClient(Client client) { + this.client = client; + return this; + } + + public AccountType getType() { + return type; + } + + public Account setType(AccountType type) { + this.type = type; + return this; + } + + public Double getAmount() { + return amount; + } + + public Account setAmount(Double amount) { + this.amount = amount; + return this; + } + + public Double getRate() { + return rate; + } + + public Account setRate(Double rate) { + this.rate = rate; + return this; + } + + public boolean isActive() { + return active; + } + + public Account setActive(boolean active) { + this.active = active; + return this; + } + } + +} diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/filter/FilterJoinTableTests.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/filter/FilterJoinTableTests.java new file mode 100644 index 0000000000..9e5ea80b0a --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/filter/FilterJoinTableTests.java @@ -0,0 +1,202 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.query.hql.filter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.Id; +import javax.persistence.JoinTable; +import javax.persistence.ManyToMany; +import javax.persistence.OrderColumn; + +import org.hibernate.annotations.FilterDef; +import org.hibernate.annotations.FilterJoinTable; +import org.hibernate.annotations.ParamDef; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.Assert.assertThat; + +/** + * @author Nathan Xu + */ +@DomainModel( + annotatedClasses = { + FilterJoinTableTests.Client.class, + FilterJoinTableTests.Account.class + } +) +@SessionFactory +public class FilterJoinTableTests { + + @BeforeEach + void setUp(SessionFactoryScope scope) { + scope.inTransaction( session -> { + Client client = new Client() + .setId( 1L ) + .setName( "John Doe" ); + + client.addAccount( + new Account() + .setId( 1L ) + .setType( AccountType.CREDIT ) + .setAmount( 5000d ) + .setRate( 1.25 / 100 ) + ); + + client.addAccount( + new Account() + .setId( 2L ) + .setType( AccountType.DEBIT ) + .setAmount( 0d ) + .setRate( 1.05 / 100 ) + ); + + client.addAccount( + new Account() + .setType( AccountType.DEBIT ) + .setId( 3L ) + .setAmount( 250d ) + .setRate( 1.05 / 100 ) + ); + + session.persist( client ); + } ); + } + + @Test + void testFilterJoinableOnCollectionField(SessionFactoryScope scope) { + scope.inTransaction( session -> { + session.enableFilter( "firstAccounts" ).setParameter( "maxOrderId", 1 ); + final Client client = session.createQuery( "select c from Client c where c.id = :id", Client.class ) + .setParameter( "id", 1L ).uniqueResult(); + assertThat( client.getAccounts().stream().map( Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 2L ) ) ) ); + } ); + } + + public enum AccountType { + DEBIT, + CREDIT + } + + @Entity(name = "Client") + @FilterDef( + name="firstAccounts", + parameters=@ParamDef( + name="maxOrderId", + type="int" + ) + ) + public static class Client { + + @Id + private Long id; + + private String name; + + @ManyToMany(cascade = CascadeType.ALL) + @JoinTable + @OrderColumn(name = "order_id") + @FilterJoinTable( + name="firstAccounts", + condition="order_id <= :maxOrderId" + ) + private List accounts = new ArrayList<>(); + + public Long getId() { + return id; + } + + public Client setId(Long id) { + this.id = id; + return this; + } + + public String getName() { + return name; + } + + public Client setName(String name) { + this.name = name; + return this; + } + + public List getAccounts() { + return accounts; + } + + public void addAccount(Account account) { + this.accounts.add( account ); + } + } + + @Entity(name = "Account") + public static class Account { + + @Id + private Long id; + + @Column(name = "account_type") + @Enumerated(EnumType.STRING) + private AccountType type; + + private Double amount; + + private Double rate; + + public Long getId() { + return id; + } + + public Account setId(Long id) { + this.id = id; + return this; + } + + public AccountType getType() { + return type; + } + + public Account setType(AccountType type) { + this.type = type; + return this; + } + + public Double getAmount() { + return amount; + } + + public Account setAmount(Double amount) { + this.amount = amount; + return this; + } + + public Double getRate() { + return rate; + } + + public Account setRate(Double rate) { + this.rate = rate; + return this; + } + } + +} diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/filter/FilterOnJoinFetchedCollectionTests.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/filter/FilterOnJoinFetchedCollectionTests.java new file mode 100644 index 0000000000..617008e3a2 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/filter/FilterOnJoinFetchedCollectionTests.java @@ -0,0 +1,262 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.query.hql.filter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.FetchType; +import javax.persistence.Id; +import javax.persistence.ManyToOne; +import javax.persistence.OneToMany; + +import org.hibernate.annotations.Filter; +import org.hibernate.annotations.FilterDef; +import org.hibernate.annotations.ParamDef; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SessionFactoryScopeAware; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.Assert.assertThat; + +/** + * @author Nathan Xu + */ +@DomainModel( + annotatedClasses = { + FilterOnJoinFetchedCollectionTests.Client.class, + FilterOnJoinFetchedCollectionTests.Account.class + } +) +@SessionFactory +public class FilterOnJoinFetchedCollectionTests implements SessionFactoryScopeAware { + + private SessionFactoryScope scope; + + @Override + public void injectSessionFactoryScope(SessionFactoryScope scope) { + this.scope = scope; + } + + @BeforeEach + void setUp(SessionFactoryScope scope) { + scope.inTransaction( session -> { + Client client = new Client() + .setId( 1L ) + .setName( "John Doe" ); + + client.addAccount( + new Account() + .setId( 1L ) + .setType( AccountType.CREDIT ) + .setAmount( 5000d ) + .setRate( 1.25 / 100 ) + .setActive( true ) + ); + + client.addAccount( + new Account() + .setId( 2L ) + .setType( AccountType.DEBIT ) + .setAmount( 0d ) + .setRate( 1.05 / 100 ) + .setActive( false ) + ); + + client.addAccount( + new Account() + .setType( AccountType.DEBIT ) + .setId( 3L ) + .setAmount( 250d ) + .setRate( 1.05 / 100 ) + .setActive( true ) + ); + session.persist( client ); + } ); + } + + @ParameterizedTest + @ValueSource( strings = { "true", "false" } ) + void testLoadFilterOnCollectionField(boolean enableFilter) { + scope.inTransaction( session -> { + if ( enableFilter ) { + session.enableFilter( "activeAccount" ).setParameter( "active", true ); + } + final Client client = session.createQuery( "select c from Client c join fetch c.accounts where c.id = :id", Client.class ) + .setParameter( "id", 1L ).uniqueResult(); + + if ( enableFilter ) { + assertThat( client.getAccounts().stream().map( Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 3L ) ) ) ); + } + else { + assertThat( client.getAccounts().stream().map( Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 2L, 3L ) ) ) ); + } + } ); + } + + @AfterEach + void tearDown(SessionFactoryScope scope) { + scope.inTransaction( session -> { + session.createQuery( "delete from Account" ).executeUpdate(); + session.createQuery( "delete from Client" ).executeUpdate(); + } ); + } + + public enum AccountType { + DEBIT, + CREDIT + } + + @Entity(name = "Client") + public static class Client { + + @Id + private Long id; + + private String name; + + @OneToMany( + mappedBy = "client", + cascade = CascadeType.ALL + ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + private List accounts = new ArrayList<>(); + + public Long getId() { + return id; + } + + public Client setId(Long id) { + this.id = id; + return this; + } + + public String getName() { + return name; + } + + public Client setName(String name) { + this.name = name; + return this; + } + + public List getAccounts() { + return accounts; + } + + public void addAccount(Account account) { + account.setClient( this ); + this.accounts.add( account ); + } + } + + @Entity(name = "Account") + @FilterDef( + name="activeAccount", + parameters = @ParamDef( + name="active", + type="boolean" + ) + ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + public static class Account { + + @Id + private Long id; + + @ManyToOne(fetch = FetchType.LAZY) + private Client client; + + @Column(name = "account_type") + @Enumerated(EnumType.STRING) + private AccountType type; + + private Double amount; + + private Double rate; + + @Column(name = "active_status") + private boolean active; + + public Long getId() { + return id; + } + + public Account setId(Long id) { + this.id = id; + return this; + } + + public Client getClient() { + return client; + } + + public Account setClient(Client client) { + this.client = client; + return this; + } + + public AccountType getType() { + return type; + } + + public Account setType(AccountType type) { + this.type = type; + return this; + } + + public Double getAmount() { + return amount; + } + + public Account setAmount(Double amount) { + this.amount = amount; + return this; + } + + public Double getRate() { + return rate; + } + + public Account setRate(Double rate) { + this.rate = rate; + return this; + } + + public boolean isActive() { + return active; + } + + public Account setActive(boolean active) { + this.active = active; + return this; + } + } + +} diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/filter/FilterWithSqlFragmentAliasTests.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/filter/FilterWithSqlFragmentAliasTests.java new file mode 100644 index 0000000000..749dab04c7 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hql/filter/FilterWithSqlFragmentAliasTests.java @@ -0,0 +1,279 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.query.hql.filter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.FetchType; +import javax.persistence.Id; +import javax.persistence.ManyToOne; +import javax.persistence.OneToMany; +import javax.persistence.Table; + +import org.hibernate.annotations.Filter; +import org.hibernate.annotations.FilterDef; +import org.hibernate.annotations.ParamDef; +import org.hibernate.annotations.SqlFragmentAlias; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SessionFactoryScopeAware; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.Assert.assertThat; + +/** + * @author Nathan Xu + */ +@DomainModel( + annotatedClasses = { + FilterWithSqlFragmentAliasTests.Client.class, + FilterWithSqlFragmentAliasTests.Account.class + } +) +@SessionFactory +public class FilterWithSqlFragmentAliasTests implements SessionFactoryScopeAware { + + private SessionFactoryScope scope; + + @Override + public void injectSessionFactoryScope(SessionFactoryScope scope) { + this.scope = scope; + } + + @BeforeEach + void setUp(SessionFactoryScope scope) { + scope.inTransaction( session -> { + // ensure query plan cache won't interfere + scope.getSessionFactory().getQueryEngine().getInterpretationCache().close(); + + Client client = new Client() + .setId( 1L ) + .setName( "John Doe" ) + .setType( AccountType.DEBIT ); + + client.addAccount( + new Account() + .setId( 1L ) + .setType( AccountType.CREDIT ) + .setAmount( 5000d ) + .setRate( 1.25 / 100 ) + .setActive( true ) + ); + + client.addAccount( + new Account() + .setId( 2L ) + .setType( AccountType.DEBIT ) + .setAmount( 0d ) + .setRate( 1.05 / 100 ) + .setActive( false ) + ); + + client.addAccount( + new Account() + .setType( AccountType.DEBIT ) + .setId( 3L ) + .setAmount( 250d ) + .setRate( 1.05 / 100 ) + .setActive( true ) + ); + session.persist( client ); + } ); + } + + @ParameterizedTest + @ValueSource( strings = { "true", "false" } ) + void testSqlFragmentAlias(boolean enableFilter) { + scope.inTransaction( session -> { + if ( enableFilter ) { + session.enableFilter( "activeAccount" ).setParameter( "active", true ); + } + Client client = session.createQuery( "select c from Client c where c.id = :id", Client.class ) + .setParameter( "id", 1L ).uniqueResult(); + + if ( enableFilter ) { + assertThat( client.getAccounts().stream().map( Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 3L ) ) ) ); + } + else { + assertThat( client.getAccounts().stream().map( Account::getId ).collect( Collectors.toSet() ), + equalTo( new HashSet<>( Arrays.asList( 1L, 2L, 3L ) ) ) ); + } + } ); + } + + @AfterEach + void tearDown(SessionFactoryScope scope) { + scope.inTransaction( session -> { + session.createQuery( "delete from Account" ).executeUpdate(); + session.createQuery( "delete from Client" ).executeUpdate(); + } ); + } + + public enum AccountType { + DEBIT, + CREDIT + } + + @Entity(name = "Client") + @Table(name = "client") + public static class Client { + + @Id + private Long id; + + private String name; + + private AccountType type; + + @OneToMany( + mappedBy = "client", + cascade = CascadeType.ALL + ) + @Filter( + name="activeAccount", + condition="{a}.active_status = :active", + aliases = { + @SqlFragmentAlias( alias = "a", table= "account") + } + ) + private List accounts = new ArrayList<>(); + + public Long getId() { + return id; + } + + public Client setId(Long id) { + this.id = id; + return this; + } + + public String getName() { + return name; + } + + public Client setName(String name) { + this.name = name; + return this; + } + + public AccountType getType() { + return type; + } + + public Client setType(AccountType type) { + this.type = type; + return this; + } + + public List getAccounts() { + return accounts; + } + + public void addAccount(Account account) { + account.setClient( this ); + this.accounts.add( account ); + } + } + + @Entity(name = "Account") + @Table(name = "account") + @FilterDef( + name="activeAccount", + parameters = @ParamDef( + name="active", + type="boolean" + ) + ) + public static class Account { + + @Id + private Long id; + + @ManyToOne(fetch = FetchType.LAZY) + private Client client; + + @Column(name = "account_type") + @Enumerated(EnumType.STRING) + private AccountType type; + + private Double amount; + + private Double rate; + + @Column(name = "active_status") + private boolean active; + + public Long getId() { + return id; + } + + public Account setId(Long id) { + this.id = id; + return this; + } + + public Client getClient() { + return client; + } + + public Account setClient(Client client) { + this.client = client; + return this; + } + + public AccountType getType() { + return type; + } + + public Account setType(AccountType type) { + this.type = type; + return this; + } + + public Double getAmount() { + return amount; + } + + public Account setAmount(Double amount) { + this.amount = amount; + return this; + } + + public Double getRate() { + return rate; + } + + public Account setRate(Double rate) { + this.rate = rate; + return this; + } + + public boolean isActive() { + return active; + } + + public Account setActive(boolean active) { + this.active = active; + return this; + } + } +}